最小乘积生成树和最小匹配

Problem

  • 给出$n$个点和$m$条边,每条边上有权值$a_i,b_i$,求一个生成树使得选出的边的$\sum a_i \cdot \sum b_i$最小/最大

  • 给出$n$个点和$m$条边的二分图,$i$点和$j$点匹配的权值为$a_{ij}$和$b_{ij}$, 求一种匹配使得$\sum a_{ij} \cdot \sum b_{ij}$最小/最大

Solution

把解空间看做二维平面上的点,$(\sum a_i,\sum b_i)$ 分别看做点的横纵坐标。显然最优解一定是在解集构成的下凸壳上。

这里需要用到另外一种求凸壳的方法。先确定最左边的点$A$,(不管第二维对于第一维的最优解),再确定最右边的点$B$,(不管第一维,第二维的最优解),然后找到离直线$AB$最远的点$C$。

对于$A$点和$B$点,我们希望最大化$ABC$的面积,也就是$(\vec A -\vec B)\cdot (\vec C- \vec B)$
$$
(A.x-B.x,A.y-B.y)\cdot (C.x-B.x,C.y-B.y)\\
=(A.x-B.x)\cdot (C.y-B.y)-(A.y-B.y)\cdot (C.x-B.x)\\
=C.x\cdot (B.y-A.y)+C.y\cdot (A.x-B.x)+B.x\cdot (A.y-B.y)-B.y\cdot (A.x-B.x)
$$
然后式子后面部分和C无关,那么只要最大化$C.x \cdot (B.y-A.y)+C.y \cdot (A.x-B.x)$就好了。

找到的$C$一定在凸壳上,然后递归$AC$,$CB$部分的凸壳。当$C$在直线$A,B$上时结束递归。

假设最小/最大生成树和最小/最大匹配最复杂度为$O(T)$,可以证明凸壳上的点数为$O(m^2)$,实际复杂度$O(m^2T)$

Example

BZOJ2395 Timeismoney

最小乘积生成树板题,直接写就好了。注意细节是最好把最小生成树写成返回一个最优点的函数,不能用全局变量。不然在递归过程中会有一些莫名其妙的问题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include <bits/stdc++.h>
using namespace std;
#define IL inline
#define rep(i, j, k) for (int i = j; i <= k; ++i)
const int maxn = 5e5 + 10;
IL int read() {
char ch = getchar(); int u = 0, f = 1;
while (!isdigit(ch)) { if (ch == '-') f = -1; ch = getchar(); }
while (isdigit(ch)) { u = (u << 1) + (u << 3) + ch - 48; ch = getchar(); }
return u * f;
}
class ed { public: int u, v, a, b; long long c; ed() {u = v = a = b = c = 0;} }e[maxn];
IL bool cmpa(ed x, ed y) { return x.a < y.a; }
IL bool cmpb(ed x, ed y) { return x.b < y.b; }
IL bool cmpc(ed x, ed y) { return x.c > y.c; }
int fa[maxn], n, m;
int findf(int x) { return fa[x] == x ? x : fa[x] = findf(fa[x]); }
class point {
public : int x, y;
point(int q = 0, int p = 0) : x(p), y(q) {}
IL bool operator < (point r) { return 1ll * x * y < 1ll * r.x * r.y; }
}ans;
point kus() {
point now;
now.x = now.y = 0;
rep(i, 1, n) fa[i] = i; int cnt = 0;
rep(i, 1, m) {
if (findf(e[i].u) != findf(e[i].v)) {
fa[findf(e[i].u)] = findf(e[i].v), now.x += e[i].a, now.y += e[i].b;
cnt++;
if (cnt == n - 1) break;
}
}
if (now < ans) ans = now;
return now;
}
long long cross(point a, point b, point c) {
return 1ll * (b.x - a.x) * (c.y - a.y) - 1ll * (c.x - a.x) * (b.y - a.y);
}
void solve(point a, point b) {
rep(i, 1, m) e[i].c = 1ll * e[i].b * (a.x - b.x) + 1ll * e[i].a * (b.y - a.y);
sort(e + 1, e + 1 + m, cmpc);
point now = kus();
if (cross(a, b, now) >= 0) return;
solve(a, now);
solve(now, b);
}
int main() {
n = read(), m = read();
ans = point((int)1e9, (int)1e9);
for (int i = 1; i <= m; ++i) e[i].u = read()+1, e[i].v = read()+1, e[i].a = read(), e[i].b = read();
sort(e + 1, e + 1 + m, cmpa);
point x = kus();
sort(e + 1, e + 1 + m, cmpb);
point y = kus();
solve(x, y);
printf("%d %d\n", ans.x, ans.y);
return 0;
}

BZOJ3571 画框

最小乘积匹配板题,但是由于费用流EK复杂度为$O(NEK)$的,对于这个题卡满就是$O(N^6)$的(复杂度上天)。

实际复杂度为$O(n^5)​$(虽然跑不满)。对于大一点的数据3秒多才能跑出来。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#include <bits/stdc++.h>
using namespace std;
#define IL inline
#define rep(i, j, k) for (int i = j; i <= k; ++i)
const int maxn = 200;
const int maxm = 1e5 + 10;
IL int read() {
char ch = getchar(); int u = 0, f = 1;
while (!isdigit(ch)) { if (ch == '-') f = -1; ch = getchar(); }
while (isdigit(ch)) { u = (u << 1) + (u << 3) + ch - 48; ch = getchar(); }
return u * f;
}
int n, a[maxn][maxn], b[maxn][maxn], rem_cnt;
class point {
public : int x, y;
point(int q = 0, int p = 0) : x(q), y(p) {}
IL bool operator < (point r) { return 1ll * x * y < 1ll * r.x * r.y; }
}real_ans;

class costFlow {
public:
class edge { public: int fr, to, f, w, next, val_a, val_b, bkf; }e[maxm];
int cnt, s, t, ansa, ansb, d[maxn], flag[maxn], incf[maxm], head[maxn], ans, maxflow, pre[maxn], da[maxn], db[maxn];
costFlow() { cnt = 1; ans = s = t = 0; }
IL void addedge(int x, int y, int f, int w, int va, int vb) { e[++cnt].to = y; e[cnt].next = head[x]; head[x] = cnt; e[cnt].f = f; e[cnt].w = w; e[cnt].val_a = va; e[cnt].val_b = vb; e[cnt].fr = x; e[cnt].bkf = f; }
IL void add(int x, int y, int f, int w, int val_a, int val_b) { addedge(x, y, f, w, val_a, val_b); addedge(y, x, 0, -w, -val_a, -val_b); }
IL void init(int _s, int _t) { s = _s; t = _t; }
bool spfa() {
queue<int>q;
memset(d, 0x3f, sizeof d);
memset(flag, 0, sizeof flag);
memset(da, 0x3f, sizeof da);
memset(db, 0x3f, sizeof db);
q.push(s); da[s] = db[s] = d[s] = 0; flag[s] = 1;
incf[s] = 1 << 30;
while (q.size()) {
int x = q.front(); q.pop(); flag[x] = 0;
for (int i = head[x]; i; i = e[i].next) {
if (!e[i].f) continue;
int y = e[i].to;
if (d[y] > d[x] + e[i].w) {
d[y] = d[x] + e[i].w;
da[y] = da[x] + e[i].val_a;
db[y] = db[x] + e[i].val_b;
incf[y] = min(incf[x], e[i].f);
pre[y] = i;
if (!flag[y]) flag[y] = 1, q.push(y);
}
}
}
if (d[t] == 0x3f3f3f3f) return false;
return true;
}

void update() {
int x = t;
while (x != s) {
int i = pre[x];
e[i].f -= incf[t];
e[i ^ 1].f += incf[t];
x = e[i ^ 1].to;
}
maxflow += incf[t];
ans += d[t] * incf[t];
ansa += da[t] * incf[t];
ansb += db[t] * incf[t];
}
IL void flow_reset() { rep(i, 1, cnt) e[i].f = e[i].bkf; }
point mcf() {
flow_reset();
ans = ansa = ansb = 0;
point Rtn = point(0, 0);
while (spfa()) {
update();
}
Rtn = point(ansa, ansb);
if (Rtn < real_ans) real_ans = Rtn;
return Rtn;
}
IL void clear() {
cnt = 1;
memset(head, 0, sizeof head);
rem_cnt = 0;
}
}G;

long long cross(point a, point b, point c) {
return 1ll * (b.x - a.x) * (c.y - a.y) - 1ll * (c.x - a.x) * (b.y - a.y);
}
void solve(point a, point b) {
rep(i, 2, rem_cnt) {
int new_val = G.e[i].val_b * (a.x - b.x) + 1ll * G.e[i].val_a * (b.y - a.y);
G.e[i].w = -new_val; G.e[i ^ 1].w = new_val;
++i;
}
point now = G.mcf();
if (cross(a, b, now) >= 0) return;
solve(a, now);
solve(now, b);
}
int main() {
// freopen("frame.in", "r", stdin);
// freopen("frame.out", "w", stdout);
int T = read();
while (T--) {
G.clear();
n = read(); real_ans = point((int)1e9, (int)1e9);
rep(i, 1, n) rep(j, 1, n) a[i][j] = read();
rep(i, 1, n) rep(j, 1, n) b[i][j] = read();
int S = 2 * n + 1; int T = S + 1;
G.init(S, T);
rep(i, 1, n) rep(j, 1, n) G.add(i, j + n, 1, 0, a[i][j], b[i][j]);
rem_cnt = G.cnt;
rep(i, 1, n) G.add(S, i, 1, 0, 0, 0);
rep(i, 1, n) G.add(i + n, T, 1, 0, 0, 0);

rep(i, 2, rem_cnt) {
int u = G.e[i].fr, v = G.e[i].to;
G.e[i].w = a[u][v - n]; G.e[i ^ 1].w = -a[u][v - n];
i++;
}
point Maxx = G.mcf();

rep(i, 2, rem_cnt) {
int u = G.e[i].fr, v = G.e[i].to;
G.e[i].w = b[u][v - n]; G.e[i ^ 1].w = -b[u][v - n];
i++;
}
point Maxy = G.mcf();

solve(Maxx, Maxy);
printf("%lld\n", 1ll * real_ans.x * real_ans.y);
}
return 0;
}