BZOJ-4940: [Ynoi2016]这是我自己的发明

[文章目录]

Description

给一个树,n 个点,有点权,初始根是 1。m 个操作,每次操作:
1. 将树根换为 x。
2. 给出两个点 x,y,从 x 的子树中选每一个点,y 的子树中选每一个点,如果两个点点权相等,ans++,求 ans。
n <= 100000 , m <= 500000 , 1 <= a[i] <= 1000000000

将子树看成dfs序上的区间,询问等价于对于每种颜色,累加在两个区间出现次数的乘积。
将区间拆成两个前缀相减,问题转变成求两个前缀每种颜色出现次数的乘积,之后莫队处理每两个前缀的贡献就好了。

#include <cmath>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
using namespace std;
#define N 101000 
#define M 501000 
typedef long long ll;
inline char nc()
{
    static char buf[100000],*p1,*p2;
    return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int read()
{
    int re=0; char ch=nc();
    while(!isdigit(ch)) ch=nc();
    while(isdigit(ch)) re=re*10+(ch^'0'),ch=nc();
    return re;
}

char pbuf[100000],*pp=pbuf;
inline void push(const char ch)
{
    if(pp==pbuf+100000) fwrite(pbuf,1,100000,stdout),pp=pbuf;
    *pp++=ch;
}
void write(ll x)
{
    static char sta[20];
    if(!x) push('0'); int top=0;
    while(x) sta[++top]=(x%10)^'0',x/=10;
    while(top) push(sta[top--]); push('\n');
}

int n,m,id,tot,a[N],b[N],c[N];
int head[N],to[N<<1],nxt[N<<1],cnt;
inline void add(int x,int y)
{
    to[++cnt]=y; nxt[cnt]=head[x]; head[x]=cnt;
    to[++cnt]=x; nxt[cnt]=head[y]; head[y]=cnt;
}
int root,tim,in[N],out[N],Log[N],fa[N][20],dep[N];
void dfs(int x,int pre)
{
    in[x]=++tim; c[tim]=a[x]; fa[x][0]=pre;
    for(int i=1;i<=Log[dep[x]];++i)
        fa[x][i]=fa[fa[x][i-1]][i-1];
    for(int i=head[x];i;i=nxt[i])
        if(to[i]!=pre)
            dep[to[i]]=dep[x]+1,dfs(to[i],x);
    out[x]=tim;
}
int getfa(int x,int y)
{
    while(y) x=fa[x][Log[-y&y]],y-=-y&y;
    return x;
}
struct node
{
    int b,l,r,id,f;
    node(){}
    node(int _l,int _r,int _id,int _f){l=_l,r=_r,id=_id,f=_f; if(l>r) swap(l,r);}
    bool operator < (const node &x)const {return b==x.b ? r<x.r : b<x.b;}
}q[M*9];
int bl,t1[N],t2[N];
ll ans[M],now;

int main()
{
    n=read(); m=read(); int i;
    for(i=1;i<=n;++i) b[i]=a[i]=read(); sort(b+1,b+n+1);
    for(i=1;i<=n;++i) a[i]=lower_bound(b+1,b+n+1,a[i])-b;
    for(i=2;i<=n;++i) add(read(),read()),Log[i]=Log[i>>1]+1;
    dfs(1,0); root=1;
    int sw,x,y; tot=0;
    int z1[5],z2[5],r1,r2;
    for(i=1;i<=m;++i)
    {
        sw=read();
        if(sw==1) root=read();
        else
        {
            ++id; x=read(); y=read(); r1=r2=0;
            if(x==root) z1[++r1]=n;
            else if(in[root]<in[x]||in[root]>out[x]) z1[++r1]=out[x],z1[++r1]=in[x]-1;
            else x=getfa(root,dep[root]-dep[x]-1),z1[++r1]=n,z1[++r1]=out[x],z1[++r1]=in[x]-1;
            if(y==root) z2[++r2]=n;
            else if(in[root]<in[y]||in[root]>out[y]) z2[++r2]=out[y],z2[++r2]=in[y]-1;
            else y=getfa(root,dep[root]-dep[y]-1),z2[++r2]=n,z2[++r2]=out[y],z2[++r2]=in[y]-1;
            for(x=1;x<=r1;++x)
                for(y=1;y<=r2;++y) if(z1[x]&&z2[y])
                    q[++tot]=node(z1[x],z2[y],id,(x&1)^(y&1));
        }
    }
    bl=(int)(n/sqrt(tot))*2+1;
    for(i=1;i<=tot;++i) q[i].b=(q[i].l-1)/bl;
    sort(q+1,q+tot+1); x=y=0;
    for(i=1;i<=tot;++i)
    {
        while(x<q[i].l) ++x,++t1[c[x]],now+=t2[c[x]];
        while(y<q[i].r) ++y,++t2[c[y]],now+=t1[c[y]];
        while(x>q[i].l) --t1[c[x]],now-=t2[c[x]],--x;
        while(y>q[i].r) --t2[c[y]],now-=t1[c[y]],--y;
        ans[q[i].id]+=(q[i].f ? -now : now);
    }
    for(i=1;i<=id;++i) write(ans[i]);
    fwrite(pbuf,1,pp-pbuf,stdout);
    return 0;
}

发表评论

邮箱地址不会被公开。 必填项已用*标注