联考day7 C. 树和森林 树形DP
题目描述
样例
样例输入
8 5
BBWWWBBW
1 2
2 3
4 5
6 7
7 8
样例输出
84
2
1 4
样例解释
分析
首先,我们要预处理出一个点到该联通块内所有点的距离之和 \(f\)
这个东西用换根 \(DP\) 搞一下就可以了
那么这个联通块内所有点对的距离之和就是这个联通块所有点的 \(f\) 值之和除以 \(2\)
除以 \(2\) 是因为点对是无序的
对于子任务一:
当联通块的个数为 \(2\) 时,两个联通块内的贡献我们已经考虑了
我们需要考虑的就是跨过联通块的贡献
我们设从联通块 \(1\) 中选择的点为 \(a\),从联通块 \(2\) 中选择的点为 \(b\) ,联通块的大小是 \(cnt\)
那么贡献就是 \((f[a]+cnt[1])*(n-cnt[1])+cnt[1]*f[b]\)
前半部分统计的是 \(1\) 联通块经过 \((a,b)\) 这条边的贡献
后半部分统计的是 \(2\) 联通块对 \(1\) 联通块的贡献
显然,我们需要把两个联通块内 \(f\) 值最大的点连接起来
当联通块的个数为 \(3\) 时,我们可以枚举哪个联通块在中间
设第一个联通块与第二个联通块通过 \((x, y)\) 相连
第二个联通块与第三个联通块通过 \((u, v)\) 相连。
则联通块与联通块之间的贡献为:
\((f[x]+cnt[1])(n−cnt[1])+(f[v]+cnt[3])(n−cnt[3])+cnt[1]f[y]+cnt[3]f[u]+dis(y,u)cnt[1]cnt[3]\)
其中 \(dis\) 代表两点间的距离
那么 \(x\), \(v\) 应该是联通块 \(1\) 和联通块 \(3\) 中 \(f\) 最大的点。
对于联通块 \(2\),我们只要求出 \(cnt[1]f[y]+cnt[3]f[u]+dis(y,u)cnt[1]cnt[3]\) 的最大值即可,这个
可以通过 \(dp\) 实现
我们分别开两个数组存储当前 \(cnt[1]f[y]\) 和 \(cnt[3]f[u]\)的最大值
自底向上 \(dp\)
对于后面的 \(dis\) 值,我们只需要在向上递归是加一个 \(cnt[1]cnt[3]\) 即可
对于子任务二:
考虑一棵树内的所有不满足条件的点。如果有奇数个这样的点,那么无解,否则一定有解,
并且唯一。
我们要使这些点变成合法的,就需要对它们进行两两匹配,然后改变每一对点路径上所有
边的存在情况。
那么,如果一条边两侧的连通块内有奇数个这样的点,这个边的状态就一定被改变了奇数
次,因此它被删掉了;否则它没有被删掉。
总复杂度 \(O(n)\)
代码
#include<cstdio>
#include<cstring>
#include<iostream>
#define rg register
inline int read(){
rg int x=0,fh=1;
rg char ch=getchar();
while(ch<'0' || ch>'9'){
if(ch=='-') fh=-1;
ch=getchar();
}
while(ch>='0' && ch<='9'){
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*fh;
}
const int maxn=1e6+5;
int n,m,sl,h[maxn],tot=1,rt1,rt2,rt3,siz[maxn],cnt[maxn],vis[maxn];
struct asd{
int to,nxt;
}b[maxn];
void ad(int aa,int bb){
b[tot].to=bb;
b[tot].nxt=h[aa];
h[aa]=tot++;
}
char s[maxn];
long long f[maxn],g[maxn];
void dfs(int rt,int now,int fa){
siz[now]=1;
vis[now]=rt;
cnt[rt]++;
for(rg int i=h[now];i!=-1;i=b[i].nxt){
rg int u=b[i].to;
if(u==fa) continue;
dfs(rt,u,now);
siz[now]+=siz[u];
g[now]+=g[u]+siz[u];
}
}
void dfs2(int rt,int now,int fa){
if(now==rt)f[now]=g[now];
for(rg int i=h[now];i!=-1;i=b[i].nxt){
rg int u=b[i].to;
if(u==fa) continue;
f[u]=f[now]+cnt[rt]-siz[u]-siz[u];
dfs2(rt,u,now);
}
}
int jl,mmax,A,B,C,D;
void solve1(){
rg long long ans=0;
for(rg int i=1;i<=n;i++){
if(vis[i]==0){
dfs(i,i,0);
if(!rt1) rt1=i;
else rt2=i;
}
}
dfs2(rt1,rt1,0);
dfs2(rt2,rt2,0);
for(rg int i=1;i<=n;i++){
ans+=f[i];
}
ans/=2;
mmax=-1,jl=0;
for(rg int i=1;i<=n;i++){
if(vis[i]==rt1){
if(f[i]>mmax){
mmax=f[i];
jl=i;
}
}
}
A=jl;
mmax=-1,jl=0;
for(rg int i=1;i<=n;i++){
if(vis[i]==rt2){
if(f[i]>mmax){
mmax=f[i];
jl=i;
}
}
}
B=jl;
ans+=(f[A]+cnt[rt1])*(n-cnt[rt1])+cnt[rt1]*f[B];
printf("%lld\n",ans);
}
long long maxb[maxn],maxc[maxn],haha=0;
void dfs5(int now,int fa,int cntl,int cntr){
maxb[now]=cntl*f[now];
maxc[now]=cntr*f[now];
for(rg int i=h[now];i!=-1;i=b[i].nxt){
rg int u=b[i].to;
if(u==fa) continue;
dfs5(u,now,cntl,cntr);
haha=std::max(haha,maxb[now]+maxc[u]+1LL*cntl*cntr);
haha=std::max(haha,maxc[now]+maxb[u]+1LL*cntl*cntr);
maxb[now]=std::max(maxb[u]+1LL*cntl*cntr,maxb[now]);
maxc[now]=std::max(maxc[u]+1LL*cntl*cntr,maxc[now]);
}
}
long long js(int l,int mids,int r){
memset(maxb,0,sizeof(maxb));
memset(maxc,0,sizeof(maxc));
rg long long jla=0,jld=0;
haha=0;
for(rg int i=1;i<=n;i++){
if(vis[i]==l){
if(f[i]>jla) jla=f[i];
}
if(vis[i]==r){
if(f[i]>jld) jld=f[i];
}
}
dfs5(mids,0,cnt[l],cnt[r]);
haha+=(jla+cnt[l])*(n-cnt[l])+(jld+cnt[r])*(n-cnt[r]);
return haha;
}
void solve2(){
rg long long ans=0;
for(rg int i=1;i<=n;i++){
if(vis[i]==0){
dfs(i,i,0);
if(!rt1) rt1=i;
else if(!rt2)rt2=i;
else rt3=i;
}
}
dfs2(rt1,rt1,0);
dfs2(rt2,rt2,0);
dfs2(rt3,rt3,0);
for(rg int i=1;i<=n;i++){
ans+=f[i];
}
ans/=2;
rg long long nans=0;
nans=std::max(nans,js(rt1,rt2,rt3));
nans=std::max(nans,js(rt1,rt3,rt2));
nans=std::max(nans,js(rt2,rt1,rt3));
ans+=nans;
printf("%lld\n",ans);
}
int sta[maxn],tp,du[maxn],num[maxn];
bool kil[maxn];
void dfs3(int now,int fa){
for(rg int i=h[now];i!=-1;i=b[i].nxt){
rg int u=b[i].to;
if(u==fa) continue;
dfs3(u,now);
num[now]+=num[u];
}
}
void dfs4(int rt,int now,int fa){
for(rg int i=h[now];i!=-1;i=b[i].nxt){
rg int u=b[i].to;
if(u==fa) continue;
dfs4(rt,u,now);
if((num[rt]-num[u])&1 && num[u]&1){
kil[i]=1;
}
}
}
void solve3(){
for(rg int i=1;i<=n;i++){
if(s[i]=='B'){
if(du[i]%2==0) num[i]=1;
} else {
if(du[i]&1) num[i]=1;
}
}
if(sl==2){
dfs3(rt1,0);
dfs3(rt2,0);
if(num[rt1]&1 || num[rt2]&1){
printf("-1\n");
return;
}
dfs4(rt1,rt1,0);
dfs4(rt2,rt2,0);
} else {
dfs3(rt1,0);
dfs3(rt2,0);
dfs3(rt3,0);
if(num[rt1]&1 || num[rt2]&1 || num[rt3]&1){
printf("-1\n");
return;
}
dfs4(rt1,rt1,0);
dfs4(rt2,rt2,0);
dfs4(rt3,rt3,0);
}
for(rg int i=1;i<tot;i+=2){
if(kil[i] || kil[i+1]) continue;
sta[++tp]=(i+1)/2;
}
printf("%d\n",tp);
for(rg int i=1;i<=tp;i++){
printf("%d ",sta[i]);
}
printf("\n");
}
int main(){
freopen("lct.in","r",stdin);
freopen("lct.out","w",stdout);
memset(h,-1,sizeof(h));
n=read(),m=read();
sl=(n-m);
scanf("%s",s+1);
rg int aa,bb;
for(rg int i=1;i<=m;i++){
aa=read(),bb=read();
ad(aa,bb);
ad(bb,aa);
du[aa]++;
du[bb]++;
}
if(sl==2){
solve1();
} else {
solve2();
}
solve3();
return 0;
}