[Luogu P2664] 树上游戏

Luogu P2664

考虑点分治。

一个点$u$的答案会由以下两种点对贡献:

  1. 以$u$为分治中心时的$u$到联通块一点的链上的颜色数
  2. 以$x$为分治中心时一条以$u$开始跨越$x$的链上的颜色数

第一个是很好算的,从$u$开始dfs,如果遇到一个点$v$,$c_v$在$v$到$u$的路径上没有出现过,设以$v$为根的子树大小为$siz_v$,那么这个点$v$就为$u$的答案贡献了$siz_v$。

第二个的话按照点分治的套路我们是要把这个路径在点$x$处断成两半进行计算的。假设$u$所在的联通块的根是$y$,$u$到$x$的路径上的颜色数为$num$,那么$u$到$x$的链对答案的贡献就是$(siz_x-siz_y)\cdot num$。而$x$到链的终点的贡献就是第一个算出的答案减去$y$子树对答案的贡献

#include <iostream>
#include <cstdio>

#define MXN 100010
typedef long long ll;

struct Edge *head[MXN];
struct Edge{
    int to; Edge *ne;
    Edge(int x, int y) : to(y), ne(head[x]){}
};
int N, c[MXN], rt, tot_siz, siz[MXN], mx[MXN], num;
ll ans[MXN], sum, sumc[MXN];
bool vis[MXN], fl[MXN];

void get_rt(int x, int fa){
    siz[x]=1; mx[x]=0;
    for(Edge *i=head[x]; i; i=i->ne) 
        if(i->to!=fa && !vis[i->to]){
            get_rt(i->to, x);
            siz[x]+=siz[i->to];
            mx[x]=std::max(mx[x], siz[i->to]);
        }
    mx[x]=std::max(mx[x], tot_siz-siz[x]);
    if(!rt || mx[x]<mx[rt]) rt=x; 
}

void dfs1(int x, int fa){
    siz[x]=1;
    bool tmp=fl[c[x]];
    fl[c[x]]=true;
    for(Edge *i=head[x]; i; i=i->ne) 
        if(i->to!=fa && !vis[i->to]){
            dfs1(i->to, x);
            siz[x]+=siz[i->to];
        }
    if(!tmp){
        fl[c[x]]=false;
        sum+=siz[x];
        sumc[c[x]]+=siz[x];
    }
}

void dfs2(int x, int fa, int S){
    bool tmp=fl[c[x]];
    fl[c[x]]=true;
    if(!tmp) num++, sum-=sumc[c[x]];
    ans[x]+=sum+S*num;
    for(Edge *i=head[x]; i; i=i->ne) 
        if(i->to!=fa && !vis[i->to])
            dfs2(i->to, x, S);
    if(!tmp){
        fl[c[x]]=false;
        num--, sum+=sumc[c[x]];
    }
}

void modify(int x, int fa, int opt){
    bool tmp=fl[c[x]];
    fl[c[x]]=true;
    for(Edge *i=head[x]; i; i=i->ne)
        if(i->to!=fa && !vis[i->to])
            modify(i->to, x, opt);
    if(!tmp){
        fl[c[x]]=false;
        sum+=opt*siz[x];
        sumc[c[x]]+=opt*siz[x];
    }
}

void clear(int x, int fa){
    sumc[c[x]]=0;
    for(Edge *i=head[x]; i; i=i->ne) 
        if(i->to!=fa && !vis[i->to])
            clear(i->to, x);
}

void divide(int x){
    vis[x]=true;
    dfs1(x, 0);
    ans[x]+=sum;
    for(Edge *i=head[x]; i; i=i->ne) if(!vis[i->to]){
        fl[c[x]]=true;
        sum-=siz[i->to];
        sumc[c[x]]-=siz[i->to];
        modify(i->to, 0, -1);
        dfs2(i->to, 0, siz[x]-siz[i->to]);
        modify(i->to, 0, 1);
        sum+=siz[i->to];
        sumc[c[x]]+=siz[i->to];
        fl[c[x]]=false;
    }
    sum=0, num=0;
    clear(x, 0);
    for(Edge *i=head[x]; i; i=i->ne) if(!vis[i->to]){
        rt=0, tot_siz=siz[i->to];
        get_rt(i->to, 0);
        divide(rt);
    }
}

int main(){
    int x, y;
    scanf("%d", &N);
    for(int i=1; i<=N; i++) scanf("%d", &c[i]);
    for(int i=1; i<N; i++){
        scanf("%d%d", &x, &y);
        head[x]=new Edge(x, y);
        head[y]=new Edge(y, x);
    }
    rt=0, tot_siz=N;
    get_rt(1, 0);
    divide(1);
    for(int i=1; i<=N; i++) printf("%lld\n", ans[i]);
    return 0;
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注