[LOJ 6509] 「雅礼集训 2018 Day7」C

LOJ #6509.

假设我们知道每个点非最后一次被走到的期望次数,那么将他们乘上所有点到这个点的平均距离求和就是答案。

设当前还有$i$个黑点,白/黑色点非最后一次被走到的期望次数为$f_{i, 0/1}$,则可以列出以下转移方程
$$\begin{split}
f_{i, 0} &= \frac{i}{n}f_{i-1, 0} + \frac{n-i-1}{n}f_{i+1, 0} + \frac{f_{i+1, 1} + 1}{n}\\
f_{i, 1} &= \frac{i-1}{n}f_{i-1, 1} + \frac{n-i}{n}f_{i+1, 1} + \frac{f_{i-1, 0} + 1}{n}\\
\end{split}$$
化简一下得到
$$\begin{split}
f_{i+1, 0} &= \frac{n}{n-i-1}f_{i, 0} – \frac{i}{n – i – 1}f_{i-1, 0} – \frac{f_{i + 1, 1} + 1}{n-i-1}\\
f_{i+1, 1} &= \frac{n}{n-i}f_{i, 1} – \frac{i-1}{n-i}f_{i-1, 1} – \frac{f_{i-1, 0} + 1}{n-i}\\
\end{split}$$
然后设$f_{1, 0} = x, f_{1, 1} = y$,所有就都可以表示为$ax+by+c$的形式,然后dp出来$f_{n, 0/1}$解一下方程即可

#include <iostream>
#include <cstdio>
#include <vector>

#define pb push_back
#define MAX_N 100010
#define MOD 1000000007

int ksm(int x, int y){
	int z = 1;
	for(; y; y >>= 1, x = 1ll * x * x % MOD)
		if(y & 1) z = 1ll * x * z % MOD;
	return z;
}

struct Data{
	int a, b, c;	// ax + by + c
	Data(){}
	Data(int a, int b, int c) : a(a), b(b), c(c){}
}f[MAX_N][2];

Data operator *(int x, Data y){
	return Data(1ll * x * y.a % MOD, 1ll * x * y.b % MOD, 1ll * x * y.c % MOD);
}

Data operator -(Data x, Data y){
	return Data((x.a - y.a) % MOD, (x.b - y.b) % MOD, (x.c - y.c) % MOD);
}

int n, inv[MAX_N], res[MAX_N], sz[MAX_N];
std::vector<int> to[MAX_N];

void dfs1(int x, int d){
	res[1] = (res[1] + d) % MOD;
	sz[x] = 1;
	for(auto i: to[x]){
		dfs1(i, d + 1);
		sz[x] += sz[i];
	}
}

void dfs2(int x, int fa){
	if(fa)
		res[x] = ((res[fa] + n) % MOD - 2ll * sz[x] % MOD) % MOD;
	for(auto i: to[x])
		dfs2(i, x);
}

int main(){
	inv[1] = 1;
	for(int i = 2; i < MAX_N; i++)
		inv[i] = 1ll * (MOD - MOD / i) * inv[MOD % i] % MOD;
	static char s[MAX_N];
	scanf("%d%s", &n, s + 1);
	for(int i = 2, x; i <= n; i++){
		scanf("%d", &x);
		to[x].pb(i);
	}
	f[1][0] = Data(1, 0, 0);
	f[1][1] = Data(0, 1, 0);
	for(int i = 1; i < n - 1; i++){
		f[i+1][1] = inv[n-i] * (n * f[i][1] - (i-1) * f[i-1][1] - f[i-1][0] - Data(0, 0, i != 1));
		f[i+1][0] = inv[n-i-1] * (n * f[i][0] - i * f[i-1][0] - f[i+1][1] - Data(0, 0, 1));
	}
	Data f0, f1;
	f0 = n * f[n-1][0] - (n-1) * f[n-2][0];
	f1 = n * f[n-1][1] - (n-2) * f[n-2][1] - f[n-2][0] - Data(0, 0, 1);
	Data t = f0 - 1ll * f0.a * ksm(f1.a, MOD - 2) % MOD * f1;
	int x, y, cnt = 0, val[2];
	for(int i = 1; i <= n; i++)
		cnt += s[i] - '0';
	y = -1ll * t.c * ksm(t.b, MOD - 2) % MOD;
	x = 1ll * ksm(f1.a, MOD - 2) * (-f1.c - 1ll * f1.b * y % MOD) % MOD;
	val[0] = (1ll * f[cnt][0].a * x % MOD + 1ll * f[cnt][0].b * y % MOD + f[cnt][0].c + inv[n]) % MOD;
	val[1] = (1ll * f[cnt][1].a * x % MOD + 1ll * f[cnt][1].b * y % MOD + f[cnt][1].c + inv[n]) % MOD;
	dfs1(1, 0); dfs2(1, 0);
	int ans = 0;
	for(int i = 1; i <= n; i++)
		ans = (ans + 1ll * res[i] * val[s[i] - '0'] % MOD) % MOD;
	ans = 1ll * ans * inv[n] % MOD;
	printf("%d\n", (ans + MOD) % MOD);
	return 0;
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注