一只$O(m\log n)$的最小方差生成树

给定$n$个点$m$条边的无向图,求一棵所有边权的方差最小的生成树

$n, m\leq 10^5$

首先有一个很noip的思路,我们可以枚举所有可能的$\overline{x}$的取值,对于每种取值跑一次kruskal。

另外有一个显而易见的思路,如果$\overline{x}\leq \frac{a+b}{2}$,那么选较小的一个更优,否则选较大的一个更优。这样的话需要枚举的$\overline{x}$的取值就只有$m^2$种了。

而每条边存在的时间必然是一个区间,考虑LCT求最小生成树的步骤,假如当前加入的边是$a$,$u_a$到$v_a$已经有一条路径了,那么在这条路径上,我们需要考虑删除一条边$b$,使得删除$b$加入$a$最优。因为方差是数分布越集中越小,所以$b$应该是$w_b$最小的一条边。而何时删除$b$加入$a$最好呢?根据刚才的思路很显然是$\overline{x}=\frac{w_a+w_b}{2}$的时候,这样就求出所有边的存在的$\overline{x}$区间了,之后遍历所有区间,用下边的式子算一下那个时候的方差就好了

$$\begin{split}
\sigma &= \frac{\sum \limits_{i = 1} ^n(x_i – \overline{x})^2}{n}\\
&= \frac{\sum \limits_{i=1}^n {x_i}^2 + n\overline{x}^2 – 2\overline{x}\sum \limits_{i=1}^n x_i}{n}\\
&= \frac{\sum \limits_{i=1}^n {x_i}^2}{n}+\frac{\left(\sum \limits_{i=1}^n x_i\right)^2}{n^2}
\end{split}$$

#include <iostream>
#include <cstdio>
#include <vector>
#include <algorithm>

#define X first
#define Y second
#define pb push_back
#define mp std::make_pair
#define MAX_N 100010
#define MAX_M MAX_N
typedef std::pair<int, int> pii;

namespace LCT{
    struct Node *ww;
    struct Node{
        Node *ch[2], *fa, *mn;
        int val, laz_rev;
        Node() : fa(NULL), laz_rev(0){
            ch[0] = ch[1] = NULL;
        }
        int rlt(){
            if(!fa) return -1;
            return fa->ch[0] == this ? 0 : fa->ch[1] == this ? 1 : -1;
        }
        void rev(){
            std::swap(ch[0], ch[1]);
            laz_rev ^= 1;
        }
        void push_down(){
            if(!laz_rev) return;
            if(ch[0]) ch[0]->rev();
            if(ch[1]) ch[1]->rev();
            laz_rev = 0;
        }
        void push_all(){
            if(rlt() != -1) fa->push_all();
            push_down();
        }
        void push_up(){
            mn = this;
            if(ch[0] && ch[0]->mn->val < mn->val)
                mn = ch[0]->mn;
            if(ch[1] && ch[1]->mn->val < mn->val)
                mn = ch[1]->mn;
        }
        Node *getrt(){
            access(), splay();
            Node *v = this;
            while(v->ch[0]) v = v->ch[0];
            return v;
        }
        void rotate(){
            Node *o = fa; int d = rlt();
            fa = o->fa;
            if(o->rlt() != -1) o->fa->ch[o->rlt()] = this;
            o->ch[d] = ch[d ^ 1];
            if(ch[d ^ 1]) ch[d ^ 1]->fa = o;
            ch[d ^ 1] = o, o->fa = this;
            o->push_up(), push_up();
        }
        void splay(){
            push_all();
            while(rlt() != -1){
                if(fa->rlt() == -1) rotate();
                else if(fa->rlt() == rlt())
                    fa->rotate(), rotate();
                else rotate(), rotate();
            }
        }
        void access(){
            for(Node *x = this, *y = NULL; x; y = x, x = x->fa){
                x->splay(), x->ch[1] = y, x->push_up();
            }
        }
        void mkrt(){
            access(), splay(), rev();
        }
    }nd[MAX_N + MAX_M];

    int getrt(int x){
        return nd[x].getrt() - nd;
    }

    void link(Node *x, Node *y){
        x->mkrt(), x->fa = y;
    }

    void link(int x, int y){
        link(nd + x, nd + y);
    }

    void cut(Node *x, Node *y){
        x->mkrt(), y->access(), y->splay();
        x->fa = y->ch[0] = NULL;
        y->push_up();
    }

    void cut(int x, int y){
        cut(nd + x, nd + y);
    }

    Node *query(Node *x, Node *y){
        x->mkrt(), y->access(), y->splay();
        return y->mn;
    }

    int query(int x, int y){
        return query(nd + x, nd + y) - nd;
    }
}

struct Data{
    int x, y, w;
}a[MAX_M];

int n, m;

std::vector <pii> vec;

int main(){LCT::ww = LCT::nd;
    in(n), in(m), in(a[1].x);
    for(int i = 1; i <= m; i++)
        in(a[i].x), in(a[i].y), in(a[i].w);
    for(int i = 1; i <= n; i++)
        LCT::nd[i].val = 1e9;
    std::sort(a + 1, a + m + 1, 
        [](Data x, Data y){return x.w < y.w;});
    for(int i = 1; i <= m; i++)
        LCT::nd[i + n].val = a[i].w;
    for(int i = 1; i <= m; i++){;
        if(LCT::getrt(a[i].x) != LCT::getrt(a[i].y)){
            vec.pb(mp(-1e9, a[i].w));
        }else{
            int p = LCT::query(a[i].x, a[i].y) - n;
            vec.pb(mp(a[p].w + a[i].w, a[i].w));
            vec.pb(mp(a[p].w + a[i].w, -a[p].w));
            LCT::cut(a[p].x, p + n);
            LCT::cut(a[p].y, p + n);
        }
        LCT::link(a[i].x, i + n);
        LCT::link(a[i].y, i + n);
    }
    std::sort(vec.begin(), vec.end());
    __int128 ans = 1e36, sum = 0, sum2 = 0;
    int cnt = 0;
    for(auto i: vec){
        int opt = i.Y < 0 ? -1 : 1;
        cnt += opt;
        sum += i.Y;
        sum2 += (__int128) i.Y * i.Y * opt; 
        if(cnt == n-1)
            ans = std::min(ans, sum2 * (n-1) - sum * sum);
    }
    if(ans == 1e36)
        out(-1), putc('\n');
    else
        out(ans), putc('\n');
    return 0;
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注