熟练剖分学习笔记
1.前言
Q:为什么要用树链剖分???
A:用于树上路径搞事情!!!
2.概念与说法
重儿子:x的儿子的子树中节点数最多的儿子
重边:连接重儿子
重链:重边组成
轻边:对应重边
轻链:轻边构成
siz[x]以x为根的子树的节点数,包括它自己,son[x]为x的重儿子,top[x]:x节点所在重链的链头编号
fa[x]代表x的父亲节点,id[x]表示x节点重新编号后的编号(待会将怎么用),dep[x]x节点的深度。
3.树链剖分实质
树链剖分实质就是指将一棵树剖分成一条条重链和轻链。
4.树链剖分实现
主要由两个dfs实现
第一个dfs,处理出每个节点的重儿子,子树节点数,父亲节点,与深度。
代码:
void dfs1(int x,int f,int deep)//当前节点x,父亲f,深度deep { fa[x]=f;siz[x]=1; dep[x]=deep;int maxson=-1; for(int i=head[x]; i; i=G[i].nxt) { int v=G[i].v; if(v==f)continue;//不能访问父亲 dfs1(v,x,deep+1); siz[x]+=siz[v];//子树节点数计算 if(siz[v]>maxson)maxson=siz[v],son[x]=v;//找出重儿子 } }
第二个dfs,处理出每个节点,重新编号后的编号id[],所在重链的链头编号。
注意:dfs时,先遍历玩重儿子,再遍历轻儿子,这样做的好处是重链上的编号是连续的,便于用数据结构维护(线段树,树状数组等)
还有一点,如果这个节点不在重链上,那么它的top[x]就默认为它自己。
代码:
void dfs2(int x,int topf)//当前节点x,链头topf { id[x]=++tot;wt[tot]=w[x]; top[x]=topf;if(!son[x])return; dfs2(son[x],topf);//重儿子 for(int i=head[x]; i; i=G[i].nxt) { int v=G[i].v; if(v==fa[x]||v==son[x])continue; dfs2(v,v);//链头为自己(轻儿子) } }
5.模板题
洛古
P3384 【模板】树链剖分
https://www.luogu.org/problemnew/show/P3384
给出线段树维护版本
代码:
#include<bits/stdc++.h> using namespace std; typedef long long ll; inline int read() { int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } #define mid ((l+r)>>1) #define lson (rt<<1) #define rson (rt<<1|1) #define len (r-l+1) const int maxn=200000+10; struct Node{int nxt,v;}G[maxn]; int head[maxn],cnt=1,tot=0,a[maxn<<2],laz[maxn<<2],wt[maxn],dep[maxn],res; int w[maxn],fa[maxn],son[maxn],top[maxn],siz[maxn],root,n,m,mod,id[maxn]; inline void insert(int u,int v){G[cnt]=(Node){head[u],v};head[u]=cnt++;} inline void dfs1(int x,int f,int deep) { dep[x]=deep; fa[x]=f;siz[x]=1; int maxson=-1; for(int i=head[x];i;i=G[i].nxt) { int v=G[i].v; if(v==f)continue; dfs1(v,x,deep+1); siz[x]+=siz[v]; if(siz[v]>maxson)son[x]=v,maxson=siz[v]; } } inline void dfs2(int x,int topf) { id[x]=++tot;wt[tot]=w[x]; top[x]=topf;if(!son[x])return; dfs2(son[x],topf); for(int i=head[x];i;i=G[i].nxt) { int v=G[i].v; if(v==fa[x]||v==son[x])continue; dfs2(v,v); } } inline void build(int rt,int l,int r) { if(l==r){a[rt]=wt[l];if(a[rt]>mod)a[rt]%=mod;return;} build(lson,l,mid);build(rson,mid+1,r); a[rt]=(a[lson]+a[rson])%mod; } inline void pushdown(int rt,int leen) { laz[lson]+=laz[rt];laz[rson]+=laz[rt]; a[lson]+=laz[rt]*(leen-(leen>>1)); a[rson]+=laz[rt]*(leen>>1); a[lson]%=mod;a[rson]%=mod;laz[rt]=0; } inline void update(int rt,int l,int r,int L,int R,int k) { if(L<=l&&r<=R){laz[rt]+=k;a[rt]+=len*k;a[rt]%=mod;return;} else { if(laz[rt])pushdown(rt,len); if(L<=mid)update(lson,l,mid,L,R,k); if(R>mid)update(rson,mid+1,r,L,R,k); a[rt]=(a[lson]+a[rson])%mod; } } inline void query(int rt,int l,int r,int L,int R) { if(L<=l&&r<=R){res+=a[rt];res%=mod;return;} else { if(laz[rt])pushdown(rt,len); if(L<=mid)query(lson,l,mid,L,R); if(R>mid)query(rson,mid+1,r,L,R); } } inline void addlu(int x,int y,int k) { k%=mod; while(top[x]!=top[y]) { if(dep[top[x]]<dep[top[y]])swap(x,y); update(1,1,n,id[top[x]],id[x],k); x=fa[top[x]]; } if(dep[x]>dep[y])swap(x,y); update(1,1,n,id[x],id[y],k); } inline int asklu(int x,int y) { int ans=0;res=0; while(top[x]!=top[y]) { if(dep[top[x]]<dep[top[y]])swap(x,y); res=0; query(1,1,n,id[top[x]],id[x]);ans+=res; ans%=mod;x=fa[top[x]]; } if(dep[x]>dep[y])swap(x,y);res=0; query(1,1,n,id[x],id[y]);ans+=res; return ans%mod; } inline void addshu(int x,int y) { update(1,1,n,id[x],id[x]+siz[x]-1,y); } inline int askshu(int x) { res=0; query(1,1,n,id[x],id[x]+siz[x]-1); return res; } inline void init() { n=read();m=read();root=read();mod=read(); for(int i=1;i<=n;++i)w[i]=read(); for(int i=1,x,y;i<n;++i) { x=read();y=read(); insert(x,y);insert(y,x); } dfs1(root,0,1); dfs2(root,root); build(1,1,n); } int main() { init(); for(int i=1;i<=m;++i) { int k,x,y,z;k=read(); if(k==1){x=read();y=read();z=read();addlu(x,y,z);} else if(k==2){x=read();y=read();printf("%d\n",asklu(x,y));} else if(k==3){x=read();z=read();addshu(x,z);} else {x=read();printf("%d\n",askshu(x));} } return 0; }
View Code
6.求lca
比较简单直接给出代码(懒):
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int maxn=500000+10; inline int read() { int x=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } int n,w[maxn],wt[maxn],cnt=1,tot,res,m,s; int q,id[maxn],fa[maxn],top[maxn],siz[maxn],son[maxn],head[maxn],dep[maxn]; #define mid ((l+r)>>1) #define lson (rt<<1),l,mid #define rson (rt<<1|1),mid+1,r #define len r-l+1 struct Node{int nxt,v;} G[maxn<<1]; void insert(int u,int v){G[cnt]=(Node){head[u],v};head[u]=cnt++;} void dfs1(int x,int f,int deep) { fa[x]=f;siz[x]=1; dep[x]=deep;int maxson=-1; for(int i=head[x]; i; i=G[i].nxt) { int v=G[i].v; if(v==f)continue; dfs1(v,x,deep+1); siz[x]+=siz[v]; if(siz[v]>maxson)maxson=siz[v],son[x]=v; } } void dfs2(int x,int topf) { id[x]=++tot;wt[tot]=w[x]; top[x]=topf;if(!son[x])return; dfs2(son[x],topf); for(int i=head[x]; i; i=G[i].nxt) { int v=G[i].v; if(v==fa[x]||v==son[x])continue; dfs2(v,v); } } int LCA(int x,int y) { while(top[x]!=top[y]) { if(dep[top[x]]<dep[top[y]])swap(x,y); x=fa[top[x]]; } if(dep[x]>dep[y])swap(x,y); return x; } void init() { n=read(),m=read(),s=read(); for(int i=1,a,b; i<n; ++i) { a=read(),b=read(); insert(a,b);insert(b,a); } dfs1(s,0,1);dfs2(s,s); } int main() { init(); for(int i=1,x,y; i<=m; ++i) { x=read();y=read(); printf("%d\n",LCA(x,y)); } return 0; }
View Code