[ZJOI 2019] 语言

LOJ #3046.

考虑一下对于一个$u$,所有的$v$的集合是什么。很显然,就是所有经过$u$的链的链并。

因为这些链都经过同一个点$u$,所以链并就是一个联通块。设所有链的端点组成的点集为$S$,在点集内已经按dfs序排好序,那么链并大小就是$$\sum \limits_{i=1}^{|S|}dep_{S_i} – \sum \limits_{i=1}^{|S| – 1}dep_{lca\{S_{i}, S_{i+1}\}}-dep_{lca(S)} + 1$$

很显然,这样的链并的贡献是支持插入链和删除链的,直接以dfs序开一个线段树上线段树就可以了

而对于一条链,这条链上所有点的联通块都要包括这条链,那么我们就可以用树上差分的思路,在$s, t$插入这条链,在$lca(s, t)$和$lca(s, t)$的父亲处分别删除一次这条链,剩下的部分线段树合并维护就可以了

#include <iostream>
#include <cstdio>
#include <vector>
#include <cmath>

#define pb push_back
#define inf 0x3f3f3f3f
#define MAX_N 100010
typedef long long ll;

int n, m, dfn[MAX_N], dep[MAX_N], pos[MAX_N], fir[MAX_N], fat[MAX_N], dfn_cnt, tim_cnt;
std::vector<int> to[MAX_N], del[MAX_N];

int lca(int x, int y);

namespace SEG{
    #define mid ((l + r) >> 1)
    struct Node{
        int cnt, mn, mx, sum;
        Node *ls, *rs;
        Node() : ls(NULL), rs(NULL), cnt(0), mn(inf), mx(0), sum(0){}
        void push_up(){
            sum = 0, mn = inf, mx = 0;
            if(ls){
                sum += ls->sum;
                mn = std::min(mn, ls->mn);
                mx = std::max(mx, ls->mx);
            }
            if(rs){
                sum += rs->sum;
                mn = std::min(mn, rs->mn);
                mx = std::max(mx, rs->mx);
            }
            if(ls && rs && ls->mx != 0 && rs->mn != inf)
                sum -= dep[lca(pos[ls->mx], pos[rs->mn])];
        }
    }*rt[MAX_N];

    void mdf(Node *&v, int l, int r, int p, int val){
        if(!v) v = new Node();
        if(l == r){
            v->cnt += val;
            v->sum = v->cnt ? dep[pos[l]] : 0;
            v->mn = v->cnt ? l : inf;
            v->mx = v->cnt ? l : 0;
            return;
        }
        if(p <= mid) mdf(v->ls, l, mid, p, val);
        else mdf(v->rs, mid + 1, r, p, val);
        v->push_up();
    }

    int qry(Node *v){
        if(!v) return 0;
        return v->sum - (v->mx ? dep[lca(pos[v->mn], pos[v->mx])] : 0);
    }

    Node *merge(Node *x, Node *y, int l, int r){
        if(!x) return y; if(!y) return x;
        if(l == r){
            x->cnt += y->cnt; x->sum |= y->sum;
            x->mn = std::min(x->mn, y->mn), x->mx = std::max(x->mx, y->mx);
            return x;
        }
        x->ls = merge(x->ls, y->ls, l, mid);
        x->rs = merge(x->rs, y->rs, mid + 1, r);
        x->push_up();
        return x;
    }
    #undef mid
}
using SEG::rt;

namespace ST{
    int f[MAX_N << 2][19];
    void build(){
        for(int j = 1; j < 19; j++)
            for(int i = 1; i <= tim_cnt; i++)
                f[i][j] = std::min(f[i][j - 1], f[i + (1<<(j-1))][j - 1]);
    }
    int query(int x, int y){
        if(x > y) std::swap(x, y);
        int lim = log2(y - x + 1);
        return std::min(f[x][lim], f[y - (1<<lim) + 1][lim]);
    }
}

void dfs1(int x, int fa){
    fat[x] = fa;
    dfn[x] = ++dfn_cnt;
    pos[dfn[x]] = x;
    ST::f[++tim_cnt][0] = dfn[x];
    fir[x] = tim_cnt;
    for(auto i: to[x]) if(i ^ fa){
        dep[i] = dep[x] + 1;
        dfs1(i, x);
        ST::f[++tim_cnt][0] = dfn[x];
    }
}

int lca(int x, int y){
    return pos[ST::query(fir[x], fir[y])];
}

ll ans = 0;

void dfs2(int x, int fa){
    for(auto i: to[x]) if(i ^ fa){
        dfs2(i, x);
        rt[x] = SEG::merge(rt[x], rt[i], 1, n);
    }
    for(auto i: del[x])
        SEG::mdf(rt[x], 1, n, i, -1);
    ans += SEG::qry(rt[x]);
}

int main(){
    scanf("%d%d", &n, &m);
    for(int i = 1, x, y; i < n; i++){
        scanf("%d%d", &x, &y);
        to[x].pb(y), to[y].pb(x);
    }
    dfs1(1, 0); ST::build();
    for(int i = 1, x, y, l; i <= m; i++){
        scanf("%d%d", &x, &y);
        l = lca(x, y);
        SEG::mdf(rt[x], 1, n, dfn[x], 1);
        SEG::mdf(rt[x], 1, n, dfn[y], 1);
        SEG::mdf(rt[y], 1, n, dfn[x], 1);
        SEG::mdf(rt[y], 1, n, dfn[y], 1);
        del[l].pb(dfn[x]); del[fat[l]].pb(dfn[x]);
        del[l].pb(dfn[y]); del[fat[l]].pb(dfn[y]);
    }
    dfs2(1, 0);
    printf("%lld\n", ans >> 1);
    return 0;
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注