recollection 题解

题意

给定一棵Trie,求Trie上所有$n$个串$s_{1\dots n}$中的$$f(i, j)=\text{lcp}(s_i, s_j)+\text{lcs}(s_i, s_j)\ \ (i!=j)$$的最大值

题解

有一种叫广义后缀数组的科技,说白了就是在Trie上建的后缀数组。其实构造与普通后缀数组很相似,唯一不同的是,每次倍增时求出的$rk$数组都必须保留下来,因为这个后缀数组是在Trie上跑的,不是在序列上跑的,所以那个$height_i\geq height_{i-1}-1$的规律就不存在惹,所以每次求$height$的时候就需要单独倍增一下,就要用到这个$rk$数组。

之后的话我们瞪着$$f(i, j)=\text{lcp}(s_i, s_j)+\text{lcs}(s_i, s_j)\ \ (i!=j)$$看上两秒,可以发现$\text{lcp}(s_i, s_j)$其实就相当于$i$和$j$的LCA的深度,那么只要枚举LCA,这个就没有了,剩下需要求的就是$\text{lcs}(s_i, s_j)$,也就是这个LCA的子树中的两点$i, j$的lcs的最大值。众所周知,假如有$rk_a<rk_b<rk_c<rk_d$,那么$\text{lcp}(b, c)\geq \text{lcp}(a, d)$。所以我们只要用一个线段树,维护一下这个子树中所有串按$rk$排序后每相邻两个的lcp,每次向父亲线段树合并一下就行了。

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>

template <typename T> void in(T &_){
    char c=getchar(); _=0; int fl=1;
    while(c<'0'||c>'9') fl=c=='-'?-1:fl, c=getchar();
    while(c>='0'&&c<='9') _=_*10+c-'0', c=getchar(); _*=fl;
}

#define MX 200010
int N, f[MX][19], s[MX], dep[MX];
int buc[MX], rk[MX][19], sa[MX], fir[MX], sec[MX], tmp[MX];

namespace ST{
    int w[MX][19];
    void init(){
        for(int j=1; j<19; j++)
            for(int i=1; i+(1<<(j-1))<=N; i++)
                w[i][j]=std::min(w[i][j-1], w[i+(1<<(j-1))][j-1]);
    }
    
    int query(int l, int r){
        int lim=log2(r-l+1);
        return std::min(w[l][lim], w[r-(1<<lim)+1][lim]);
    }
}

struct Data{
    int l, r, mx, siz;
    Data(){}
    Data(int l, int r, int mx, int siz) : l(l), r(r), mx(mx), siz(siz){}
};

Data operator +(Data a, Data b){
    if(!a.siz) return b;
    if(!b.siz) return a;
    Data c=Data(a.l, b.r, std::max(a.mx, b.mx), a.siz+b.siz);
    c.mx=std::max(c.mx, ST::query(a.r+1, b.l));
    return c;
}

struct Node{
    Node *lch, *rch;
    Data val;
    Node() : lch(NULL), rch(NULL){}
    void push_up(){
        if(!lch && !rch) return;
        if(!lch) val=rch->val;
        else if(!rch) val=lch->val;
        else val=lch->val+rch->val;
    }
}*rt[MX];

#define mid ((l+r)>>1)

void insert(Node *&v, int l, int r, int p){
    if(!v) v=new Node();
    if(l==r){
        v->val=Data(p, p, 0, 1);
        return;
    }
    if(p<=mid) insert(v->lch, l, mid, p);
    else       insert(v->rch, mid+1, r, p);
    v->push_up();
}

Node *merge(Node *v1, Node *v2, int l, int r){ 
    if(!v1) return v2;
    if(!v2) return v1;
    v1->lch=merge(v1->lch, v2->lch, l, mid);
    v1->rch=merge(v1->rch, v2->rch, mid+1, r);
    v1->push_up();
    return v1;
}

int main(){
    in(N);
    for(int i=1; i<=N; i++) rt[i]=NULL;
    for(int i=2; i<=N; i++){
        in(f[i][0]); in(s[i]);
        dep[i]=dep[f[i][0]]+1;
    }
    memset(buc, 0, sizeof(buc));
    for(int i=2; i<=N; i++) buc[s[i]+1]++;
    for(int i=1; i<=301; i++) buc[i]+=buc[i-1];
    for(int i=2; i<=N; i++) rk[i][0]=buc[s[i]]+1;
    for(int j=1; j<19; j++){
        for(int i=2; i<=N; i++){
            fir[i]=rk[i][j-1];
            sec[i]=f[i][j-1] ? rk[f[i][j-1]][j-1] : 0;
            f[i][j]=f[f[i][j-1]][j-1];
        }
        memset(buc, 0, sizeof(buc));
        for(int i=2; i<=N; i++) buc[sec[i]]++;
        for(int i=1; i<=N; i++) buc[i]+=buc[i-1];
        for(int i=2; i<=N; i++) tmp[N-1- --buc[sec[i]]]=i;
        memset(buc, 0, sizeof(buc));
        for(int i=2; i<=N; i++) buc[fir[i]]++;
        for(int i=1; i<=N; i++) buc[i]+=buc[i-1];
        for(int i=1; i<N; i++) sa[buc[fir[tmp[i]]]--]=tmp[i];
        for(int i=1, k=0; i<N; i++){
            if(!k) rk[sa[i]][j]=1;
            else if(fir[sa[i]]==fir[k] && sec[sa[i]]==sec[k])
                rk[sa[i]][j]=rk[k][j];
            else rk[sa[i]][j]=rk[k][j]+1;
            k=sa[i];
        }
    }
    for(int i=2; i<N; i++){
        int x=sa[i-1], y=sa[i], now=0;
        for(int j=18; ~j; j--)
            if(rk[x][j]==rk[y][j])
                x=f[x][j], y=f[y][j], now+=1<<j;
        ST::w[i][0]=now;
    }
    ST::init();
    for(int i=2; i<=N; i++) insert(rt[i], 1, N, rk[i][18]);
    int ans=0;
    for(int i=N; i; i--){
        if(rt[i]->val.siz>1) ans=std::max(ans, dep[i]+rt[i]->val.mx);
        if(f[i][0]) rt[f[i][0]]=merge(rt[f[i][0]], rt[i], 1, N);
    }
    printf("%d\n", ans);
    return 0;
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注