「Avito Code Challenge 2018」H. K Paths

Problem

Description

给定一棵 n 个点的树 T 和一个数 k ,求在树上选出 k 条路径使得其交非空且不在其交上的边至多被一条路径覆盖的方案数。
k 条路径间是有序的,亦即 \left\{(u_{1}, v_{1}),(u_{2}, v_{2})\right\} \left\{(u_{2}, v_{2}),(u_{1}, v_{1})\right\} 是不同的两种方案。

答案对 998244353 取模。

Constraints

1\le n,k\le 10^{5}

Solution

Analysis

这蠢题我想了两天没想到定根可以简化问题。。。我也是够蠢的。。

容易想到枚举 k 条路径的交来计算方案数。
(u, v) 为当前枚举的路径交,考虑将 v 作为根,在 u 的子树中选出 k 条路径端点的方案数,容易得到生成函数 P_{u}(x) 及方案数 f_{u}

\begin{aligned} P_{u}(x)&=\prod_{w\in\operatorname{son}(u)}\left(size_{w}x+1\right) \\ f_{u}&=\sum_{i=0}^{k}\binom{k}{i}i![x^{i}]P_{u}(x) \end{aligned}

P_{u}(x) 显然可以使用分治 FFT 在 O\left(\operatorname{deg}_{u}\log^{2}{\operatorname{deg}_{u}}\right) 的时间复杂度内求出,这部分的总时间复杂度即为 O\left(n\log^{2}{n}\right)

但是枚举 (u, v) 的时间复杂度是 O(n^{2}) 的,考虑简化问题。
稍加思考即可发现定根可以简化问题:
定根后记 \operatorname{sub}(u) 为以 u 为根的子树, \operatorname{son}(u, v) u v 方向上的儿子,记 s_{u}

s_{u}=\sum_{v\in\operatorname{sub}(u)}f_{v}

  • 对于不具有祖先-后代关系的点对 (u, v) ,容易通过时间复杂度 O(n) 的树形 DP 计算出其贡献:

\sum_{u=1}^{n}\sum_{v,w\in\operatorname{son}(u)}s_{v}s_{w}

  • 对于具有祖先-后代关系的点对 (u, v) ,考虑枚举深度较小的点 u w=\operatorname{son}(u, v) ,将 v\in\operatorname{sub}(w) 作为整体考虑:

\begin{aligned} Q_{u,w}(x)&=\frac{(n-size_{u}) x + 1}{size_{w} x + 1}P_{u}(x) \\ g_{u,w}&=\sum_{i=0}^{k}\binom{k}{i}i![x^{i}]Q_{u,w}(x) \end{aligned}

这部分对答案的贡献即为:

\sum_{u=1}^{n}\sum_{v\in\operatorname{son}(u)}g_{u,v}s_{v}

观察到不同的 size_{w} 个数是 O\left(\sqrt{n}\right) 的,故在 u 相同时将每个 size_{w} 对应的 g 值记录下来即可。
多项式乘|除以一次多项式是 O(n) 的,故这部分的时间复杂度为 O\left(n\sqrt{n}\right)

时间复杂度 O\left(n\log^{2}{n}+n\sqrt{n}\right)

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
#include<cstdio>
#include<algorithm>

using i64=long long;

constexpr int maxn(100000);
constexpr int mxdg(262144);
constexpr int p(998244353);
constexpr int proot(3);

template<typename _Tp>
inline void swap(_Tp&x,_Tp&y)
{_Tp z=x;x=y;y=z;}
template<typename _Tp>
inline void inc(_Tp&x,const _Tp&y)
{x+=y;(p<=x)&&(x-=p);}
template<typename x_tp,typename y_tp>
inline int sub(const x_tp&x,const y_tp&y)
{return(x<y)?(x-y+p):(x-y);}
template<typename _Tp>
inline int fpow(_Tp v,int n){
int pw=1;
for(;n;n>>=1,v=(i64)v*v%p)
if(n&1)pw=(i64)pw*v%p;
return pw;
}

namespace IOManager{
constexpr int FILESZ(131072);
char buf[FILESZ];
const char*ibuf=buf,*tbuf=buf;

struct IOManager{
inline char gc()
{return(ibuf==tbuf)&&(tbuf=(ibuf=buf)+fread(buf,1,FILESZ,stdin),ibuf==tbuf)?EOF:*ibuf++;}

template<typename _Tp>
inline operator _Tp(){
_Tp s=0u;char c=gc();
for(;c<48;c=gc());
for(;c>47;c=gc())
s=(_Tp)(s*10u+c-48u);
return s;
}
};
}IOManager::IOManager io;

struct Edge{
int v;Edge*las;

inline Edge* init(const int&to,Edge*const&ls)
{return v=to,las=ls,this;}
}*las[maxn+1];

inline void lnk(){
static Edge pool[maxn<<1],*alc=pool-1;
const int u=io,v=io;
las[u]=(++alc)->init(v,las[u]);
las[v]=(++alc)->init(u,las[v]);
}

namespace poly{
using Z=int;
using MZ=i64;
using poly_t=Z[mxdg];
using poly=Z*const;

inline int calcpw2(const int&n){
int t=1;
for(;t<n;t<<=1);
return t;
}

poly_t wn,iwn,cw;

inline void polyinit(){
const int wnv=fpow(proot,(p-1)/mxdg);
for(int i=(wn[0]=iwn[0]=1);i!=mxdg;++i)
wn[i]=(MZ)wn[i-1]*wnv%p;
std::reverse_copy(wn+1,wn+mxdg,iwn+1);
}

void DFT(poly&a,const int&n){
for(int i=0,j=0;i!=n;++i){
if(i<j)swap(a[i],a[j]);
for(int l=n>>1;(j^=l)<l;l>>=1);
}
cw[0]=1;
poly endpos=a+n,wendpos=wn+mxdg;
for(int l=1,tp=mxdg>>1;l!=n;l<<=1,tp>>=1){
for(Z*i=wn+tp,*w=cw;i!=wendpos;i+=tp)
*++w=*i;
for(Z*i=a,z;i!=endpos;i+=l+l)
for(int j=0;j!=l;++j)
z=(MZ)i[j+l]*cw[j]%p,
i[j+l]=sub(i[j],z),
inc(i[j],z);
}
}
void IDFT(poly&a,const int&n){
for(int i=0,j=0;i!=n;++i){
if(i<j)swap(a[i],a[j]);
for(int l=n>>1;(j^=l)<l;l>>=1);
}
cw[0]=1;
poly endpos=a+n,wendpos=iwn+mxdg;
for(int l=1,tp=mxdg>>1;l!=n;l<<=1,tp>>=1){
for(Z*i=iwn+tp,*w=cw;i!=wendpos;i+=tp)
*++w=*i;
for(Z*i=a,z;i!=endpos;i+=l+l)
for(int j=0;j!=l;++j)
z=(MZ)i[j+l]*cw[j]%p,
i[j+l]=sub(i[j],z),
inc(i[j],z);
}const Z invn=fpow(n,p-2);
for(Z*i=a;i!=endpos;++i)
*i=(MZ)*i*invn%p;
}

inline void mul(const poly&f,const int&df,const poly&g,const int&dg,poly&ans){
static poly_t mul_f,mul_g;

if(df+dg<=150){
static MZ mul_r[200];
std::fill(mul_r,mul_r+df+dg+1,0);

for(int i=0;i<=df;++i)
for(int j=0;j<=dg;++j)
mul_r[i+j]+=(MZ)f[i]*g[j]%p;

for(int i=0;i<=df+dg;++i)
ans[i]=mul_r[i]%p;
}else{
const int n=calcpw2(df+dg+1);

std::copy(f,f+df+1,mul_f);std::fill(mul_f+df+1,mul_f+n,0);
std::copy(g,g+dg+1,mul_g);std::fill(mul_g+dg+1,mul_g+n,0);

DFT(mul_f,n);DFT(mul_g,n);
for(int i=0;i!=n;++i)
mul_f[i]=(MZ)mul_f[i]*mul_g[i]%p;
IDFT(mul_f,n);

std::copy(mul_f,mul_f+df+dg+1,ans);
}
}
}

poly::poly_t cp_t[18];
poly::poly cp=cp_t[0];
int idx,
mem[maxn+1],
lasi[maxn+1],
expf[maxn+1],
sz[maxn+1],
sum[maxn+1],
lis[maxn+1];

void calc(const int&d,const int&l,const int&r){
if(l==r)
return(void)(cp_t[d][l+l]=1,cp_t[d][l+l+1]=lis[l]);
const int m=(l+r)>>1;
calc(d+1,l,m);calc(d+1,m+1,r);
poly::mul(cp_t[d+1]+l+l,m-l+1,cp_t[d+1]+m+m+2,r-m,cp_t[d]+l+l);
}

const int n=io,k=io;

int ans;
void calc(const int&u,const int&fa){
sz[u]=1;
for(Edge*o=las[u];o;o=o->las)
if(o->v!=fa)
calc(o->v,u),
ans=((i64)sum[u]*sum[o->v]+ans)%p,
inc(sum[u],sum[o->v]),
sz[u]+=sz[o->v];

int cur=0;
for(Edge*o=las[u];o;o=o->las)
if(o->v!=fa)
lis[++cur]=sz[o->v];
if(cur)
lis[0]=lis[cur],
calc(0,0,cur-1);
else cp[0]=1;

const int usz=cur+1;

int tsum=0;
for(int i=0;i<=k&&i!=usz;++i)
tsum=((i64)cp[i]*expf[i]+tsum)%p;
inc(sum[u],tsum);

if(fa){
const int osz=n-sz[u];
cp[cur+1]=0;
for(int i=cur;~i;--i)
cp[i+1]=((i64)cp[i]*osz+cp[i+1])%p;
}

const int cn=cur+(fa!=0);

++idx;
for(Edge*o=las[u];o;o=o->las)
if(o->v!=fa){
const int vsz=sz[o->v];
if(lasi[vsz]!=idx){
int val=0;
for(int i=0,res=0;i<=k&&i!=cn;++i)
res=sub(cp[i],(i64)res*vsz%p),
val=((i64)expf[i]*res+val)%p;

lasi[vsz]=idx;
mem[vsz]=val;
}ans=((i64)mem[vsz]*sum[o->v]+ans)%p;
}
}

int main(){
if(k==1)
return printf("%lld",(1ll*(n-1ll)*n/2ll)%p),0;

expf[0]=1;
for(int i=0;i!=k;++i)
expf[i+1]=(i64)expf[i]*(k-i)%p;
for(int i=n-1;i;--i)
lnk();

poly::polyinit();
calc(1,0);
printf("%d",ans);

return 0;
}