[Luogu P5115] Check,Check,Check one two!

Luogu P5115

其实思路本身挺简单的……建两个SAM,分别记为$S_1$和$S_2$,其中$S_1$上放的是正串,$S_2$上放的是反串,那么$\text{lcp}(i, j)\cdot \text{lcs}(i, j)$就是在两个SAM上的LCA的$len$之积。所以我们可以枚举第一个SAM的后缀树上的LCA,用边分树合并求出所有点对在第二个SAM的后缀树上的LCA的$len$之和,乘一下计入答案。

至于代码实现……我只想说……nmdwsmzmnx

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
#include <queue>

#define pb push_back
typedef unsigned long long ull;

char s[100010];
int N, K1, K2; 
ull ans;

struct SAM{
    #define MX 200010
    int ch[MX][26], fa[MX], len[MX], pos[MX], cnt, end;
    std::vector<int> to[MX];
    #undef MX
    
    SAM() : cnt(1), end(1){}
    
    void insert(int c){
        int p=end; end=++cnt; len[end]=len[p]+1;
        for(; p && !ch[p][c]; p=fa[p]) ch[p][c]=end;
        if(!p){fa[end]=1; return;}
        int q=ch[p][c];
        if(len[q]==len[p]+1){fa[end]=q; return;}
        ++cnt; for(int i=0; i<26; i++) ch[cnt][i]=ch[q][i];
        fa[cnt]=fa[q]; fa[q]=fa[end]=cnt; len[cnt]=len[p]+1;
        for(; p && ch[p][c]==q; p=fa[p]) ch[p][c]=cnt;
    }
    
    void build(){
        for(int i=1; i<=N; i++){
            insert(s[i]-'a');
            pos[i]=end;
        }
        for(int i=1; i<=cnt; i++) if(fa[i]){
            to[fa[i]].pb(i);
            to[i].pb(fa[i]);
        }
    }
}S1, S2; 

namespace ED{
    #define MX 800010
    struct Node{
        ull val[2];
        Node *ch[2];
        Node(){
            ch[0]=ch[1]=NULL;
            val[0]=val[1]=0;
        }
    }*rt[MX>>2];
    int cnt, dfn, lsb[MX], rsb[MX], dep[MX];
    int siz[MX], ch[MX][2], RT, f[MX], val[MX][25];
    bool vis[MX];
    std::vector<int> to[MX], vec;
    std::queue<int> q;
    
    void dfs1(int x, int fa){
        val[x][0]=S2.len[x];
        for(auto i: S2.to[x]) if(i^fa) q.push(i);
        while(q.size()>2){
            int u=q.front(); q.pop();
            int v=q.front(); q.pop();
            val[++cnt][0]=val[x][0];
            to[cnt].pb(u); to[u].pb(cnt);
            to[cnt].pb(v); to[v].pb(cnt); 
            q.push(cnt);
        }
        while(!q.empty()){
            int u=q.front(); q.pop();
            to[x].pb(u); to[u].pb(x);
        }
        for(auto i: S2.to[x]) if(i^fa) dfs1(i, x);
    }
    
    void dfs2(int x, int fa){
        lsb[x]=++dfn;
        for(auto i: to[x]) if(i^fa){
            dep[i]=dep[x]+1;
            dfs2(i, x);
        }
        rsb[x]=dfn;
    }
    
    void dfs3(int x, int fa){
        f[x]=fa; siz[x]=1; vec.pb(x);
        for(auto i: to[x]) if(i!=fa && !vis[i]){
            dfs3(i, x);
            siz[x]+=siz[i];
        }
    }
    
    void dfs4(int x, int fa, int w, int d){
        val[x][d]=abs(w);
        for(auto i: to[x]) if(i!=fa && !vis[i])
            dfs4(i, x, std::min(val[i][0], w), d);
    }
    
    void divide(int &v, int x, int d){
        vec.clear();
        dfs3(x, 0);
        if(vec.size()<=1) return;
        int mn=0x3f3f3f3f;
        for(auto i: vec)
            if(std::max(siz[i], siz[x]-siz[i])<mn)
                mn=std::max(siz[i], siz[x]-siz[i]), v=i;
        int u=f[v]; 
        if(dep[u]>dep[v]) std::swap(u, v);
        bool tmp=vis[u]; vis[u]=true;
        divide(ch[v][0], v, d+1); vis[u]=tmp;
        tmp=vis[v]; vis[v]=true;
        divide(ch[v][1], u, d+1); vis[v]=tmp;
        dfs4(v, u, -1, d); dfs4(u, v, val[u][0], d); 
    }
    
    Node *build(int x, int p, int d){
        if(!x) return NULL;
        Node *v=new Node();
        ull tmp=val[p][d]<=K2 ? val[p][d] : 0;
        if(lsb[x]<=lsb[p] && lsb[p]<=rsb[x]){
            v->val[0]=tmp;
            v->ch[0]=build(ch[x][0], p, d+1);
        }else{
            v->val[1]=tmp;
            v->ch[1]=build(ch[x][1], p, d+1);
        }
        return v;
    }
    
    void work(){
        cnt=S2.cnt; 
        dfs1(1, 0);
        dfs2(1, 0);
        divide(RT, 1, 1);
        for(int i=1; i<=N; i++)
            rt[S1.pos[i]]=build(RT, S2.pos[N-i+1], 1);
    }
    
    Node *merge(Node *x, Node *y, ull w){
        if(!x) return y;
        if(!y) return x;
        ans+=w*(x->val[0]*y->val[1]+x->val[1]*y->val[0]);
        x->val[0]+=y->val[0], x->val[1]+=y->val[1];
        x->ch[0]=merge(x->ch[0], y->ch[0], w);
        x->ch[1]=merge(x->ch[1], y->ch[1], w);
        return x;
    }
}

void dfs(int x, int fa){
    ull tmp=S1.len[x]<=K1 ? S1.len[x] : 0;
    for(auto i: S1.to[x]) if(i!=fa){
        dfs(i, x);
        ED::rt[x]=ED::merge(ED::rt[x], ED::rt[i], tmp);
    }
}

int main(){
    scanf("%s%d%d", s+1, &K1, &K2);
    N=strlen(s+1);
    S1.build();
    std::reverse(s+1, s+N+1);
    S2.build();
    ED::work();
    dfs(1, 0);
    printf("%llu\n", ans);
    return 0;
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注