[国家集训队] Crash的文明世界
Description
给定一棵 \(n\) 个点的树,对于每个点 \(i\) 求 \(S(i)=\sum\limits_{j=1}^n \operatorname{dist(i,j)}^k\) 。\(n\leq 50000,k\leq 150\)。
Sol
根据斯特林展开,原式化为
\[
\begin{align}S(i)&=\sum\limits_{j=1}^n \sum\limits_{p=0}^k S(k,p)\cdot \dbinom{\operatorname{dist(i,j)}}{p} \cdot p!\\&=\sum_{p=0}^kS(k,p)\cdot p!\cdot\sum_{j=1}^n\dbinom{\operatorname{dist(i,j)}}{p}\end{align}
\]
这个式子启发我们对于每个点 \(i\) 和每个 \(p\) ,维护好 \(\sum\limits_{j=1}^n \dbinom{\operatorname{dist(i,j)}}p\) 就好了
又因为 \(\dbinom{n}{m}=\dbinom{n-1}{m-1}+\dbinom{n-1}{m}\) ,所以设 \(dp[i][p]=\sum\limits_{j=1}^n \dbinom{\operatorname{dist(i,j)}}p\) ,这样就可以递推了。
先做一遍树形\(\text{DP}\)求出每个点子树的\(\mathrm{dp}\)值,再换根一下求出子树外的\(\text{dp}\)值就行了。
复杂度 \(O(nk)\)。
Code
#pragma GCC optimize(2)
#include<bits/stdc++.h>
using namespace std;
typedef double db;
typedef long long ll;
const int K=155;
const int N=50005;
const int mod=10007;
int dp[N][K],f[K];
int fac[N],S[K][K];
int n,k,cnt,head[N];
struct Edge{
int to,nxt;
}edge[N<<1];
void add(int x,int y){
edge[++cnt].to=y;
edge[cnt].nxt=head[x];
head[x]=cnt;
}
void init(int n,int m){
fac[0]=1;
for(int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%mod;
S[0][0]=1;
for(int i=1;i<=K;i++)
for(int j=1;j<=i;j++)
S[i][j]=(S[i-1][j-1]+1ll*S[i-1][j]*j%mod)%mod;
}
void dfs(int now,int fa=0){
dp[now][0]=1;
for(int i=head[now];i;i=edge[i].nxt){
int to=edge[i].to;
if(to==fa) continue;
dfs(to,now);
(dp[now][0]+=dp[to][0])%=mod;
for(int j=1;j<=k;j++)
(dp[now][j]+=dp[to][j-1]+dp[to][j])%=mod;
}
}
void dfs2(int now,int fa=0){
if(fa){
f[0]=dp[now][0];
(dp[fa][0]-=dp[now][0]-mod)%=mod;
for(int j=1;j<=k;j++)
(dp[fa][j]-=dp[now][j-1]+dp[now][j]-mod-mod)%=mod,f[j]=dp[now][j];
(dp[now][0]+=dp[fa][0])%=mod;
for(int j=1;j<=k;j++)
(dp[now][j]+=dp[fa][j-1]+dp[fa][j])%=mod;
(dp[fa][0]+=f[0])%=mod;
for(int j=1;j<=k;j++)
(dp[fa][j]+=f[j]+f[j-1])%=mod;
}
for(int i=head[now];i;i=edge[i].nxt){
int to=edge[i].to;
if(to==fa) continue;
dfs2(to,now);
}
}
signed main(){
init(N-5,K-5);
scanf("%d%d",&n,&k);
for(int x,y,i=1;i<n;i++)
scanf("%d%d",&x,&y),add(x,y),add(y,x);
dfs(1); dfs2(1);
for(int i=1;i<=n;i++){
int ans=0;
for(int j=0;j<=k;j++)
(ans+=1ll*S[k][j]%mod*fac[j]%mod*dp[i][j]%mod)%=mod;
printf("%d\n",ans);
} return 0;
}