[Luogu P3413] SAC#1 – 萌数

[Luogu P3413] SAC#1 – 萌数

思路

很显然,要看一个数是不是萌数,看$2$或$3$位就OK了。很显然是一个数位DP。

很显然直接DP记录萌数去重会很麻烦,所以用$f[i][j][k]$表示长度为$i$,最左边两位为$j, k$的非萌数个数。作为比较无脑的人不太会记搜,所以就刷个表统计答案好咯,细节在代码注释里qwq。

代码

#include <iostream>
#include <cstdio>
using namespace std;

const int MOD=1000000007;
int l[1010], r[1010], f[1010][10][10];

void getnum(int *a){
    char ch=getchar();
    while(ch<'0' || ch>'9')	ch=getchar();
    while(ch>='0' && ch<='9'){
        a[++a[0]]=ch-'0';
        ch=getchar();
    }
}

long long solve(int *a){
    long long tot=0, ans=0;
	int X=-1, Y=-1;					
    //X为当前位往左一位,Y为当前位往左两位 
    bool fl=true;
    for(int i=1; i<=a[0]; i++)
        tot=(tot*10+1LL*a[i])%MOD;
    //统计1~a的数的个数
	tot++;	//加个0 
    for(int i=1; i<a[0]; i++)
        for(int j=1; j<=9; j++)
            for(int k=0; k<=9; k++)
                ans=(ans+f[i][j][k])%MOD;		
	//求少于a[0]位的非萌数有多少 
    if(a[0]>1)	ans=(ans+10)%MOD;				
	//把一位数(0~9)也算上 
    for(int i=a[0]; i>1; i--){
	//从最左边一位到从右往左第二位确定每一位 
        int v=a[a[0]-i+1];
        for(int j=0; j<v; j++)	
            if(i!=a[0] || j!=0)
			//不能有前导0 
                for(int k=0; k<=9; k++)
                    if(X!=j&&Y!=j&&k!=X)
					//避免统计到萌数 
                        ans=(ans+f[i][j][k])%MOD;
        if(v==X || v==Y){fl=false; break;}
		//如果确定当位后有回文就break,并进行标记 
        Y=X, X=v;									
    }
    if(fl)										
	//如果没有统计到过回文就进行最右一位的统计 
        for(int j=0; j<=a[a[0]]; j++)
            if(j!=Y && j!=X)
                ans++, ans%=MOD;
    return (tot-ans+MOD)%MOD; 
}

int main(){
    for(int i=2; i<=1000; i++)
        for(int j=0; j<=9; j++)
            for(int k=0; k<=9; k++)if(j!=k){
                for(int h=0; h<=9; h++)
                    if(h!=j && h!=k)
                        f[i][j][k]=(f[i][j][k]+f[i-1][k][h])%MOD;
                if(i==2)	f[i][j][k]++;
            }
    getnum(l);
    getnum(r);
    if(l[l[0]]!=0)	l[l[0]]--;
    else{
        int now=l[0];
        while(l[now]==0)
            l[now]=9, now--;
        l[now]--;
        if(now==1 && l[now]==0){
            for(int i=1; i<=l[0]; i++)	l[i]=l[i+1];
            l[0]--;
        }
    }
    //对l进行-1 
    cout<<(solve(r)-solve(l)+MOD)%MOD;
    return 0;
}