[文章目录]
Description
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。
我们将以下面的形式来要求你对这棵树完成一些操作:
I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
注意:从点u到点v的路径上的节点包括u和v本身
first time 写树链剖分,好爽。(%%%zkw线段树)
并没有用lca,deep运用要思考一下,deep[tid[x]]>deep[tid[y]]时,将x向上爬。
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
#define N 38324
int n,val[N],head[N],to[N<<1],nxt[N<<1],cnt,m;
int deep[N],size[N],son[N],fa[N],tot;
int mx[N<<2],sum[N<<2],tid[N],top[N],M;
char st[100];
void add(int x,int y)
{
to[++cnt]=y;
nxt[cnt]=head[x];
head[x]=cnt;
}
void dfs1(int x,int pre)//找重边,son[x]为x的重儿子,deep[x]为x的深度
{
int tmp=-1;
deep[x]=deep[pre]+1;fa[x]=pre;//记录fa[x],后面从重边向轻边上跳
for(int y,i=head[x];i;i=nxt[i])
{
y=to[i];if(y!=pre)
{
dfs1(y,x);size[x]+=size[y];
if(size[y]>tmp) tmp=size[y],son[x]=y;
}
}
size[x]++;
}
void dfs2(int x,int pre,int anc)
{
top[x]=anc;tid[x]=++M;
sum[M]=mx[M]=val[x];//用来建树(zkw线段树)
if(son[x]) dfs2(son[x],x,anc);//一定先搜索重儿子,保证重链在线段树上的映射是连续的
for(int y,i=head[x];i;i=nxt[i])
{
y=to[i];
if(y!=pre&&y!=son[x]) dfs2(y,x,y);
}
}
void build()
{
dfs1(1,1);
for(M=1;M<=n+1;M<<=1);//
int c=M;
dfs2(1,1,1);
M=c;
mx[M+n+1]=mx[M]=-1<<30;//别忘处理开区间的值,防止向上更新时出错
for(int i=M-1;i;i--) mx[i]=max(mx[i<<1],mx[i<<1|1]);
for(int i=M-1;i;i--) sum[i]=sum[i<<1]+sum[i<<1|1];
}
int query_sum(int l,int r)//zkw线段树
{
int re=0;
for(l--,r++;l^r^1;l>>=1,r>>=1)
{
if(~l&1) re+=sum[l^1];
if(r&1) re+=sum[r^1];
}
return re;
}
int query_mx(int l,int r)//%%%zkw
{
int re=-1<<30;
for(l--,r++;l^r^1;l>>=1,r>>=1)
{
if(~l&1) re=max(re,mx[l^1]);
if(r&1) re=max(re,mx[r^1]);
}
return re;
}
void fix(int id,int now)
{
val[id]=now;id=tid[id];//一定不能用在逗号表达式里
for(mx[id]=sum[id]=now,id>>=1;id;id>>=1)
{
mx[id]=max(mx[id<<1],mx[id<<1|1]);
sum[id]=sum[id<<1]+sum[id<<1|1];
}
}
int q_mx(int x,int y)
{
int re=-1<<30;
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]]) swap(x,y);//一定是以重链起点的深度参考
re=max(re,query_mx(tid[top[x]],tid[x]));
x=fa[top[x]];//跳到另一个重链上
}
if(deep[x]>deep[y]) swap(x,y);
re=max(re,query_mx(tid[x],tid[y]));//
return re;
}
int q_sum(int x,int y)
{
int re=0;
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]]) swap(x,y);
re+=query_sum(tid[top[x]],tid[x]);
x=fa[top[x]];
}
if(deep[x]>deep[y]) swap(x,y);
re+=query_sum(tid[x],tid[y]);
return re;
}
int main()
{
scanf("%d",&n);
for(int x,y,i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y);add(y,x);
}
for(int i=1;i<=n;i++)
scanf("%d",&val[i]);
build();
scanf("%d",&m);
int k1,k2;
while(m--)
{
scanf("%s%d%d",st,&k1,&k2);
if(st[3]=='X') printf("%d\n",q_mx(k1,k2));
else if(st[3]=='M') printf("%d\n",q_sum(k1,k2));
else fix(k1,k2);
}
return 0;
}