[CTSC 2010] 珠宝商

Luogu P4218

将长度为$M$的串称作$s$,将点$i$上的字母称作$c_i$

我们很显然可以有一种$O(n^2)$的方法来统计

对于串$s$建出一个SAM,然后枚举路径$u, v$的$u$,以$u$为根DFS一下,顺便在SAM上沿着DAG走,将SAM上经过的所有的点所代表的子串的数量都累加起来

如果点很少可以这样,点很多的话就要考虑别的方法

先咕一会……等会继续……

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>

namespace IO{
#define MX (1 << 21)
    char ibuf[MX], obuf[MX], *iS, *iT, *oS = obuf, *oT = obuf + MX, c;
    inline char getc(){ return getchar();
        if(iS == iT) iT = (iS = ibuf) + fread(ibuf, 1, MX, stdin);
        return *iS++;
    }
    inline void flush(){
        fwrite(obuf, 1, oS - obuf, stdout);
        oS = obuf;
    }
    inline void putc(char _){
        if(oS == oT) flush(); *oS++ = _;
    }
    struct autof{ ~autof(){flush();} }ff;
    template <typename T> inline void in(T &_){
        c = getc(), _ = 0; int fl = 1;
        while(c < '0' || c > '9') 
            fl = c == '-' ? -1 : fl, c = getc();
        while(c >= '0' && c <= '9') 
            _ = _ * 10 + c - '0', c = getc(); 
        _ *= fl;
    }
    template <typename T> void out(T _){
        if(_ < 0) putc('-'), out(-_);
        else{
            if(_ / 10) out(_ / 10); 
            putc('0' + _ % 10);
        }
    }
    inline void getstr(char *s, int l){
        while(c < 'a' || c > 'z') c = getc();
        for(int i = 1; i <= l; i++)
            s[i] = c, c = getc();
    }
#undef MX
}
using IO::in;
using IO::out;
using IO::putc;
using IO::getstr;

#define B int(sqrt(N))
#define MAX_N 50010
typedef long long ll;


struct Edge *head[MAX_N];
struct Edge{
    int to; Edge *ne;
    Edge(int x, int y) : to(y), ne(head[x]){}
};

ll ans;
int N, M, siz[MAX_N], mx[MAX_N];
char s1[MAX_N], s2[MAX_N];
bool vis[MAX_N];

struct SAM{
#define MX (MAX_N << 1)
    int ch[MX][26], fa[MX], len[MX], sz[MX], buc[MX], q[MX];
    int pos[MX], son[MX][26], tag[MX], w[MX], end, cnt, l;
    char s[MX];
    SAM() : cnt(1), end(1){
        memset(buc, 0, sizeof(buc));    
    }
    void insert(int c){
        int p = end; end = ++cnt, len[end] = len[p] + 1, sz[end] = 1;
        for(; p && !ch[p][c]; p = fa[p]) ch[p][c] = end;
        if(!p) fa[end] = 1;
        else{
            int q = ch[p][c];
            if(len[q] == len[p] + 1) fa[end] = q;
            else{
                ++cnt;
                for(int i = 0; i < 26; i++) ch[cnt][i] = ch[q][i];
                fa[cnt] = fa[q], fa[q] = fa[end] = cnt;
                len[cnt] = len[p] + 1;
                for(; ch[p][c] == q; p = fa[p]) ch[p][c] = cnt;
            }
        }
    }
    void build(char *_s, int _l){
        l = _l;
        for(int i = 1; i <= l; i++) s[i]=_s[i];
        for(int i = 1; i <= l; i++){
            insert(s[i] - 'a');
            pos[end] = i, w[i] = end;
        }
        for(int i = 1; i <= cnt; i++) buc[len[i]]++;
        for(int i = 1; i <= cnt; i++) buc[i] += buc[i-1]; 
        for(int i = 1; i <= cnt; i++) q[buc[len[i]]--] = i;
        for(int i = cnt; i; i--){
            sz[fa[q[i]]] += sz[q[i]];
            if(!pos[fa[q[i]]]) pos[fa[q[i]]] = pos[q[i]];
            son[fa[q[i]]][s[pos[q[i]] - len[fa[q[i]]]] - 'a'] = q[i];
        }
    }
    int work(char *s, int l){
        int v = 1;
        for(int i = 1; i <= l; i++)
            v = ch[v][s[i] - 'a'];
        return sz[v];
    }
    void dfs(int x, int fa, int v, int _l){ 
        if(_l == len[v]) v = son[v][s1[x] - 'a'];
        else if(s[pos[v] - _l] != s1[x]) v = 0;
        if(!v) return;
        _l++, tag[v]++;
        for(Edge *i = head[x]; i; i = i->ne) 
            if(i->to != fa && !vis[i->to])
                dfs(i->to, x, v, _l);
    }
    void calc(){
        for(int i = 1; i <= cnt; i++)
            tag[q[i]] += tag[fa[q[i]]];
    }
#undef MX
}S1, S2;

void dfs1(int x, int fa, int &rt, int tot){
    siz[x] = 1, mx[x] = 0;
    for(Edge *i = head[x]; i; i = i->ne) 
        if(i->to != fa && !vis[i->to]){
            dfs1(i->to, x, rt, tot);
            siz[x] += siz[i->to];
            mx[x] = std::max(mx[x], siz[i->to]);
        }
    mx[x] = std::max(mx[x], tot - siz[x]);
    if(!rt || mx[x] < mx[rt]) rt = x;
}

ll dfs2(int x, int fa, int v){
    v = S1.ch[v][s1[x]-'a'];
    if(!v) return 0;
    ll res = S1.sz[v];
    for(Edge *i = head[x]; i; i=i->ne)
        if(i->to != fa && !vis[i->to])
            res += dfs2(i->to, x, v);
    return res;
}

ll calc1(int x, int fa){ 
    ll res = dfs2(x, 0, 1);
    for(Edge *i = head[x]; i; i = i->ne) 
        if(i->to != fa && !vis[i->to])
            res += calc1(i->to, x);
    return res;
}

ll calc2(int x, int fa){
    memset(S1.tag, 0, sizeof(S1.tag));
    memset(S2.tag, 0, sizeof(S2.tag));
    if(fa){
        S1.dfs(x, fa, S1.ch[1][s1[fa] - 'a'], 1);
        S2.dfs(x, fa, S2.ch[1][s1[fa] - 'a'], 1);
    }else{
        S1.dfs(x, fa, 1, 0);
        S2.dfs(x, fa, 1, 0);
    }
    S1.calc(), S2.calc();
    ll ans = 0;
    for(int i = 1; i <= M; i++)
        ans += 1ll * S1.tag[S1.w[i]] * S2.tag[S2.w[M-i+1]];
    return ans;
}

void divide(int x){
    int rt = 0;
    dfs1(x, 0, rt, 0);
    ans += calc2(x, 0);
    vis[x] = true;
    for(Edge *i = head[x]; i; i = i->ne) if(!vis[i->to])
        ans -= calc2(i->to, x);
    for(Edge *i = head[x]; i; i = i->ne) if(!vis[i->to]){
        if(siz[i->to] <= B) ans += calc1(i->to, x);
        else{
            rt = 0;
            dfs1(i->to, 0, rt, siz[i->to]);
            divide(i->to);
        }
    }
}

int main(){
    in(N), in(M);
    for(int i = 1, x, y; i < N; i++){
        in(x), in(y);
        head[x] = new Edge(x, y);
        head[y] = new Edge(y, x);
    }
    getstr(s1, N);
    getstr(s2, M);
    S1.build(s2, M);
    std::reverse(s2+1, s2+M+1);
    S2.build(s2, M);
    int rt = 0;
    dfs1(1, 0, rt, N);
    if(N <= B) ans += calc1(1, 0);
    else divide(rt);
    printf("%lld\n", ans);
    return 0;
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注