BZOJ-3772: 精神污染

[文章目录]

Description

统计树上一些路径互相包含的对数。点数<=100000 路径条数<=100000 所以说对数在10^10级别

话说这题我还讲了。。。讲的不怎么地就是了。
对于一条路径A->B,如果覆盖路径C->D当且仅当C和D都在A->B的路径上(废话),发现一条路径的两个段点都有限制,我们可以先满足一个,查询另一个。对于每个路径,将其中一个端点用链表连在另一端点上。对于每个点,我们在该维护到根节点路径上所有另一端点,入栈时间+1,出栈时间-1。那么对于一个路径A->B,我们路径一端在路径上的集合的另一端点可以拿出来,然后容斥原理查询A->B链上的另一端点个数。统计答案再减去相同链非完全覆盖覆盖的次数就是答案。

#include <vector>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define N 100010 
typedef long long ll;
vector<int>v[N];
ll ans;
int n,m,tot;
int head[N],to[N<<1],nxt[N<<1],cnt,fa[N][17],deep[N],in[N],out[N];
inline void add(int x,int y) {to[++cnt]=y; nxt[cnt]=head[x]; head[x]=cnt;}
void dfs1(int x,int pre)
{
    in[x]=++tot; fa[x][0]=pre; deep[x]=deep[pre]+1;
    for(int i=1;(1<<i)<=deep[x];++i) fa[x][i]=fa[fa[x][i-1]][i-1];
    for(int i=head[x];i;i=nxt[i]) if(to[i]!=pre)
        dfs1(to[i],x);
    out[x]=++tot;
}
struct ques {
    int x,y;
}q[N];
bool cmp(ques x,ques y){return x.x==y.x?x.y<y.y:x.x<y.x;}
struct seg {
    int ls,rs,siz;
}s[N*39];
int se,rt[N];
void update(int l,int r,int &x,int y,int w,int z)
{
    x=++se; s[x].siz=s[y].siz+z;
    if(l==r) return ; int mid=(l+r)>>1;
    if(w<=mid) s[x].rs=s[y].rs,update(l,mid,s[x].ls,s[y].ls,w,z);
    else s[x].ls=s[y].ls,update(mid+1,r,s[x].rs,s[y].rs,w,z);
}
int query(int l,int r,int x,int y,int r1,int r2,int r3,int r4)
{
    if(x<=l&&y>=r) return s[r1].siz+s[r2].siz-s[r3].siz-s[r4].siz;
    int mid=(l+r)>>1;
    if(y<=mid) return query(l,mid,x,y,s[r1].ls,s[r2].ls,s[r3].ls,s[r4].ls);
    else if(x>mid) return query(mid+1,r,x,y,s[r1].rs,s[r2].rs,s[r3].rs,s[r4].rs);
    else return query(l,mid,x,y,s[r1].ls,s[r2].ls,s[r3].ls,s[r4].ls)
        +query(mid+1,r,x,y,s[r1].rs,s[r2].rs,s[r3].rs,s[r4].rs);
}
inline int getlca(int x,int y)
{
    if(deep[x]<deep[y]) swap(x,y);
    for(int i=16;~i;--i) if(deep[fa[x][i]]>=deep[y]) x=fa[x][i];
    if(x==y) return x;
    for(int i=16;~i;--i) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}
void dfs2(int x,int pre)
{
    rt[x]=rt[pre];
    for(int i=0;i<v[x].size();++i)
        update(1,tot,rt[x],rt[x],in[v[x][i]],1),
        update(1,tot,rt[x],rt[x],out[v[x][i]],-1);
    for(int i=head[x];i;i=nxt[i]) if(to[i]!=pre) dfs2(to[i],x);
}
void getans(int x,int y,int xx,int yy)
{
    if(x!=xx) ans+=query(1,tot,in[xx]+1,in[x],rt[x],rt[y],rt[xx],rt[yy]);
    ans+=query(1,tot,in[xx],in[y],rt[x],rt[y],rt[xx],rt[yy]);
    ans--;
}
ll gcd(ll x,ll y){return y?gcd(y,x%y):x;}
int main()
{
    scanf("%d%d",&n,&m); int x,y;
    for(int i=1;i!=n;++i)
    {
        scanf("%d%d",&x,&y);
        add(x,y); add(y,x);
    }
    for(int i=1;i<=m;++i)
    {
        scanf("%d%d",&x,&y);
        v[x].push_back(y); q[i].x=x; q[i].y=y;
    }
    dfs1(1,0); dfs2(1,0); int anc;
    for(int i=1;i<=m;++i)
    {
        anc=getlca(q[i].x,q[i].y);
        getans(q[i].x,q[i].y,anc,fa[anc][0]);
    }
    sort(q+1,q+m+1,cmp); int cnt=1;
    for(int i=2;i<=m;++i)
    {
        if(q[i].x==q[i-1].x&&q[i].y==q[i-1].y) cnt++;
        else ans-=(ll)cnt*(cnt-1)/2,cnt=1;
    }
    ans-=(ll)cnt*(cnt-1)/2;
    ll fm=(ll)m*(m-1)/2; ll gc=gcd(ans,fm);
    printf("%lld/%lld\n",ans/gc,fm/gc);
    return 0;
}

发表评论

邮箱地址不会被公开。 必填项已用*标注