论如何在很快的时间中求解最小树形图

前言

学到了这种很神奇的 Trick ,于是来写一下

好吧其实感觉是正常智商水平算法竞赛选手能想到的,可能我太菜了

另外本来标题想叫「论如何在$O(xxx)$的时间中求解最小树形图」,但是发现好像不太会分析时间复杂度…… 欢迎会分析的dalao教我分析

平凡的朱-刘算法

现在网上满天飞的求最小树形图的方法就是朱-刘算法,而那种很快的算法也是朱刘算法改出来的,所以先说朱刘算法。

朱刘算法其实就是重复以下操作:

  1. 对于每个点,把他的边权最小的入边拎出来。
  2. 观察所有边权最小的入边构成的图,如果有环的话就把这个环缩起来。
  3. 把环上所有的点的其他入边的权值都减去那个最小的权值,并将那个最小权值计入答案。

重复这些步骤,直到没有环存在。这时再将所有边权最小的边选出来的话,得到的就是一个最小树形图了。

优化

其实朱刘算法的操作过程,可以用另一种想法来实现。

每次随意找出一个点,为这个点找出权值最小的入边,重复这一操作,直到找出一条到根的路径,或者一条路径+一个环。

如果这条路径上没有环,也无法到根,那么显然是无解的。

如果这是一条到根的路径,那么很显然,最终的树形图是包含这条路径的。那么就把这条路径上的权值都加到答案里。

如果这条路径有一个环,就把这个环按上边的方法缩起来,并且将剩下的路径上的权值加入答案。

当一个点的路径到达了根,这个点就不用管了。

所以可以用一个并查集维护每个点和根是否联通,用一个并查集维护几个点是否被缩成了一个点,用可并堆维护一下每个点的入边。

这样的话,复杂度大概应该可能是:

  1. $O(n\log m)$合并可并堆$n$次
  2. $O(m\log m)$push、pop所有边

所以可能就是$O((n+m)\log m)$了叭……欢迎有会分析的聚聚告诉我我分析对不对_(:з」∠)_

代码

#include <iostream>
#include <cstdio>

#define X first
#define Y second
#define mp std::make_pair
typedef std::pair<int, int> pii;

#define MAX_N 110

class Union{
private:
    int fa[MAX_N];
public:
    Union(){
        for(int i = 1; i < MAX_N; i++)
            fa[i] = i;
    }
    int find(int x){
        return fa[x] == x ? x : fa[x] = find(fa[x]);
    }
    void merge(int x, int y){
        x = find(x), y = find(y);
        fa[y] = x;
    }
}U1, U; //U1维护是否与根联通,U维护是否在一个环里

template <typename T>
class Heap{
private:
    struct Node{
        T val;
        int dis, laz;
        Node *lch, *rch;
        Node(T val) : val(val), dis(1), laz(0), lch(NULL), rch(NULL){}
        void push_down(){
            if(lch) lch->val.X += laz, lch->laz += laz;
            if(rch) rch->val.X += laz, rch->laz += laz;
            laz = 0;
        }
    }*rt[MAX_N];
    Node *merge(Node *x, Node *y){
        if(!x) return y;
        if(!y) return x;
        x->push_down(), y->push_down();
        if(y->val < x->val) std::swap(x, y);
        x->rch = merge(x->rch, y);
        if(!x->lch || x->lch->dis < x->rch->dis)
            std::swap(x->lch, x->rch);
        x->dis = (x->rch ? x->rch->dis : 0) + 1;
        return x;
    }
public:
    Heap(){
        for(int i = 0; i < MAX_N; i++)
            rt[i] = NULL;
    }
    void push(int x, T y){
        x = U.find(x);
        rt[x] = merge(rt[x], new Node(y));
    }
    void pop(int x){
        x = U.find(x);
        rt[x]->push_down();
        rt[x] = merge(rt[x]->lch, rt[x]->rch);
    }
    T top(int x){
        x = U.find(x);
        while(rt[x]){
            rt[x]->push_down();
            pii ans = rt[x]->val;
            ans.Y = U.find(ans.Y);
            if(ans.Y != x) return ans;
            rt[x] = merge(rt[x]->lch, rt[x]->rch);
        }
        return mp(-1, -1);
    }
    void merge(int x, int y){
        x = U.find(x), y = U.find(y);
        if(x == y) return;
        U.merge(x, y);
        if(y == U.find(x))
            std::swap(x, y);
        rt[x] = merge(rt[x], rt[y]);
    }
    void modify(int x, int y){
        x = U.find(x);
        rt[x]->laz = y;
        rt[x]->val.X += y;
    }
    bool empty(int x){
        x = U.find(x);
        return top(x).Y == -1;
    }
};

int n, m, rt, vis[MAX_N], pre[MAX_N];

int main(){
    Heap <pii> H;
    in(n), in(m), in(rt);
    for(int i = 1; i <= m; i++){
        int x, y, z;
        in(x), in(y), in(z);
        H.push(y, mp(z, x));
    }
    int ans = 0, tim = 0;
    while(1){
        for(int i = 1; i <= n; i++) if(U1.find(U.find(i)) != U1.find(U.find(rt))){
            int x = U.find(i), y = 0;
            tim++; 
            while(vis[x] != tim && U1.find(x) != U1.find(rt)){
                if(H.empty(x)){ //如果没有入边,则无解
                    out(-1), putc('\n');
                    return 0;
                }
                vis[x] = tim;
                pre[x] = H.top(x).Y;
                U1.merge(x, pre[x]);
                ans += H.top(x).X;
                H.modify(x, -H.top(x).X);
                y = x, x = pre[x];
            }
            if(vis[x] == tim) //把环上的点缩起来
                for(int v = x; pre[v] != x; v = pre[v])
                    H.merge(v, pre[v]);
        }
        bool fl = true;
        for(int i = 1; i <= n; i++) //如果所有点都可以到根,那么就做完了
            if(U1.find(U.find(i)) != U1.find(U.find(rt)))
                fl = false;
        if(fl) break;
    }
    out(ans), putc('\n');
    return 0;
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注