Codeforces 796C

题目

给一颗$N$个节点的树,节点$i$有权值$a_i$。第一次可以随意删除一个节点i,费用为$a_i$,并且与$a_i$距离为$1$的节点$j$的权值$$a_j增加$1$,与$i$距离为$2$的节点$K$的权值$$a_k增加$2$。之后删除的点必须与某个已被删除的点距离为$1$,其余删除规则与第一次一样。问:删除所有的点,费用的最大值的最小值是多少?

数据范围

$1 \leq N \leq 3 \times 10^5 \quad |a_i| \leq 10^9$

做法

$O(N\times log_2N)$

第一个被删除的点$i$的费用是$a_i$,与$i$相连的点$j$的费用是$a_j+1$,其余的点$k$的费用都是$a_k+2$(由树上无环的性质和删除的规则可以推出)。于是,我们只要枚举第一个删除的点,模拟一下删除的过程即可,需要一个数据结构来支持高效地插入、删除和查询最大值的操作,multiset可以满足要求。

$O(N)$

上文已分析出:每个点的权值最多增加$2$。所以答案只能是$m,m+1,m+2$中的数。其中,$m$为初始权值的最大值。所以,我们在枚举第一个删除的点后模拟删除的过程中,只要记录剩余的集合中$m-1$的个数和$m$的个数即可。

注意:删除与第一个删除的数相连的数时,如果它是m-1,那么权值的最大值至少是m+1。我们可以等价地处理成在删掉第一个数和所有与第一个数相连的数的集合中加入数m-1。

代码

$O(N\times log_2N)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 3e5 + 5;
int N, a[MAX_N];
multiset<int> S;
vector<int> G[MAX_N];
int main()
{
cin >> N;
for (int i = 0; i < N; ++i) scanf("%d", a + i), S.insert(a[i]);
for (int i = 1; i < N; ++i) {
int a, b; scanf("%d%d", &a, &b); --a; --b;
G[a].push_back(b);
G[b].push_back(a);
}
int ans = INT_MAX;
for (int v = 0; v < N; ++v) {
int tmp_ans = a[v];
S.erase(S.find(a[v]));
for (int u : G[v]) {
if (u != v) {
S.erase(S.find(a[u]));
tmp_ans = max(tmp_ans, a[u] + 1);
}
}
if (!S.empty()) tmp_ans = max(tmp_ans, *S.rbegin() + 2);
for (int u : G[v]) {
if (u != v) {
S.insert(a[u]);
}
}
S.insert(a[v]);
ans = min(ans, tmp_ans);
}
printf("%d\n", ans);
return 0;
}

$O(N)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 3e5 + 5;
int N;
int a[MAX_N];
vector<int> G[MAX_N];
int main()
{
cin >> N;
int max_num = INT_MIN;
for (int i = 0; i < N; ++i) scanf("%d", a + i), max_num = max(max_num, a[i]);
for (int i = 1; i < N; ++i) {
int a, b; scanf("%d%d", &a, &b); --a; --b;
G[a].push_back(b);
G[b].push_back(a);
}
int x = 0, y = 0;
for (int i = 0; i < N; ++i) {
if (a[i] == max_num) x++;
else if (a[i] == max_num - 1) y++;
}
int ans = max_num + 2;
for (int i = 0; i < N; ++i) {
int nx = x, ny = y;
if (a[i] == max_num) --nx;
else if (a[i] == max_num - 1) --ny;
for (int u : G[i]) if (u != i) {
if (a[u] == max_num) --nx, ++ny;
else if (a[u] == max_num - 1) --ny;
}
if (nx == 0) {
ans = min(ans, max_num + 1);
if (ny == 0) ans = min(ans, max_num);
}
}
printf("%d\n", ans);
return 0;
}