BZOJ5398 admirable

Problem

给出一棵树,选出k条路径,使得每条边只会被0,1,k条路径覆盖。求方案数模1e9+9。

Solution

首先可以枚举那个被选了k次的路径,那么选了一次的边一定是从这个路径延伸而来的。

那么问题就是在某个点u的子树中选个k个点,使得选出的点两两LCA是u的方案数;对于u的每个子树,最多只能选一个点,选一个点的方案数为size[v]。

设f[u][i]表示在u的子树中选了i个点的方案,对于(u,v)这条边,显然有转移为$f[u][i]=f[u][i]+f[u][i-1]\cdot size[v]$

题目要求的就是在k个点中选若干个点走到子树中,我们可以定义点x向下的方案数为g[x]。

$$g[x]=\sum_{i=1}^{min(deg[x],k)}A(k,i)\cdot f[x][i]$$

但是这样dp一遍以后只求出了以1号点为根的方案树。

对于这两类情况只能我们发现右边的做不了。我们需要的是以u子树内离u最近的那个点为根的答案。

如果我们以每个点为根都做一遍复杂度为O(n^3)。

考虑怎么优化这个过程。

首先上面的那个dp有点想一个背包,仔细观察一下发现每个点可以看成${1,size[x]}$这样的多项式,这个dp的过程就是一个卷积。根据套路可以分治fft快速计算。

然后会获得一个$O(n^2log^n)$的做法,但是5000都过不了。

其实,对于每个点去dp一遍是没有必要的,因为有很多dp出来的状态是没有用的,类似退背包的操作,可以考虑退掉自己这个子树的多项式。

我们最开始把所有答案加入多项式,也要加入自己上面的点的贡献,也就是(n-size[x])。然后由于{1,size[x]}这个多项式是很独特的,并不需要多项式除法,可以O(n)快速乘或者除掉。

对于x点的每个子树,把它的贡献退掉以后就可以得到以他为根的时候的答案。

然而这样做并不行,对于一个菊花图,退一个点然后算一个点,复杂度退化为了O(n^2)级别。

再想想?其实我们每次退掉一个点只和它的子树大小有关,和这个点具体是什么没有关系,由于不同的子树大小为根号级别,每个大小单独算一下就好了。

复杂度$O(nlog^2n+n\sqrt n)$

Code

说起来很简单,代码好想并不是很好打。(本蒟蒻打了一下午+一晚上才勉强打完,但是常数巨大)。

由于出题人过于毒瘤,NTT不行FFT精度不够,又又又要打MTT,甚至不开longdouble就不行。于是常数就上天了,迟早💊。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
#include <bits/stdc++.h>
using namespace std;
#define IL inline
#define rep(i, j, k) for (int i = j; i <= k; ++i)
#define repd(i, j, k) for (int i = j; i >= k; --i)
#define pb push_back
#define mp make_pair
#define fr first
#define se second
#define ef else if
#define re real
#define im imag
#define fre(x) freopen(x".in","r",stdin),freopen(x".out", "w", stdout)
#define popcnt(x) __builtin_popcount(x)
#define ll long long
#define m(x) ((x%mo+mo)%mo)
const int mo=1e9+9;
IL int read() { int x; int _w = scanf("%d", &x); return x; }
IL int read(int &x) { int _w=scanf("%d", &x); }
int pls(int x,int y){ x+=y;return x>=mo?x-mo:x;}
int dec(int x,int y){ x-=y;return x<0?x+mo:x;}
int mul(int x,int y){ return 1ll*x*y%mo;}
IL void write(int x) { printf("%d\n", x); }
int fpw(int x,int y,int r=1){for(;y;y>>=1,x=1ll*x*x%mo)if(y&1)r=1ll*r*x%mo;return r;}
const int maxn=4e5+10;
int n,k,ans;
vector<int>e[maxn],B[maxn],g[maxn];
int fac[maxn],ifac[maxn],size[maxn];
int A(int x,int y){return 1ll*fac[x]*ifac[x-y]%mo;}

typedef long double db;
class cpx{
public:
db a,b;
cpx(db aa=0,db bb=0){a=aa,b=bb;}
cpx operator + (cpx r)const {return cpx(a+r.a,b+r.b);}
cpx operator - (cpx r)const {return cpx(a-r.a,b-r.b);}
cpx operator * (cpx r){return cpx(a*r.a-b*r.b,a*r.b+b*r.a);}
};

vector<cpx>F1,F2,G1,G2,H1,H2,H3,H4;
const db pi=acos(-1);
int rev[maxn],Ff[maxn];

void dft(vector<cpx>&a,db o){
int n=a.size();
rep(i,0,n-1)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int i=1;i<n;i<<=1){
cpx e=cpx(cos(pi/i),o*sin(pi/i));
for(int j=0;j<n;j+=(i<<1)){
cpx x=cpx(1,0);
for(int k=j;k<j+i;k++,x=x*e){
cpx a1=a[k]; cpx a2=a[k+i]*x;
a[k]=a1+a2;a[k+i]=a1-a2;
}
}
}
}

void conv(vector<int>&a,vector<int>&b,vector<int>&c){
int n=a.size(),m=b.size();
int s=n+m-1;
while(s!=(s&-s))s+=(s&-s);int bb=1;
while((1<<bb)<s)bb++;
rep(i,0,s-1)rev[i]=rev[i>>1]>>1|(i&1)<<(bb-1);
F1.clear();F1.resize(s);F2.clear();F2.resize(s);G1.clear();G1.resize(s);G2.clear();G2.resize(s);
rep(i,0,n-1)F1[i]=cpx(a[i]&32767,0);rep(i,n,s-1)F1[i]=cpx(0,0);
rep(i,0,n-1)F2[i]=cpx(a[i]>>15,0);rep(i,n,s-1)F2[i]=cpx(0,0);
rep(i,0,m-1)G1[i]=cpx(b[i]&32767,0);rep(i,m,s-1)G1[i]=cpx(0,0);
rep(i,0,m-1)G2[i]=cpx(b[i]>>15,0);rep(i,m,s-1)G2[i]=cpx(0,0);
dft(F1,1);dft(G1,1);dft(F2,1);dft(G2,1);
H1.clear();H1.resize(s);H2.clear();H2.resize(s);H3.clear();H3.resize(s);H4.clear();H4.resize(s);
rep(i,0,s-1)H1[i]=F1[i]*G1[i];rep(i,0,s-1)H2[i]=F2[i]*G1[i];rep(i,0,s-1)H3[i]=F1[i]*G2[i];rep(i,0,s-1)H4[i]=F2[i]*G2[i];
dft(H1,-1);dft(H2,-1);dft(H3,-1);dft(H4,-1);
c.clear();c.resize(n+m);
rep(i,0,n+m-1){
ll a1=(ll)(H1[i].a/s+0.5)%mo;ll a2=(ll)(H2[i].a/s+0.5)%mo;
ll a3=(ll)(H3[i].a/s+0.5)%mo;ll a4=(ll)(H4[i].a/s+0.5)%mo;
c[i]=(a1+((a2+a3)<<15)+(a4<<30))%mo;
}
}

void test() {
int n,m;read(n),read(m);
vector<int>aa,bb,cc;
rep(i,0,n)aa.push_back(read());
rep(i,0,m)bb.push_back(read());
conv(aa,bb,cc);
rep(i,0,n+m)printf("%d ",cc[i]);
}

vector<int>tr[maxn<<2];
vector<int>Lst;
void bt(int p,int x,int y){
if(x==y){tr[p].clear();tr[p].pb(1);tr[p].pb(size[Lst[x]]);return;}
int mid=(x+y)>>1;
bt(p<<1,x,mid);bt(p<<1|1,mid+1,y);
conv(tr[p<<1],tr[p<<1|1],tr[p]);
}

void getdiv(vector<int>&a,int x){
int n=a.size();
rep(i,1,n-1)a[i]=dec(a[i],mul(a[i-1],x));
a.pop_back();
}

void getmul(vector<int>&a,int x){
int n=a.size();
a.push_back(mul(a[n-1],x));
repd(i,n-1,1)a[i]=pls(a[i],mul(a[i-1],x));
}

void dfs(int x,int fa){
int sz=1,tot=0;size[x]=1;
for(int i=0;i<e[x].size();++i){int v=e[x][i];if(v!=fa)dfs(v,x),size[x]+=size[v];}
Lst.clear();
for(int i=0;i<e[x].size();++i)if(e[x][i]!=fa)Lst.pb(e[x][i]),tot++;
vector<int>tmp,f,P;
if(Lst.size()){
tmp.pb(1);
bt(1,0,Lst.size()-1);
conv(tmp,tr[1],P);
} else P.pb(1);
for(int i=0;i<P.size();++i)B[x].pb(P[i]);
rep(i,0,min(k,tot))Ff[x]=pls(Ff[x],mul(A(k,i),P[i]));
getmul(P,n-size[x]);

vector<int>Gs,Ans;
for(int i=0;i<e[x].size();++i)if(e[x][i]!=fa)Gs.pb(size[e[x][i]]);
sort(Gs.begin(),Gs.end());
Gs.resize(unique(Gs.begin(),Gs.end())-Gs.begin());
Ans.resize(Gs.size());

for(int i=0;i<Gs.size();++i){
getdiv(P,Gs[i]);
rep(j,0,min(k,tot))Ans[i]=pls(Ans[i],mul(A(k,j),P[j]));
getmul(P,Gs[i]);
}

g[x].resize(e[x].size());
for(int i=0;i<e[x].size();++i)if(e[x][i]!=fa)g[x][i]=Ans[lower_bound(Gs.begin(),Gs.end(),size[e[x][i]])-Gs.begin()];
}

int sf[maxn];
void calc(int x,int fa) {
sf[x]=Ff[x];
for(int i=0;i<e[x].size();++i)if(e[x][i]!=fa)calc(e[x][i],x),sf[x]=pls(sf[x],sf[e[x][i]]);
}
void calcF(int x,int fa,int lst){
if(x!=1){
ans=pls(ans,mul(Ff[x],lst));
}
for(int i=0;i<e[x].size();++i){
if(e[x][i]==fa)continue;
calcF(e[x][i],x,pls(lst,dec(dec(sf[x],Ff[x]),sf[e[x][i]])));
}
}

int fa[maxn][1],dep[maxn],pos_f[maxn];
void prep(int x){
dep[x]=dep[fa[x][0]]+1;
for(int i=0;i<e[x].size();++i){
int v=e[x][i];
if(v==fa[x][0])continue;
fa[v][0]=x; pos_f[v]=i;
prep(v);
}
}
int jump(int x,int y){
repd(i,18,0)if(dep[fa[x][i]]>dep[y])x=fa[x][i];
return x;
}
int LCA(int x,int y){
if(dep[x]<dep[y])swap(x,y);
repd(i,18,0)if(dep[fa[x][i]]>=dep[y])x=fa[x][i];
if(x==y)return x;
repd(i,18,0)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
int solve(int x,int y){
int rhs=0;
if(dep[x]<dep[y])swap(x,y);
int t=LCA(x,y);
if(t==y){
int up=jump(x,y);
return mul(g[y][up],Ff[x]);
}else return mul(Ff[x],Ff[y]);
}
int main() {
fre("a");
read(n),read(k);
fac[0]=1;
rep(i,1,maxn-1)fac[i]=1ll*fac[i-1]*i%mo;
ifac[0]=1;ifac[maxn-1]=fpw(fac[maxn-1],mo-2);
repd(i,maxn-2,1)ifac[i]=1ll*ifac[i+1]*(i+1)%mo;
rep(i,1,n-1){int u=read(),v=read();e[u].pb(v);e[v].pb(u);}
prep(1);
dfs(1,0);
calc(1,1);
calcF(1,0,0);
ans=mul(ans,(mo+1)/2);
rep(x,2,n){ans=pls(ans,mul(sf[x],g[fa[x][0]][pos_f[x]]));}
write(ans);
return 0;
}