[CodeChef]Prime Distance On Tree

CodeChef PRIMEDST

思路

考虑暴力,很显然这个就是一个淀粉质板题,但是如果这样写的话还不如暴力跑得快。

其实淀粉质的calc函数中做的$O(n^2)$运算就是一个卷积,那么FFT加速就可以了

总的时间复杂度为$O(n\cdot \log^2 n)$

不知道为啥我写NTT就WA,改成FFT就过了

代码

#include <iostream>
#include <cstdio>
#include <cmath>

typedef long long LL;

struct Node{
	struct Edge *head;
	bool vis;
	int siz, mx;
	Node() : head(NULL), vis(false){}
}node[50010], *root;

struct Edge{
	Node *to;
	Edge *ne;
	Edge(int x, int y) : to(&node[y]), ne(node[x].head){}
};

int N, SumSiz, pri[100010], mxdis;
LL cnt[140010], ans[140010], tmp[140010];
const double Pi=3.1415926535897;
bool isn_pri[100010];

namespace FFT{
	struct Complex{
		double x, y;
		Complex(double x=0, double y=0) : x(x), y(y){}
		Complex operator +(Complex _){return Complex(x+_.x, y+_.y);}
		Complex operator -(Complex _){return Complex(x-_.x, y-_.y);}
		Complex operator *(Complex _){return Complex(x*_.x-y*_.y, x*_.y+y*_.x);}
	}C[140010], D[140010];

	int n, rev[140010];
	
	void fft(Complex *P, int opt){
		for(int i=0; i<n; i++) if(rev[i]>i)
			std::swap(P[i], P[rev[i]]);
		for(int i=1; i<n; i<<=1){
			Complex W=Complex(cos(Pi/i), sin(Pi/i)*opt);
			for(int p=i*2, j=0; j<n; j+=p){
				Complex w=Complex(1, 0);
				for(int k=0; k<i; k++, w=w*W){
					Complex X=P[j+k], Y=P[i+j+k]*w;
					P[j+k]=X+Y; P[i+j+k]=X-Y;
				}
			}
		}
	}
	
	void conv(LL *a, LL *b, LL *c){
		for(int i=0; i<=2*mxdis; i++){
			C[i]=Complex(a[i], 0), D[i]=Complex(b[i], 0);
		}
		for(n=1; n<=2*mxdis; n<<=1);
		for(int i=2*mxdis+1; i<n; i++) C[i]=D[i]=Complex(0, 0);
		for(int i=0; i<n; i++)
			rev[i]=(rev[i>>1]>>1) | ((i&1) ? n>>1 : 0);
		fft(C, 1); fft(D, 1);
		for(int i=0; i<n; i++) C[i]=C[i]*D[i];
		fft(C, -1);
		for(int i=0; i<n; i++) c[i]=LL(C[i].x/n+0.5);
	}
}

void get_root(Node *x, Node *fa){
	x->siz=1, x->mx=0;
	for(Edge *i=x->head; i; i=i->ne) if(i->to!=fa && !i->to->vis){
		get_root(i->to, x); x->siz+=i->to->siz; x->mx=std::max(x->mx, i->to->siz);
	}
	x->mx=std::max(x->mx, SumSiz-x->siz);
	root=root ? (x->mx<root->mx ? x : root) : x;
}

void dfs1(Node *x, Node *fa, int sum){
	cnt[sum]+=1; mxdis=std::max(mxdis, sum);
	for(Edge *i=x->head; i; i=i->ne) if(i->to!=fa && !i->to->vis){
		dfs1(i->to, x, sum+1);
	}
}

void calc1(Node *x, int val, int opt){
	mxdis=0;
	dfs1(x, NULL, val);
	FFT::conv(cnt, cnt, tmp);
	for(int i=1; i<=2*mxdis; i++) ans[i]+=opt*tmp[i];
	for(int i=0; i<=2*mxdis; i++) cnt[i]=0;
}

void divide(Node *x){
	calc1(x, 0, 1); x->vis=true;
	for(Edge *i=x->head; i; i=i->ne) if(!i->to->vis)
		calc1(i->to, 1, -1);
	for(Edge *i=x->head; i; i=i->ne) if(!i->to->vis){
		SumSiz=i->to->siz; root=NULL;
		get_root(i->to, NULL);
		divide(root);
	}
}

void get_pri(){
	for(int i=2; i<=100000; i++){
		if(!isn_pri[i]) pri[++pri[0]]=i;
		for(int j=1; j<=pri[0] && pri[j]*i<=100000; j++){
			isn_pri[i*pri[j]]=true;
			if(i%pri[j]==0) break;
		}
	}
}

int main(){
	get_pri();
	int x, y;
	scanf("%d", &N);
	for(int i=1; i<N; i++){
		scanf("%d%d", &x, &y);
		node[x].head=new Edge(x, y);
		node[y].head=new Edge(y, x);
	}
	root=NULL; SumSiz=N;
	get_root(&node[1], NULL);
	divide(root);
	double tot1=0, tot2=0;
	for(int i=1; i<=pri[0] && pri[i]<2*N-1; i++)
		tot1+=ans[pri[i]]/2;
	tot2=1LL*N*(N-1)/2;
	printf("%.8lf\n", tot1/tot2);
	return 0;
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注