[文章目录]
Description
给你一颗n个节点带有点权的树,给出m次询问,每次询问u-v路径上点权第k小的点权。强制在线。n,m<=
树上主席树。
发现第k大,显然主席树(其他什么也不会啊233)。考虑将节点到根的链看作为所维护的权值线段树的前缀,每个节点到根节点的链上的所有点权就可以直接从自己父亲那里继承过来大部分。
那么每次一个区间的第k大,就相当于两个端点的链-lca的父亲的链的区间相减。不过lca节点会计算两遍,需要减掉。
另外,神一般的输出格式。。。无语了。
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define N 101000
using namespace std;
int n,m,w[N],b[N],c[N],col,ans;
int head[N],nxt[N<<1],to[N<<1],cnt;
int fa[N][20],deep[N];
inline void add(int x,int y)
{
to[++cnt]=y;
nxt[cnt]=head[x];
head[x]=cnt;
}
struct seg
{
int l,r,sum;
}tre[N*20];
int root[N],tot;
int hash(int x)
{
int l=1,r=col+1,mid;
while(l<r)
{
mid=l+r>>1;
if(c[mid]>=x) r=mid;
else l=mid+1;
}
return r;
}
void update(int l,int r,int val,int &x,int y)
{
tre[++tot]=tre[y]; tre[tot].sum++; x=tot;
if(l==r) return ;
int mid=l+r>>1;
if(val<=mid) update(l,mid,val,tre[x].l,tre[y].l);
else update(mid+1,r,val,tre[x].r,tre[y].r);
}
void dfs(int x)
{
update(1,col,hash(w[x]),root[x],root[fa[x][0]]);
for(int i=head[x];i;i=nxt[i])
if(to[i]!=fa[x][0])
{
fa[to[i]][0]=x;
deep[to[i]]=deep[x]+1;
dfs(to[i]);
}
}
int lca(int x,int y)
{
if(deep[x]<deep[y]) swap(x,y);
for(int i=17;i>=0;i--)
if(deep[fa[x][i]]>=deep[y]) x=fa[x][i];
if(x==y) return x;
for(int i=17;i>=0;i--)
if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
int query(int l,int r,int k,int r1,int r2,int anc,int val)
{
if(l==r) return r;
int tmp=tre[tre[r1].l].sum+tre[tre[r2].l].sum-2*tre[tre[anc].l].sum,mid=l+r>>1;//一定是左儿子的相减
if(l<=val&&mid>=val) tmp--;//减去lca节点重复计算次数
if(k<=tmp) return query(l,mid,k,tre[r1].l,tre[r2].l,tre[anc].l,val);
else return k-=tmp,query(mid+1,r,k,tre[r1].r,tre[r2].r,tre[anc].r,val);//不能忘记减k,逗号表达式返回值为最后面的语句
}
void cal(bool flag)
{
int x,y,k;scanf("%d%d%d",&x,&y,&k);
static int ans; x^=ans; int anc=lca(x,y);//static 函数中直接缓存上一次调用的答案
ans=c[query(1,col,k,root[x],root[y],root[fa[anc][0]],hash(w[anc]))];
if(flag) printf("%d\n",ans);//神。。。PE
else printf("%d",ans);
}
int main()
{
scanf("%d%d",&n,&m);
int i,j,x,y;
for(i=1;i<=n;i++) scanf("%d",w+i),b[i]=w[i];
sort(b+1,b+n+1);
c[++col]=b[1];
for(i=2;i<=n;i++) if(b[i]!=b[i-1]) c[++col]=b[i];
for(i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
deep[1]=1; dfs(1);
for(i=1;i<=17;i++)
for(j=1;j<=n;j++)
fa[j][i]=fa[fa[j][i-1]][i-1];
for(i=1;i<=m;i++) cal(i!=m);
return 0;
}