JDOJ-2013: [ZJOI2008]树的统计Count

[文章目录]

Description

一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。

我们将以下面的形式来要求你对这棵树完成一些操作:

I.                   CHANGE u t : 把结点u的权值改为t

II.                QMAX u v: 询问从点u到点v的路径上的节点的最大权值

III.             QSUM u v: 询问从点u到点v的路径上的节点的权值和

注意:从点u到点v的路径上的节点包括u和v本身

first time 写树链剖分,好爽。(%%%zkw线段树)

并没有用lca,deep运用要思考一下,deep[tid[x]]>deep[tid[y]]时,将x向上爬。

#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
#define N 38324
int n,val[N],head[N],to[N<<1],nxt[N<<1],cnt,m;
int deep[N],size[N],son[N],fa[N],tot;
int mx[N<<2],sum[N<<2],tid[N],top[N],M;
char st[100];
void add(int x,int y)
{
	to[++cnt]=y;
	nxt[cnt]=head[x];
	head[x]=cnt;
}
void dfs1(int x,int pre)//找重边,son[x]为x的重儿子,deep[x]为x的深度
{
	int tmp=-1;
	deep[x]=deep[pre]+1;fa[x]=pre;//记录fa[x],后面从重边向轻边上跳
	for(int y,i=head[x];i;i=nxt[i])
	{
		y=to[i];if(y!=pre)
		{
			dfs1(y,x);size[x]+=size[y];
			if(size[y]>tmp) tmp=size[y],son[x]=y;
		}
	}
	size[x]++;
}
void dfs2(int x,int pre,int anc)
{
	top[x]=anc;tid[x]=++M;
	sum[M]=mx[M]=val[x];//用来建树(zkw线段树)
	if(son[x])  dfs2(son[x],x,anc);//一定先搜索重儿子,保证重链在线段树上的映射是连续的
	for(int y,i=head[x];i;i=nxt[i])
	{
		y=to[i];
		if(y!=pre&&y!=son[x]) dfs2(y,x,y);
	}
}
void build()
{
	dfs1(1,1);
	for(M=1;M<=n+1;M<<=1);//
	int c=M;
	dfs2(1,1,1); 
	M=c;
	mx[M+n+1]=mx[M]=-1<<30;//别忘处理开区间的值,防止向上更新时出错
	for(int i=M-1;i;i--) mx[i]=max(mx[i<<1],mx[i<<1|1]);
	for(int i=M-1;i;i--) sum[i]=sum[i<<1]+sum[i<<1|1];
}
int query_sum(int l,int r)//zkw线段树
{
	int re=0;
	for(l--,r++;l^r^1;l>>=1,r>>=1)
	{
		if(~l&1) re+=sum[l^1];
		if(r&1) re+=sum[r^1];
	}
	return re;
}
int query_mx(int l,int r)//%%%zkw
{
	int re=-1<<30;
	for(l--,r++;l^r^1;l>>=1,r>>=1)
	{
		if(~l&1) re=max(re,mx[l^1]);
		if(r&1) re=max(re,mx[r^1]);
	}
	return re;
}
void fix(int id,int now)
{
	val[id]=now;id=tid[id];//一定不能用在逗号表达式里
	for(mx[id]=sum[id]=now,id>>=1;id;id>>=1)
	{
		mx[id]=max(mx[id<<1],mx[id<<1|1]);
		sum[id]=sum[id<<1]+sum[id<<1|1];
	}
}
int q_mx(int x,int y)
{
	int re=-1<<30;
	while(top[x]!=top[y])
	{
		if(deep[top[x]]<deep[top[y]])  swap(x,y);//一定是以重链起点的深度参考
		re=max(re,query_mx(tid[top[x]],tid[x]));
		x=fa[top[x]];//跳到另一个重链上
	}
	if(deep[x]>deep[y]) swap(x,y);
	re=max(re,query_mx(tid[x],tid[y]));//
	return re;
}
int q_sum(int x,int y)
{
	int re=0;
	while(top[x]!=top[y])
	{
		if(deep[top[x]]<deep[top[y]])  swap(x,y);
		re+=query_sum(tid[top[x]],tid[x]);
		x=fa[top[x]];
	}
	if(deep[x]>deep[y]) swap(x,y);
	re+=query_sum(tid[x],tid[y]);
	return re;
}
int main()
{
	scanf("%d",&n);
	for(int x,y,i=1;i<n;i++)
	{
		scanf("%d%d",&x,&y);
		add(x,y);add(y,x);
	}
	for(int i=1;i<=n;i++)
		scanf("%d",&val[i]);
	build(); 
	scanf("%d",&m);
	int k1,k2;
	while(m--)
	{
		scanf("%s%d%d",st,&k1,&k2);
		if(st[3]=='X') printf("%d\n",q_mx(k1,k2));
		else if(st[3]=='M') printf("%d\n",q_sum(k1,k2));
		else fix(k1,k2);
	}
	return 0;
}

发表评论

邮箱地址不会被公开。 必填项已用*标注