21牛客多校J - Tree(思维,rmq)
题目链接
J-Tree_2021牛客暑期多校训练营8 (nowcoder.com)
题解
将Toilet-Ares简称为\(\rm T\),将Unidentified-person简称为\(\rm U\)。
在\(s\)到\(t\)的路径中,如果离开路径,那么玩家能走的最远距离就是他走进的子树的最远距离。以\(t\)为根构建有根树,很容易就可以预处理出在\(s\)到\(t\)的路径上一个点离开后能走的最远距离。
假设数组\(a[i]\)代表\(\rm T\)从\(s\)出发在\(i\)点离开的能走最远距离,数组\(b[i]\)代表\(\rm U\)从\(t\)出发在\(i\)点离开能走的最远距离,路径范围为\([1, {\rm len}]\)。如果当前\(\rm T\)先手在\(l\)点,\(\rm U\)在\(r\)点,\(\rm T\)离开路径,那么\(\rm U\)就会选择选择\(\max\limits_{l<i\le r}(b[i])\),此时差值为\(a[l]-\max\limits_{l<i\le r}(b[i])\),\(\rm U\)先手离开路径也同理。
所以如果一方离开路径,马上就可以知道结果;否则由于两方都没有离开路径,所以可以枚举\(\rm T\)和\(\rm U\)所在位置,因为他们是轮流走的,复杂度是线性的,按照以下策略统计答案:
- 如果当前\(\rm T\)离开路径差值严格大于下一步\(\rm U\)离开路径的差值,那么可以结束循环,因为\(\rm U\)会尽可能让差值小,如果\(\rm T\)不离开,\(\rm U\)至少会在下一步离开使得差值比现在\(\rm T\)离开路径差值变小;
- 否则,假设当前\(\rm T\)离开路径的差值为\(d_{\rm T}\),接下来\(\rm U\)离开路径的差值为\(d_{\rm U}\)。由于\(\rm T\)要差值最大,会有\(ans_{\rm T}=\max(ans_{\rm T}, d_T)\);\(\rm U\)要差值最小,有\(ans_{\rm U}=\max(ans_{\rm U}, d_U)\)。如果想在这一步之前结束,有\(ans=\max(ans_{\rm T}, ans_{\rm U})\)。
- 继续枚举下一步\(\rm T\)和\(\rm U\)的位置,最终答案为\(ans\)。
#include <bits/stdc++.h>
#define endl \'\n\'
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define mp make_pair
#define seteps(N) fixed << setprecision(N)
typedef long long ll;
using namespace std;
/*-----------------------------------------------------------------*/
ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
#define INF 0x3f3f3f3f
const int N = 1e6 + 10;
const int M = 1e6 + 10;
const double eps = 1e-5;
struct edge {
int ntp, ne;
} ed[M];
int head[N];
int si;
void add(int u, int v) {
si++;
ed[si] = {v, head[u]};
head[u] = si;
}
int up[N];
int mxlen[N];
int s, t;
int dfs(int p, int fa, int d) {
up[p] = fa;
int mxd = d;
for(int e = head[p]; e; e = ed[e].ne) {
int ntp = ed[e].ntp;
if(ntp == fa) continue;
mxd = max(mxd, dfs(ntp, p, d + 1));
}
mxlen[p] = mxd - d + 1;
return mxd;
}
int len[N];
int premx[26][N], lasmx[26][N];
int que(int l, int r, int arr[][N]) {
int k = log2(r - l + 1);
return max(arr[k][l], arr[k][r-(1<<k)+1]);
}
int main() {
IOS;
int n;
cin >> n >> s >> t;
for(int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
add(v, u);
add(u, v);
}
dfs(t, 0, 1);
int cur = s;
int pre = 0;
int cnt = 0;
while(cur) {
int num = 0;
for(int e = head[cur]; e; e = ed[e].ne) {
int ntp = ed[e].ntp;
if(ntp == pre || ntp == up[cur]) continue;
num = max(mxlen[ntp], num);
}
len[++cnt] = num;
pre = cur;
cur = up[cur];
}
for(int i = 1; i <= cnt; i++) {
premx[0][i] = len[i] + i - 1;
}
for(int i = cnt; i >= 1; i--) {
lasmx[0][i] = len[i] + cnt - i;
}
for(int i = 1; i < 25; i++) {
for(int j = 1; j + (1 << i) - 1 <= cnt; j++) {
premx[i][j] = max(premx[i - 1][j], premx[i - 1][j + (1 << (i - 1))]);
lasmx[i][j] = max(lasmx[i - 1][j], lasmx[i - 1][j + (1 << (i - 1))]);
}
}
int ans1 = -INF;
int ans2 = INF;
int ans = -INF;
int l = 1, r = cnt;
while(l < r) {
ans1 = max(ans1, premx[0][l] - que(l + 1, r, lasmx));
ans = max(ans, min(ans1, ans2));
if(l + 1 < r - 1) {
if(premx[0][l] - que(l + 1, r, lasmx) > que(l + 1, r - 1, premx) - lasmx[0][r]) break;
ans2 = min(ans2, que(l + 1, r - 1, premx) - lasmx[0][r]);
}
l++, r--;
}
if(l == r) {
ans1 = max(ans1, premx[0][l] - lasmx[0][l + 1]);
}
ans = max(ans, min(ans1, ans2));
cout << ans << endl;
}