codeforces 600E(dsu on tree or 线段树合并)

PS:昨天晚上,已经学了两三天的后缀自动机的我,大致有点明白但是还是不是特别明白...心情有点烦躁。于是索性就去整理了一下自己有哪些还是不会的算法,于是便发现了dsu on tree和线段树合并这两个算法是我很久以前就想学的,但是因为时间有限还是一直拖到了现在...今天下午来实验室,打算先学dsu on tree,就先在网上看了几篇博客后,大致就懂了,还去特地做掉了cf上大佬们推荐的所谓的模板题发现的确不难,是个套路算法-_-。ac那题后,我点开了standing,发现那道题几乎我的每个好友都做了。(心想看来的确是模板题实锤了)这时发现我的一个队友以前也做了这道题,点开了她的代码后,发现为啥和我的做法丝毫不同?? 大致看了下好像还正是我等会要学的线段树合并算法???,后面继续百度了一下发现这题的确可以用线段树合并做。于是便又借助了蓝书上的线段树合并模板成功又ac了一次。 难道这真的是巧合吗??? 图片说明

###题意:给出一棵有根树(1为根),树上每个点都有一个权值代表不同的颜色。求出以每个点为根的子树里出现次数最多的颜色之和。 ###解法1:dsu on tree 就是先预处理出每个结点的重儿子。然后就是每次先递归进轻儿子,完了先算重儿子的贡献,再算轻儿子的贡献。轻儿子的贡献算完后要消除影响,但是重儿子不消除。这一点就是dsu on tree的核心,通过这一点可以被证明复杂度是O(nlog n),(?) 大致框架:

void dsu(int x,int f){
    for(int i=head[x];i!=-1;i=edge[i].next){
        int y = edge[i].e;
        if(y==fat[x]||y==son[x])continue;
        dsu(y,0);
    }//暴力计算轻儿子
    if(son[x])dsu(son[x],1);//计算重儿子
    for(int i=head[x];i!=-1;i=edge[i].next){
        int y = edge[i].e;
        if(y==fat[x]||y==son[x])continue;
        ans[x]+=cal(y);//计算轻儿子对点x的贡献
        add(y);//添加轻儿子状态
    }
    XXXXXXX;//添加重儿子状态
    if(!f)del(x);//删除轻儿子状态
}

AC代码:

#include<bits stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
#define X first
#define Y second
#define pb push_back
#define pll pair<ll, ll>
#define pli pair<ll, int>
#define pii pair<int,int>
#define New_Time srand((unsigned)time(NULL))
inline ll gcd(ll a, ll b) { while (b != 0) { ll c = a % b; a = b; b = c; }return a &lt; 0 ? -a : a; }
inline ll lowbit(ll x) { return x &amp; (-x); }
int head[2000010], Edge_Num;
struct Edge { int to, next; ll w; }e[4000010];
inline void ade(int x, int y, ll w) { e[++Edge_Num] = { y,head[x],w }; head[x] = Edge_Num; }
inline void G_init(int n) { memset(head, 0, sizeof(int) * (n + 100)); Edge_Num = 0; }
int dir[8][2] = { {-1,0},{0,-1},{-1,-1},{1,-1},{1,0},{0,1},{1,1},{-1,1} };
const long double PI = 3.14159265358979323846;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
inline ll rd() {
	ll x = 0; bool f = 1; char ch = getchar();
	while (ch&lt;'0' || ch&gt;'9') { if (ch == '-')f = 0; ch = getchar(); }
	while (ch &gt;= '0' &amp;&amp; ch &lt;= '9') { x = (x &lt;&lt; 3) + (x &lt;&lt; 1) + (ch ^ 48); ch = getchar(); }
	return f ? x : -x;
}
const double eps = 1e-8;
const ll mod = 1e9 + 7;
const int M = 1e6 + 10;
const int N = 1e6 + 10;
int hson[N], siz[N];
void dfs(int x, int pre) {
	siz[x] = 1;
	for (int i = head[x]; i; i = e[i].next) {
		int y = e[i].to;
		if (y == pre)continue;
		dfs(y, x);
		siz[x] += siz[y];
		if (!hson[x] || siz[y] &gt; siz[hson[x]])hson[x] = y;
	}
}
ll now;
int maxx;
int cnt[N], a[N];
int Son;
ll ans[N];
void add(int x,int pre,int val) {
	cnt[a[x]] += val;
	if (cnt[a[x]] &gt; maxx) {
		now = a[x];
		maxx = cnt[a[x]];
	}
	else if (cnt[a[x]] == maxx) {
		now += a[x];
	}
	for (int i = head[x]; i; i = e[i].next) {
		int y = e[i].to;
		if (y == pre || y == Son)continue;
		add(y, x, val);
	}
}
void gao(int x, int pre,bool op) {
	for (int i = head[x]; i; i = e[i].next) {
		int y = e[i].to;
		if (y == pre||y==hson[x])continue;
		gao(y, x, 0);
	}
	if (hson[x])gao(hson[x], x, 1), Son = hson[x];
	add(x, pre, 1); Son = 0;
	ans[x] = now;
	if (!op)add(x, pre, -1), now = 0, maxx = 0;
}
void solve() {
	int n = rd();
	for (int i = 1; i &lt;= n; i++)a[i] = rd();
	for (int i = 1; i &lt; n; i++) {
		int u = rd(), v = rd();
		ade(u, v, 1), ade(v, u, 1);
	}
	dfs(1, 0);
	gao(1, 0, 0);
	for (int i = 1; i &lt;= n; i++) {
		cout &lt;&lt; ans[i] &lt;&lt; ' ';
	}
	cout &lt;&lt; endl;
}
int main() {
	int _T = 1;
	//	_T = rd();
	while (_T--)solve();
}

##另一道dsu on tree的题目但是我觉得是写法上有一点差异 题目链接:https://ac.nowcoder.com/acm/contest/4853/E

##AC代码

#include<bits stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
#define X first
#define Y second
#define pb push_back
#define pll pair<ll, ll>
#define pli pair<ll, int>
#define pii pair<int,int>
#define New_Time srand((unsigned)time(NULL))
inline ll gcd(ll a, ll b) { while (b != 0) { ll c = a % b; a = b; b = c; }return a &lt; 0 ? -a : a; }
inline ll lowbit(ll x) { return x &amp; (-x); }
int head[2000010], Edge_Num;
struct Edge { int to, next; ll w; }e[4000010];
inline void ade(int x, int y, ll w) { e[++Edge_Num] = { y,head[x],w }; head[x] = Edge_Num; }
inline void G_init(int n) { memset(head, 0, sizeof(int) * (n + 100)); Edge_Num = 0; }
int dir[8][2] = { {-1,0},{0,-1},{-1,-1},{1,-1},{1,0},{0,1},{1,1},{-1,1} };
const long double PI = 3.14159265358979323846;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
inline ll rd() {
	ll x = 0; bool f = 1; char ch = getchar();
	while (ch&lt;'0' || ch&gt;'9') { if (ch == '-')f = 0; ch = getchar(); }
	while (ch &gt;= '0' &amp;&amp; ch &lt;= '9') { x = (x &lt;&lt; 3) + (x &lt;&lt; 1) + (ch ^ 48); ch = getchar(); }
	return f ? x : -x;
}
const double eps = 1e-8;
const ll mod = 1e9 + 7;
const int M = 1e6 + 10;
const int N = 1e6 + 10;
int hson[N], siz[N], d[N], fa[N];
void dfs(int x) {
	siz[x] = 1;
	for (int i = head[x]; i; i = e[i].next) {
		int y = e[i].to;
		if (y == fa[y])continue;
		d[y] = d[x] + 1;
		fa[y] = x;
		dfs(y);
		siz[x] += siz[y];
		if (!hson[x] || siz[y] &gt; siz[hson[x]])hson[x] = y;
	}
}
ll cnt[N], a[N], num[N];
ll ans[N];
int n, k;
void add(int x,int f) {
	cnt[d[x]] += f * a[x];
	num[d[x]] += f;
	for (int i = head[x]; i; i = e[i].next) {
		int y = e[i].to;
		if (y == fa[x])continue;
		add(y, f);
	}
}
void cal(int x, int rt) {
	int v = 2 * d[rt] + k - d[x];
	if (v &gt;= 0 &amp;&amp; v &lt;= n)ans[rt] += cnt[v] + a[x] * num[v];
	for (int i = head[x]; i; i = e[i].next) {
		int y = e[i].to;
		if (y == fa[x])continue;
		cal(y, rt);
	}
}
void dsu(int x,bool op) {
	for (int i = head[x]; i; i = e[i].next) {
		int y = e[i].to;
		if (y == fa[x] || y == hson[x])continue;
		dsu(y, 0);
	}
	if (hson[x])dsu(hson[x], 1);
	for (int i = head[x]; i; i = e[i].next) {
		int y = e[i].to;
		if (y == hson[x] || y == fa[x])continue;
		cal(y, x);
		add(y, 1);
	}
	cnt[d[x]] += a[x];
	num[d[x]]++;
	if (!op)add(x, -1);
}
void solve() {
	n = rd(), k = rd();
	for (int i = 1; i &lt;= n; i++)a[i] = rd();
	for (int i = 1; i &lt; n; i++) {
		int u = rd(), v = rd();
		ade(u, v, 1), ade(v, u, 1);
	}
	dfs(1);
	dsu(1, 0);
	for (int i = 1; i &lt;= n; i++) {
		cout &lt;&lt; ans[i] &lt;&lt; ' ';
	}
	cout &lt;&lt; endl;
}
int main() {
	int _T = 1;
	//	_T = rd();
	while (_T--)solve();
}

##解法2 线段树合并 AC代码:

#include<bits stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
#define X first
#define Y second
#define pb push_back
#define pll pair<ll, ll>
#define pli pair<ll, int>
#define pii pair<int,int>
#define New_Time srand((unsigned)time(NULL))
inline ll gcd(ll a, ll b) { while (b != 0) { ll c = a % b; a = b; b = c; }return a &lt; 0 ? -a : a; }
inline ll lowbit(ll x) { return x &amp; (-x); }
int head[2000010], Edge_Num;
struct Edge { int to, next; ll w; }e[4000010];
inline void ade(int x, int y, ll w) { e[++Edge_Num] = { y,head[x],w }; head[x] = Edge_Num; }
inline void G_init(int n) { memset(head, 0, sizeof(int) * (n + 100)); Edge_Num = 0; }
int dir[8][2] = { {-1,0},{0,-1},{-1,-1},{1,-1},{1,0},{0,1},{1,1},{-1,1} };
const long double PI = 3.14159265358979323846;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
inline ll rd() {
	ll x = 0; bool f = 1; char ch = getchar();
	while (ch&lt;'0' || ch&gt;'9') { if (ch == '-')f = 0; ch = getchar(); }
	while (ch &gt;= '0' &amp;&amp; ch &lt;= '9') { x = (x &lt;&lt; 3) + (x &lt;&lt; 1) + (ch ^ 48); ch = getchar(); }
	return f ? x : -x;
}
const double eps = 1e-8;
const ll mod = 1e9 + 7;
const int M = 1e6 + 10;
const int N = 1e6 + 10;
struct SegmentTree {
	int lc, rc;
	pll dat;
}t[N &lt;&lt; 2];
int root[N], tot, a[N];
pll get(pll A, pll B) {
	if (A.X &gt; B.X)return A;
	else if (B.X &gt; A.X)return B;
	else {

		A.Y += B.Y;
		return A;
	}
}
void build(int&amp; now) {
	now = ++tot;
}
void pushup(int p) {
	t[p].dat = get(t[t[p].lc].dat, t[t[p].rc].dat);
}
int merge(int p, int q, int l, int r) {
	if (!p)return q;
	if (!q)return p;
	if (l == r) {
		t[p].dat.X += t[q].dat.X;
		return p;
	}
	int mid = (l + r) &gt;&gt; 1;
	t[p].lc = merge(t[p].lc, t[q].lc, l, mid);
	t[p].rc = merge(t[p].rc, t[q].rc, mid + 1, r);
	pushup(p);
	return p;
}
void insert(int p,int l, int r, int val) {
	if (l == r) {
		t[p].dat.X++;
		t[p].dat.Y = l;
		return;
	}
	int mid = (l + r) &gt;&gt; 1;
	if (val &lt;= mid) {
		if (!t[p].lc)build(t[p].lc);
		insert(t[p].lc, l, mid, val);
	}
	else {
		if (!t[p].rc)build(t[p].rc);
		insert(t[p].rc, mid + 1, r, val);
	}
	pushup(p);
}
int n;
ll ans[N];
void dfs(int x, int pre) {
	build(root[x]);
	insert(root[x], 1, n, a[x]);
	for (int i = head[x]; i; i = e[i].next) {
		int y = e[i].to;
		if (y == pre)continue;
		dfs(y, x);
		root[x] = merge(root[x], root[y], 1, n);
	}
	ans[x] = t[root[x]].dat.Y;
}
void solve() {
	n = rd();
	for (int i = 1; i &lt;= n; i++)a[i] = rd();
	for (int i = 1; i &lt; n; i++) {
		int u = rd(), v = rd();
		ade(u, v, 1), ade(v, u, 1);
	}
	dfs(1, 0);
	for (int i = 1; i &lt;= n; i++)cout &lt;&lt; ans[i] &lt;&lt; ' ';
	cout &lt;&lt; endl;
}
int main() {
	int _T = 1;
	//	_T = rd();
	while (_T--)solve();
}
```</int,int></ll,></ll,></bits></int,int></ll,></ll,></bits></int,int></ll,></ll,></bits>
全部评论

相关推荐

评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务