BZOJ-1112: [POI2008]砖块Klo

[文章目录]

Description

给一段长度为n的数列,将一个元素加一或者减一花费为1,求最小花费使得数列中有一段长度为k的连续子列高度相同。

每一段长度为k的子列使之相同的最小代价是将其每个元素变成其中位数的代价。splay维护数,查找中位数和区间和。

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
int n,k,h[101000];
ll fna=1ll<<60;
struct node
{
    node *fa,*c[2];
    int siz,w; ll sum;
    node(int x);
    void pushup();
}*null=new node(0),*root;
node::node(int x)
{
    fa=c[0]=c[1]=null;
    siz=null?1:0;
    w=sum=x;
}
void node::pushup()
{
    siz=c[0]->siz+c[1]->siz+1;
    sum=w+c[0]->sum+c[1]->sum;
}
void rotate(node *x)
{
    node *y=x->fa; int l=(x==y->c[1]),r=l^1;
    y->c[l]=x->c[r]; x->c[r]->fa=y;
    x->c[r]=y; x->fa=y->fa;
    if(y==y->fa->c[0]) y->fa->c[0]=x;
    else y->fa->c[1]=x;
    y->fa=x; y->pushup(); x->pushup();
    if(root==y) root=x;
}
void splay(node *x,node *tar)
{
    while(x->fa!=tar)
    {
        node *y=x->fa,*z=y->fa;
        if(z==tar) {rotate(x); break;}
        if((x==y->c[0])^(y==z->c[0])) rotate(x);
        else rotate(y); rotate(x);
    }
}
void initialize()
{
    root=new node(-1<<30);
    root->c[1]=new node(1<<30);
    root->c[1]->fa=root;
    root->pushup();
}
void insert(int x)
{
    node *l,*r,*p=root;
    while(p!=null)
    {
        if(p->w>=x) r=p,p=p->c[0];
        else l=p,p=p->c[1];
    } 
    splay(l,null); splay(r,root);
    root->c[1]->c[0]=new node(x);
    root->c[1]->c[0]->fa=root->c[1];
    root->c[1]->pushup(); root->pushup();
}
node *aft(node *x)
{
    if(x->c[1]!=null)
    {
        x=x->c[1];
        while(x->c[0]!=null) x=x->c[0];
    }
    else
    {
        while(x==x->fa->c[1]) x=x->fa;
        x=x->fa;
    }
    return x;
}
void del(int x)
{
    node *l,*r,*p=root;
    while(p!=null)
    {
        if(p->w>=x) r=p,p=p->c[0];
        else l=p,p=p->c[1];
    }
    splay(l,null); splay(aft(r),root);
    node *tmp=root->c[1]->c[0];
    root->c[1]->c[0]=null;
    root->c[1]->pushup();
    root->pushup();
}
void find(node *x,int y,node *z)
{
    while(1)
    {
        int tmp=x->c[0]->siz;
        if(y<=tmp) x=x->c[0];
        else
        {
            y-=tmp+1;
            if(!y) break;
            x=x->c[1];
        }
    }
    splay(x,z);
}
int main()
{
    scanf("%d%d",&n,&k);
    initialize(); int i; ll tmp,ans;
    for(i=1;i<=n;++i)
    {
        scanf("%d",h+i);
        insert(h[i]);
        if(i>k) del(h[i-k]);
        if(i>=k)
        {
            find(root,k/2+2,null);
            tmp=root->w;
            find(root,1,null);
            find(root,k/2+2,root);
            ans=tmp*root->c[1]->c[0]->siz-root->c[1]->c[0]->sum;
            find(root,k/2+2,null);
            find(root,k+2,root);
            ans=ans+root->c[1]->c[0]->sum-tmp*root->c[1]->c[0]->siz;
            fna=min(fna,ans);
        }
    }
    printf("%lld",fna);
    return 0;
}

发表评论

邮箱地址不会被公开。 必填项已用*标注