codeforces 600E(dsu on tree or 线段树合并)
PS:昨天晚上,已经学了两三天的后缀自动机的我,大致有点明白但是还是不是特别明白...心情有点烦躁。于是索性就去整理了一下自己有哪些还是不会的算法,于是便发现了dsu on tree和线段树合并这两个算法是我很久以前就想学的,但是因为时间有限还是一直拖到了现在...今天下午来实验室,打算先学dsu on tree,就先在网上看了几篇博客后,大致就懂了,还去特地做掉了cf上大佬们推荐的所谓的模板题发现的确不难,是个套路算法-_-。ac那题后,我点开了standing,发现那道题几乎我的每个好友都做了。(心想看来的确是模板题实锤了)这时发现我的一个队友以前也做了这道题,点开了她的代码后,发现为啥和我的做法丝毫不同?? 大致看了下好像还正是我等会要学的线段树合并算法???,后面继续百度了一下发现这题的确可以用线段树合并做。于是便又借助了蓝书上的线段树合并模板成功又ac了一次。
难道这真的是巧合吗???
###题意:给出一棵有根树(1为根),树上每个点都有一个权值代表不同的颜色。求出以每个点为根的子树里出现次数最多的颜色之和。 ###解法1:dsu on tree 就是先预处理出每个结点的重儿子。然后就是每次先递归进轻儿子,完了先算重儿子的贡献,再算轻儿子的贡献。轻儿子的贡献算完后要消除影响,但是重儿子不消除。这一点就是dsu on tree的核心,通过这一点可以被证明复杂度是O(nlog n),(?) 大致框架:
void dsu(int x,int f){
for(int i=head[x];i!=-1;i=edge[i].next){
int y = edge[i].e;
if(y==fat[x]||y==son[x])continue;
dsu(y,0);
}//暴力计算轻儿子
if(son[x])dsu(son[x],1);//计算重儿子
for(int i=head[x];i!=-1;i=edge[i].next){
int y = edge[i].e;
if(y==fat[x]||y==son[x])continue;
ans[x]+=cal(y);//计算轻儿子对点x的贡献
add(y);//添加轻儿子状态
}
XXXXXXX;//添加重儿子状态
if(!f)del(x);//删除轻儿子状态
}
AC代码:
#include<bits stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
#define X first
#define Y second
#define pb push_back
#define pll pair<ll, ll>
#define pli pair<ll, int>
#define pii pair<int,int>
#define New_Time srand((unsigned)time(NULL))
inline ll gcd(ll a, ll b) { while (b != 0) { ll c = a % b; a = b; b = c; }return a < 0 ? -a : a; }
inline ll lowbit(ll x) { return x & (-x); }
int head[2000010], Edge_Num;
struct Edge { int to, next; ll w; }e[4000010];
inline void ade(int x, int y, ll w) { e[++Edge_Num] = { y,head[x],w }; head[x] = Edge_Num; }
inline void G_init(int n) { memset(head, 0, sizeof(int) * (n + 100)); Edge_Num = 0; }
int dir[8][2] = { {-1,0},{0,-1},{-1,-1},{1,-1},{1,0},{0,1},{1,1},{-1,1} };
const long double PI = 3.14159265358979323846;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
inline ll rd() {
ll x = 0; bool f = 1; char ch = getchar();
while (ch<'0' || ch>'9') { if (ch == '-')f = 0; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
return f ? x : -x;
}
const double eps = 1e-8;
const ll mod = 1e9 + 7;
const int M = 1e6 + 10;
const int N = 1e6 + 10;
int hson[N], siz[N];
void dfs(int x, int pre) {
siz[x] = 1;
for (int i = head[x]; i; i = e[i].next) {
int y = e[i].to;
if (y == pre)continue;
dfs(y, x);
siz[x] += siz[y];
if (!hson[x] || siz[y] > siz[hson[x]])hson[x] = y;
}
}
ll now;
int maxx;
int cnt[N], a[N];
int Son;
ll ans[N];
void add(int x,int pre,int val) {
cnt[a[x]] += val;
if (cnt[a[x]] > maxx) {
now = a[x];
maxx = cnt[a[x]];
}
else if (cnt[a[x]] == maxx) {
now += a[x];
}
for (int i = head[x]; i; i = e[i].next) {
int y = e[i].to;
if (y == pre || y == Son)continue;
add(y, x, val);
}
}
void gao(int x, int pre,bool op) {
for (int i = head[x]; i; i = e[i].next) {
int y = e[i].to;
if (y == pre||y==hson[x])continue;
gao(y, x, 0);
}
if (hson[x])gao(hson[x], x, 1), Son = hson[x];
add(x, pre, 1); Son = 0;
ans[x] = now;
if (!op)add(x, pre, -1), now = 0, maxx = 0;
}
void solve() {
int n = rd();
for (int i = 1; i <= n; i++)a[i] = rd();
for (int i = 1; i < n; i++) {
int u = rd(), v = rd();
ade(u, v, 1), ade(v, u, 1);
}
dfs(1, 0);
gao(1, 0, 0);
for (int i = 1; i <= n; i++) {
cout << ans[i] << ' ';
}
cout << endl;
}
int main() {
int _T = 1;
// _T = rd();
while (_T--)solve();
}
##另一道dsu on tree的题目但是我觉得是写法上有一点差异 题目链接:https://ac.nowcoder.com/acm/contest/4853/E
##AC代码
#include<bits stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
#define X first
#define Y second
#define pb push_back
#define pll pair<ll, ll>
#define pli pair<ll, int>
#define pii pair<int,int>
#define New_Time srand((unsigned)time(NULL))
inline ll gcd(ll a, ll b) { while (b != 0) { ll c = a % b; a = b; b = c; }return a < 0 ? -a : a; }
inline ll lowbit(ll x) { return x & (-x); }
int head[2000010], Edge_Num;
struct Edge { int to, next; ll w; }e[4000010];
inline void ade(int x, int y, ll w) { e[++Edge_Num] = { y,head[x],w }; head[x] = Edge_Num; }
inline void G_init(int n) { memset(head, 0, sizeof(int) * (n + 100)); Edge_Num = 0; }
int dir[8][2] = { {-1,0},{0,-1},{-1,-1},{1,-1},{1,0},{0,1},{1,1},{-1,1} };
const long double PI = 3.14159265358979323846;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
inline ll rd() {
ll x = 0; bool f = 1; char ch = getchar();
while (ch<'0' || ch>'9') { if (ch == '-')f = 0; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
return f ? x : -x;
}
const double eps = 1e-8;
const ll mod = 1e9 + 7;
const int M = 1e6 + 10;
const int N = 1e6 + 10;
int hson[N], siz[N], d[N], fa[N];
void dfs(int x) {
siz[x] = 1;
for (int i = head[x]; i; i = e[i].next) {
int y = e[i].to;
if (y == fa[y])continue;
d[y] = d[x] + 1;
fa[y] = x;
dfs(y);
siz[x] += siz[y];
if (!hson[x] || siz[y] > siz[hson[x]])hson[x] = y;
}
}
ll cnt[N], a[N], num[N];
ll ans[N];
int n, k;
void add(int x,int f) {
cnt[d[x]] += f * a[x];
num[d[x]] += f;
for (int i = head[x]; i; i = e[i].next) {
int y = e[i].to;
if (y == fa[x])continue;
add(y, f);
}
}
void cal(int x, int rt) {
int v = 2 * d[rt] + k - d[x];
if (v >= 0 && v <= n)ans[rt] += cnt[v] + a[x] * num[v];
for (int i = head[x]; i; i = e[i].next) {
int y = e[i].to;
if (y == fa[x])continue;
cal(y, rt);
}
}
void dsu(int x,bool op) {
for (int i = head[x]; i; i = e[i].next) {
int y = e[i].to;
if (y == fa[x] || y == hson[x])continue;
dsu(y, 0);
}
if (hson[x])dsu(hson[x], 1);
for (int i = head[x]; i; i = e[i].next) {
int y = e[i].to;
if (y == hson[x] || y == fa[x])continue;
cal(y, x);
add(y, 1);
}
cnt[d[x]] += a[x];
num[d[x]]++;
if (!op)add(x, -1);
}
void solve() {
n = rd(), k = rd();
for (int i = 1; i <= n; i++)a[i] = rd();
for (int i = 1; i < n; i++) {
int u = rd(), v = rd();
ade(u, v, 1), ade(v, u, 1);
}
dfs(1);
dsu(1, 0);
for (int i = 1; i <= n; i++) {
cout << ans[i] << ' ';
}
cout << endl;
}
int main() {
int _T = 1;
// _T = rd();
while (_T--)solve();
}
##解法2 线段树合并 AC代码:
#include<bits stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
#define X first
#define Y second
#define pb push_back
#define pll pair<ll, ll>
#define pli pair<ll, int>
#define pii pair<int,int>
#define New_Time srand((unsigned)time(NULL))
inline ll gcd(ll a, ll b) { while (b != 0) { ll c = a % b; a = b; b = c; }return a < 0 ? -a : a; }
inline ll lowbit(ll x) { return x & (-x); }
int head[2000010], Edge_Num;
struct Edge { int to, next; ll w; }e[4000010];
inline void ade(int x, int y, ll w) { e[++Edge_Num] = { y,head[x],w }; head[x] = Edge_Num; }
inline void G_init(int n) { memset(head, 0, sizeof(int) * (n + 100)); Edge_Num = 0; }
int dir[8][2] = { {-1,0},{0,-1},{-1,-1},{1,-1},{1,0},{0,1},{1,1},{-1,1} };
const long double PI = 3.14159265358979323846;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
inline ll rd() {
ll x = 0; bool f = 1; char ch = getchar();
while (ch<'0' || ch>'9') { if (ch == '-')f = 0; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
return f ? x : -x;
}
const double eps = 1e-8;
const ll mod = 1e9 + 7;
const int M = 1e6 + 10;
const int N = 1e6 + 10;
struct SegmentTree {
int lc, rc;
pll dat;
}t[N << 2];
int root[N], tot, a[N];
pll get(pll A, pll B) {
if (A.X > B.X)return A;
else if (B.X > A.X)return B;
else {
A.Y += B.Y;
return A;
}
}
void build(int& now) {
now = ++tot;
}
void pushup(int p) {
t[p].dat = get(t[t[p].lc].dat, t[t[p].rc].dat);
}
int merge(int p, int q, int l, int r) {
if (!p)return q;
if (!q)return p;
if (l == r) {
t[p].dat.X += t[q].dat.X;
return p;
}
int mid = (l + r) >> 1;
t[p].lc = merge(t[p].lc, t[q].lc, l, mid);
t[p].rc = merge(t[p].rc, t[q].rc, mid + 1, r);
pushup(p);
return p;
}
void insert(int p,int l, int r, int val) {
if (l == r) {
t[p].dat.X++;
t[p].dat.Y = l;
return;
}
int mid = (l + r) >> 1;
if (val <= mid) {
if (!t[p].lc)build(t[p].lc);
insert(t[p].lc, l, mid, val);
}
else {
if (!t[p].rc)build(t[p].rc);
insert(t[p].rc, mid + 1, r, val);
}
pushup(p);
}
int n;
ll ans[N];
void dfs(int x, int pre) {
build(root[x]);
insert(root[x], 1, n, a[x]);
for (int i = head[x]; i; i = e[i].next) {
int y = e[i].to;
if (y == pre)continue;
dfs(y, x);
root[x] = merge(root[x], root[y], 1, n);
}
ans[x] = t[root[x]].dat.Y;
}
void solve() {
n = rd();
for (int i = 1; i <= n; i++)a[i] = rd();
for (int i = 1; i < n; i++) {
int u = rd(), v = rd();
ade(u, v, 1), ade(v, u, 1);
}
dfs(1, 0);
for (int i = 1; i <= n; i++)cout << ans[i] << ' ';
cout << endl;
}
int main() {
int _T = 1;
// _T = rd();
while (_T--)solve();
}
```</int,int></ll,></ll,></bits></int,int></ll,></ll,></bits></int,int></ll,></ll,></bits>
阿里云成长空间 733人发布