JDOJ-2183: 普通平衡树

[文章目录]

Description

此为平衡树系列第一道:普通平衡树您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入x数
2. 删除x数(若有多个相同的数,因只删除一个)
3. 查询x数的排名(若有多个相同的数,因输出最小的排名)
4. 查询排名为x的数
5. 求x的前驱(前驱定义为小于x,且最大的数)
6. 求x的后继(后继定义为大于x,且最小的数)

1.treap 179ms

treap的模板题,突然想起来以前学过的一个叫switch的东西顺便用了一下。

整体思想就是随机化,利用堆的性质,随机左右旋转,期望树高为logn。

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
int n,sw,x,befor,aftr,tot,root;
struct node
{
    int size,lson,rson,rnd,val,num;
}a[501000];
void push(int k)//因为只是改变数的结构,val不需要更新,只更新size
{
    a[k].size=a[k].num+a[a[k].lson].size+a[a[k].rson].size;
}
void lturn(int &k)//右儿子提上来
{
    int t=a[k].rson;
    a[k].rson=a[t].lson;
    a[t].lson=k;
    push(k);push(t);k=t;
}
void rturn(int &k)//左儿子提上来
{
    int t=a[k].lson;
    a[k].lson=a[t].rson;
    a[t].rson=k;
    push(k);push(t);k=t;
}
void update(int &k,int temp)
{
    if(!k)//加分支
    {
        k=++tot;
        a[k].num=a[k].size=1;
        a[k].val=temp;
        a[k].rnd=rand();
        return ;
    }
    a[k].size++;
    if(temp==a[k].val) a[k].num++;
    else if(temp<a[k].val)//维护平衡
    {
        update(a[k].lson,temp);
        if(a[a[k].lson].rnd<a[k].rnd) rturn(k);
    }
    else
    {
        update(a[k].rson,temp);
        if(a[a[k].rson].rnd<a[k].rnd) lturn(k);
    }
}
void del(int &k,int temp)
{
    if(!k) return ;//没找到
    if(temp==a[k].val)
    {
        if(a[k].num>1) a[k].num--,a[k].size--;
        else if(a[k].lson*a[k].rson==0)//因为传进来的是地址,所以相当于直接把非空的儿子提了上去,不破坏中序遍历
            k=a[k].lson+a[k].rson;
        else if(a[a[k].lson].rnd<a[a[k].rson].rnd)//由于num==1,所以通过不破坏中序遍历的左右旋转将val==temp的下沉到底部del
            rturn(k),del(k,temp);
        else lturn(k),del(k,temp);
    }
    else if(temp<a[k].val) a[k].size--,del(a[k].lson,temp);//更新size+继续搜寻
    else a[k].size--,del(a[k].rson,temp);
}
int ask_num(int k,int temp)//排名为temp的数值
{
    if(!k) return 0;
    if(temp<=a[a[k].lson].size) return ask_num(a[k].lson,temp);
    else if(temp>a[a[k].lson].size+a[k].num)
        return ask_num(a[k].rson,temp-a[k].num-a[a[k].lson].size);//进入右儿子时更新排名
    else return a[k].val;
}
int ask_rank(int k,int temp)//数值为temp的排名,没有返回0
{
    if(!k) return 0;
    if(a[k].val==temp) return a[a[k].lson].size+1;
    else if(temp>a[k].val)
        return a[a[k].lson].size+a[k].num+ask_rank(a[k].rson,temp);//进入右儿子排名加上前面的个数
    else return ask_rank(a[k].lson,temp);
}
void ask_before(int k,int temp)//<x的数
{
    if(!k) return ;
    if(temp>a[k].val)
    {
        befor=a[k].val;//记录,防止一直下搜时无解
        ask_before(a[k].rson,temp);
    }
    else ask_before(a[k].lson,temp);
}
void ask_after(int k,int temp)//>x的数
{
    if(!k) return ;
    if(temp<a[k].val)
    {
        aftr=a[k].val;//记录
        ask_after(a[k].lson,temp);
    }
    else ask_after(a[k].rson,temp);
}
int main()
{
    scanf("%d",&n);
    while(n--)
    {
        scanf("%d%d",&sw,&x);sw--;
        switch(sw)//神奇的switch
        {
            case 0:update(root,x);break;//根节点由于要左右旋转可能标号改变,直接定义变量,有旋转的操作都传地址
            case 1:del(root,x);break;
            case 2:printf("%d\n",ask_rank(root,x));break;
            case 3:printf("%d\n",ask_num(root,x));break;
            case 4:befor=0;ask_before(root,x);printf("%d\n",befor);break;//清0 befor和aftr
            case 5:aftr=0;ask_after(root,x);printf("%d\n",aftr);break;
        }
    }   
    return 0;
}

好像代码好看了

2.splay 271ms

大常数啊,既难写啊,还慢

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
int n;
struct node
{
    node *fa,*c[2];
    int siz,w;
    node(int x);
    void pushup(){siz=c[0]->siz+c[1]->siz+1;}
}*null=new node(0),*root;
node::node(int x)
{
    fa=c[0]=c[1]=null;
    siz=null?1:0;
    w=x;
}
void initialize()
{
    root=new node(-1<<30);
    root->c[1]=new node(1<<30);
    root->c[1]->fa=root;
    root->pushup();
}

void rotate(node *x)
{
    node *y=x->fa; bool l=(x==y->c[1]),r=l^1;
    y->c[l]=x->c[r]; x->c[r]->fa=y;
    x->c[r]=y; x->fa=y->fa;
    if(y==y->fa->c[0]) y->fa->c[0]=x;
    else y->fa->c[1]=x;
    y->fa=x; y->pushup(); x->pushup();
    if(root==y) root=x;
}
void splay(node *x,node *tar)
{
    while(x->fa!=tar)
    {
        node *y=x->fa,*z=y->fa;
        if(z==tar){rotate(x); break;}
        if((x==y->c[0])^(y==z->c[0])==0) rotate(y);
        else rotate(x); rotate(x);
    }
}
node *pos(int x)
{
    node *p=root;
    while(1)
    {
        if(x<=p->c[0]->siz) p=p->c[0];
        else
        {
            x-=p->c[0]->siz+1;
            if(!x) return p;
            p=p->c[1];
        }
    }
}
void Insert()
{
    int x; scanf("%d",&x);
    node *p=root,*l,*r;
    while(p!=null)
    {
        if(p->w>=x) r=p,p=p->c[0];
        else l=p,p=p->c[1];
    }
    splay(l,null); splay(r,root);
    root->c[1]->c[0]=new node(x);
    root->c[1]->c[0]->fa=root->c[1];
    root->c[1]->pushup();
    root->pushup();
}
node *after(node *x)
{
    if(x->c[1]!=null)
    {
        x=x->c[1];
        while(x->c[0]!=null) x=x->c[0];
        return x;
    }
    while(x==x->fa->c[1]) x=x->fa;
    return x->fa;   
}
void Delete()
{
    int x; scanf("%d",&x);
    node *p=root,*l,*r;
    while(p!=null)
    {
        if(p->w>=x) r=p,p=p->c[0];
        else l=p,p=p->c[1];
    }
    r=after(r);
    splay(l,null); splay(r,root);
    root->c[1]->c[0]=null;
    root->c[1]->pushup();
    root->pushup();
}
void Qrank()
{
    int x; scanf("%d",&x);
    node *p=root;int ans=0;
    while(p!=null)
    {
        if(p->w>=x) p=p->c[0];
        else ans+=(p->c[0]->siz+1),p=p->c[1];
    }
    printf("%d\n",ans);
}
void Qdata()
{
    int x; scanf("%d",&x);
    printf("%d\n",pos(x+1)->w);
}
void Qlower()
{
    int x; scanf("%d",&x);
    node *p=root,*ans;
    while(p!=null)
    {
        if(p->w>=x) p=p->c[0];
        else ans=p,p=p->c[1];
    }
    printf("%d\n",ans->w);
}
void Qupper()
{
    int x; scanf("%d",&x);
    node *p=root,*ans;
    while(p!=null)
    {
        if(p->w<=x) p=p->c[1];
        else ans=p,p=p->c[0];
    }
    printf("%d\n",ans->w);
}
int main()
{
    initialize();
    scanf("%d",&n);
    int sw;
    while(n--)
    {
        scanf("%d",&sw);
        if(sw==1) Insert();
        else if(sw==2) Delete();
        else if(sw==3) Qrank();
        else if(sw==4) Qdata();
        else if(sw==5) Qlower();
        else Qupper();
    }   
    return 0;
}
3.非旋转treap 251ms

好写,能过,好调。啊,完美

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define N 101000
#define mp(x,y) make_pair(x,y)
using namespace std;
typedef pair<int,int> pr;
int root,n,tot,op;
int ls[N],rs[N],w[N],siz[N],key[N];
void updata(int pos)
{
    siz[pos]=siz[ls[pos]]+siz[rs[pos]]+1;
}
pr split(int x,int y)
{
    if(y==0) return mp(0,x);
    if(y==siz[ls[x]])
    {
        int t=ls[x]; ls[x]=0;
        updata(x);
        return mp(t,x);
    }
    if(y==siz[ls[x]]+1)
    {
        int t=rs[x];
        rs[x]=0; updata(x);
        return mp(x,t);
    }
    if(y<siz[ls[x]])
    {
        pr t=split(ls[x],y);
        ls[x]=t.second; updata(x);
        return mp(t.first,x);
    }
    pr t=split(rs[x],y-siz[ls[x]]-1);
    rs[x]=t.first; updata(x);
    return mp(x,t.second);
}
int merge(int x,int y)
{
    if(!x||!y) return x|y;
    if(key[x]<key[y])
    {
        rs[x]=merge(rs[x],y);
        updata(x);
        return x;
    }
    ls[y]=merge(x,ls[y]);
    updata(y);
    return y;
}
void insert(int x,int y)
{
    if(!root)
    {
        w[++tot]=y; siz[tot]=1; key[tot]=rand()*rand();
        root=tot;
        return ;
    }
    int l=0;
    while(x)
    {
        if(w[x]>=y) x=ls[x];
        else l+=siz[ls[x]]+1,x=rs[x];
    }
    pr t=split(root,l);
    w[++tot]=y; siz[tot]=1; key[tot]=rand()*rand();
    t.first=merge(t.first,tot);
    root=merge(t.first,t.second);
}
void del(int x,int y)
{
    int l=0;
    while(x)
    {
        if(w[x]>=y) x=ls[x];
        else l+=siz[ls[x]]+1,x=rs[x];
    }
    pr t=split(root,l);
    t.second=split(t.second,1).second;
    root=merge(t.first,t.second);
}
void q_rank(int x,int y)
{
    int ans=0;
    while(x)
    {
        if(w[x]>=y) x=ls[x];
        else ans+=siz[ls[x]]+1,x=rs[x];
    }
    printf("%d\n",ans+1);
}
void q_val(int x,int y)
{
    while(x)
    {
        if(y<=siz[ls[x]]) x=ls[x];
        else
        {
            y-=siz[ls[x]]+1;
            if(!y){printf("%d\n",w[x]); return ;}
            x=rs[x];
        }
    }
}
void q_before(int x,int y)
{
    int l=0;
    while(x)
    {
        if(w[x]>=y) x=ls[x];
        else l=x,x=rs[x];
    }
    printf("%d\n",w[l]);
}
void q_after(int x,int y)
{
    int r;
    while(x)
    {
        if(w[x]<=y) x=rs[x];
        else r=x,x=ls[x];
    }
    printf("%d\n",w[r]);
}
int main()
{
    scanf("%d",&n);
    int x;
    while(n--)
    {
        scanf("%d%d",&op,&x);
        switch(op)
        {
            case 1:insert(root,x); break;
            case 2:del(root,x); break;
            case 3:q_rank(root,x); break;
            case 4:q_val(root,x); break;
            case 5:q_before(root,x); break;
            case 6:q_after(root,x); break;
        }
    }
    return 0;
}
4.替罪羊树。。。153ms。

debug了一上午。。。快,还好写,好想。哇呜!!!常数这么小啊。开心。

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const double arf=0.75;
int cgg;
struct node
{
    node *c[2];
    int w,siz,ms;
    node(int x);
    void pushup();
}*null=new node(0),*root=null,*cg[101000];
node::node(int x)
{
    c[0]=c[1]=null;
    w=x; siz=ms=null?1:0;
}
void node::pushup() {siz=c[0]->siz+c[1]->siz+1; ms=siz;}
void flatten(node *x)
{
    if(x==null) return ;
    flatten(x->c[0]);
    cg[++cgg]=x;
    flatten(x->c[1]);
}
node *build(int l,int r)
{
    if(l>r) return null;
    int mid=l+r>>1;
    cg[mid]->c[0]=build(l,mid-1);
    cg[mid]->c[1]=build(mid+1,r);
    cg[mid]->pushup(); 
    return cg[mid];
}
void rebuild(node *&x)
{
    cgg=0; flatten(x);
    x=build(1,cgg);
}
void insert(node *&x,int y,int fl)
{
    if(x==null) {x=new node(y); return ;}
    x->siz++; x->ms=max(x->siz,x->ms);
    if(y<=x->w)
    {
        if(!fl&&x->c[0]->siz+1>x->ms*arf) insert(x->c[0],y,1),rebuild(x);
        else insert(x->c[0],y,0);
    }
    else
    {
        if(!fl&&x->c[1]->siz+1>x->ms*arf) insert(x->c[1],y,1),rebuild(x);
        else insert(x->c[1],y,0);
    }
}
int del(node *&x,int y)
{
    if(x==null) return 0;
    if(x->w==y)
    {
        if(x->c[1]==null||x->c[0]==null)
        {
            node *tmp=x;
            x=x->c[x->c[0]==null];
            free(tmp);      
        }
        else
        {
            node *p=x->c[1],*ff=x;
            while(p->c[0]!=null) ff=p,p->siz--,p=p->c[0];
            ff->c[ff->c[1]==p]=p->c[1];
            x->w=p->w; x->siz--; free(p);
        }
        return 1;
    }
    if(del(x->c[x->w<y],y)) x->siz--;
}
int q_rank(node *x,int y)
{
    if(x==null) return 1;
    if(x->w>=y) return q_rank(x->c[0],y);
    return x->c[0]->siz+1+q_rank(x->c[1],y);
}
int q_date(node *x,int y)
{
    while(1)
    {
        if(y<=x->c[0]->siz) x=x->c[0];
        else
        {
            y-=x->c[0]->siz+1;
            if(!y) return x->w;
            x=x->c[1];
        }
    }
}
int q_fro(node *x,int y)
{
    int re;
    while(x!=null)
    {
        if(x->w>=y) x=x->c[0];
        else re=x->w,x=x->c[1];
    }
    return re;
}
int q_aft(node *x,int y)
{
    int re;
    while(x!=null)
    {
        if(x->w<=y) x=x->c[1];
        else re=x->w,x=x->c[0];
    }
    return re;
}
int main()
{
    int n,sw,x; scanf("%d",&n);
    while(n--)
    {
        scanf("%d%d",&sw,&x);
        switch(sw)
        {
            case 1:insert(root,x,0); break;
            case 2:del(root,x); break;
            case 3:printf("%d\n",q_rank(root,x)); break;
            case 4:printf("%d\n",q_date(root,x)); break;
            case 5:printf("%d\n",q_fro(root,x));  break;
            case 6:printf("%d\n",q_aft(root,x));  break;
        }
    }
    return 0;
}
5.SBT(size balanced tree) 160ms

受大师(CQzhangyu)怂恿学了一发SBT,据说常数小,比treap还快,还好写。学了一发,好像的确这样诶。随随便便就D掉了treap。。。

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define N 101000 
int rot,tot;
struct SBT
{
    int ls,rs,siz,w;
}t[N];
void zig(int &x)
{
    int l=t[x].ls;
    t[x].ls=t[l].rs; t[l].rs=x; 
    t[l].siz=t[x].siz;
    t[x].siz=t[t[x].ls].siz+t[t[x].rs].siz+1;
    x=l;
}
void zag(int &x)
{
    int r=t[x].rs;
    t[x].rs=t[r].ls; t[r].ls=x;
    t[r].siz=t[x].siz;
    t[x].siz=t[t[x].ls].siz+t[t[x].rs].siz+1;
    x=r; 
}
void maintain(int &k,bool fl)
{
    if(!fl)
    {
        if(t[t[t[k].ls].ls].siz>t[t[k].rs].siz) zig(k);
        else if(t[t[t[k].ls].rs].siz>t[t[k].rs].siz) zag(t[k].ls),zig(k);
        else return ;
    }
    else
    {
        if(t[t[t[k].rs].rs].siz>t[t[k].ls].siz) zag(k);
        else if(t[t[t[k].rs].ls].siz>t[t[k].ls].siz) zig(t[k].rs),zag(k);
        else return ;
    }
    maintain(t[k].ls,0); maintain(t[k].rs,1);
    maintain(k,0); maintain(k,1);
}
void del(int &k,int x)
{
    if(!k) return ; t[k].siz--;
    if(x<t[k].w) del(t[k].ls,x);
    else if(x>t[k].w) del(t[k].rs,x);
    else
    {
        if(!t[k].ls||!t[k].rs) k=t[k].ls|t[k].rs;
        else
        {
            int p=t[k].ls,las=k;
            while(t[p].rs) t[p].siz-- , las=p , p=t[p].rs;
            if(p==t[las].ls) t[las].ls=t[p].ls;
            else t[las].rs=t[p].ls;
            t[p].ls=t[k].ls; t[p].rs=t[k].rs; t[p].siz=t[k].siz; k=p;
        }
    }
}
void insert(int &k,int x)
{
    if(!k) k=++tot,t[k].w=x,t[k].siz=1;
    else
    {
        t[k].siz++;
        if(x<t[k].w) insert(t[k].ls,x);
        else insert(t[k].rs,x);
        maintain(k,x>=t[k].w);
    }
}
int query_data(int x)
{
    int p=rot,tmp;
    while(1)
    {
        tmp=t[t[p].ls].siz;
        if(x<=tmp) p=t[p].ls;
        else
        {
            x-=(tmp+1);
            if(!x) return t[p].w;
            p=t[p].rs;
        }
    }
}
int query_rank(int x)
{
    int p=rot,re=1;
    while(p)
    {
        if(x<=t[p].w) p=t[p].ls;
        else
        {
            re+=t[t[p].ls].siz+1;
            p=t[p].rs;
        }
    }
    return re;
}
int query_fro(int x)
{
    int p=rot,re;
    while(p)
    {
        if(x<=t[p].w) p=t[p].ls;
        else re=t[p].w, p=t[p].rs;
    }
    return re;
}
int query_aft(int x)
{
    int p=rot,re;
    while(p)
    {
        if(x>=t[p].w) p=t[p].rs;
        else re=t[p].w, p=t[p].ls;
    }
    return re;
}
int main()
{
    int n,op,x;
    scanf("%d",&n);
    while(n--)
    {
        scanf("%d%d",&op,&x);
        switch(op)
        {
            case 1: insert(rot,x); break;
            case 2: del(rot,x); break;
            case 3: printf("%d\n",query_rank(x)); break;
            case 4: printf("%d\n",query_data(x)); break;
            case 5: printf("%d\n",query_fro(x)); break;
            case 6: printf("%d\n",query_aft(x)); break;
        }
    }
    return 0;
}

发表评论

邮箱地址不会被公开。 必填项已用*标注