第一行输入三个整数
代表节点数、给定的上下限。
第二行输入
个整数
代表每个节点的权值。
此后
行,每行输入两个整数
代表一条无向边连接树上
和
两个节点。
在一行上输出一个整数,代表好路径的条数。
5 2 3 5 4 3 3 1 1 2 1 3 3 4 3 5
4
对于这个样例,如下图所示。路径
是好的,因为路径点权最小值
且点权最大值
。
除此之外,以下路径也是好的:
;
;
。
import sys
inps = sys.stdin.read().split('\n')
from collections import deque
def count(n, inSet, tree):
total = 0
visited = [False] * n
for i in range(n):
if not visited[i] and inSet[i]:
que = deque()
que.append(i)
_n = 1
visited[i] = True
while que:
u = que.popleft()
for v in tree[u]:
if not visited[v] and inSet[v]:
_n += 1
que.append(v)
visited[v] = True
total += _n * (_n - 1) // 2
return total
def main():
n, a, b = map(int, inps[0].split(' '))
ws = list(map(int, inps[1].split(' ')))
tree = [[] for _ in range(n)]
for uv in inps[2:-1]:
u, v = map(int, uv.split(' '))
tree[u - 1].append(v - 1)
tree[v - 1].append(u - 1)
total = n * (n - 1) // 2
inA = [w > a for w in ws]
inB = [w < b for w in ws]
inAB = [a < w < b for w in ws]
cntA = count(n, inA, tree)
cntB = count(n, inB, tree)
cntAB = count(n, inAB, tree)
oup = total - (cntA + cntB - cntAB)
print(oup)
if __name__ == '__main__':
main() // 总路径数
// 1...2, 1...3, 1...4, 1...5
// 2...3, 2...4, 2...5
// 3...4, 3...5
// 4...5
// 4 + 3 + 2 + 1 + 0 = n * (n - 1) / 2
// min < a && max > b
// = 至少一个 < a && 至少一个 > b
// 但是不好统计,所有改用补集思路
// 所有路径 - 坏路径
// = 所有路径 - ((全部 > a) + (全部 < b) - (a < 全部 < b))
// 筛选出的节点不一定是一棵完整的树,需要先合并成n棵树再计算各个树上的路径
#include <cstdint>
#include <cstdio>
#include <vector>
using namespace std;
enum SECTION {
NONE = 0,
MORE_THAN_A = 1 << 0,
LESS_THAN_B = 1 << 1
};
class UnionFind {
public:
UnionFind(int n) {
parent.resize(n);
size.resize(n, 1);
for (int i = 0; i < n; i++) {
parent[i] = i;
}
return;
}
int find(int node) {
while (parent[node] != node) {
parent[node] = parent[parent[node]];
node = parent[node];
}
return node;
}
void combine(int x, int y) {
int rootx = find(x);
int rooty = find(y);
if (rootx == rooty) {
return;
}
if (rootx > rooty) {
parent[rooty] = rootx;
size[rootx] += size[rooty];
} else {
parent[rootx] = rooty;
size[rooty] += size[rootx];
}
}
int64_t calculate() {
int64_t ret = 0;
for (int i = 0; i < parent.size(); i++) {
if ((parent[i] == i) && size[i] > 1) {
ret += (int64_t)size[i] * (size[i] - 1) / 2;
}
}
return ret;
}
private:
vector<int> parent;
vector<int> size;
};
int main() {
int n = 0, a = 0, b = 0;
scanf("%d %d %d", &n, &a, &b);
int value;
vector<int> flag(n);
for (int i = 0; i < n; i++) {
scanf("%d", &value);
if (value > a) {
flag[i] |= MORE_THAN_A;
}
if (value < b) {
flag[i] |= LESS_THAN_B;
}
}
vector<vector<int>> neighbour(n);
for (int i = 0; i < n - 1; i++) {
int u = 0, v = 0;
scanf("%d %d", &u, &v);
u--;v--;
neighbour[u].push_back(v);
}
UnionFind bad1(n), bad2(n), bad3(n);
for (int i = 0; i < n; i++) {
for (auto j : neighbour[i]) {
if ((flag[i] & MORE_THAN_A) && (flag[j] & MORE_THAN_A)) {
bad1.combine(i, j);
}
if ((flag[i] & LESS_THAN_B) && (flag[j] & LESS_THAN_B)) {
bad2.combine(i, j);
}
if ((flag[i] & MORE_THAN_A) && (flag[j] & MORE_THAN_A)
&& (flag[i] & LESS_THAN_B) && (flag[j] & LESS_THAN_B)) {
bad3.combine(i, j);
}
}
}
int64_t badPath1 = bad1.calculate();
int64_t badPath2 = bad2.calculate();
int64_t badPath3 = bad3.calculate();
int64_t totalPath = (uint64_t)n * (n - 1) / 2;
printf("%ld\n", totalPath - badPath1 - badPath2 + badPath3);
return 0;
} #include <iostream>
#include <stack>
#include <utility>
#include <vector>
using namespace std;
class union_find
{
public:
vector<int> father_map;
vector<int> size_map;
union_find (int n): father_map(n+1), size_map(n+1, 1)
{
for (int i = 1; i<=n; ++i)
{
father_map[i] = i;
}
}
int find_father(int i)
{
stack<int> path;
while (i != father_map[i])
{
path.push(i);
i = father_map[i];
}
while (!path.empty()) {
father_map[path.top()] = i;
path.pop();
}
return i;
}
void uinon_(int x, int y)
{
int fx = find_father(x);
int fy = find_father(y);
if (fx!=fy)
{
father_map[fx] = fy;
size_map[fy] += size_map[fx];
}
}
};
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, a, b;
cin >> n >> a >> b;
vector<int> w(n+1);
for (int i = 1; i<=n; ++i)
{
cin >> w[i];
}
vector<pair<int, int>> edges;
for (int i = 0, u, v; i<n-1; ++i)
{
cin >> u >> v;
edges.emplace_back(u,v);
}
//把所有点>a ,<b, >a&&<b的各建立一个并查集,然后按这三个条件合并,分别统计路径的数目,路径的数目就是组合数公式,从一个集合里挑选两个点组成路径。
union_find uf1(n) , uf2(n), uf3(n);
long long sum_a=0, sum_b=0, sum_ab=0;
for (auto& [u, v] : edges)
{
if (w[u] > a && w[v] > a)
{
uf1.uinon_(u, v);
}
if (w[u] < b && w[v] < b) {
uf2.uinon_(u, v);
}
if (w[u] > a && w[u] < b && w[v] > a && w[v] < b) {
uf3.uinon_(u, v);
}
}
long long total = (long long)(n-1)*n/2;
for (int i = 1; i<=n; ++i)
{
if (uf1.father_map[i] == i && uf1.size_map[i] >= 2)
{
sum_a += (long long)uf1.size_map[i]*(uf1.size_map[i]-1)/2;
}
if (uf2.father_map[i] == i && uf2.size_map[i] >= 2)
{
sum_b += (long long)uf2.size_map[i]*(uf2.size_map[i]-1)/2;
}
if (uf3.father_map[i] == i && uf3.size_map[i] >= 2)
{
sum_ab += (long long)uf3.size_map[i]*(uf3.size_map[i]-1)/2;
}
}
cout << total - sum_a -sum_b + sum_ab;
} line = [int(item) for item in input().split(" ")]
node_num = line[0]
min_num = line[1]
max_num = line[2]
node_val = [int(item) for item in input().split(" ")]
node_tu = [[0] * node_num for i in range(node_num)]
for i in range(0, node_num - 1):
line = [int(item) for item in input().split(" ")]
node_tu[line[0]-1][line[1]-1] = 1
node_tu[line[1]-1][line[0]-1] = 1
# print(node_tu)
resa = []
def get_next_path(node_tu, searched_path):
node = searched_path[-1]
res = []
for i in range(node_num):
if node_tu[node - 1][i] == 1 and i + 1 not in searched_path:
res.append(i + 1)
return res
def search_path(node_tu, searched_path, res):
next_path = get_next_path(node_tu, searched_path)
if not next_path:
val_list = []
for item in searched_path:
val_list.append(node_val[item - 1])
if min(val_list) <= min_num and max(val_list) >= max_num:
temp = list(reversed(searched_path))
if searched_path not in res and temp not in res:
res.append(searched_path)
else:
for node in next_path:
search_path(node_tu, searched_path + [node], res)
for node in range(1,node_num+1):
search_path(node_tu, [node,], resa)
print(len(resa))