首页 > 试题广场 >

最大最小路

[编程题]最大最小路
  • 热度指数:1037 时间限制:C/C++ 2秒,其他语言4秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解
\hspace{15pt}对于给定的无向无根树,第 i 个节点上有一个权值 w_i 。我们定义一条简单路径是好的,当且仅当:路径上的点的点权最小值小于等于 a ,路径上的点的点权最大值大于等于 b

\hspace{15pt}保证给定的 a < b ,你需要计算有多少条简单路径是好的。

输入描述:
\hspace{15pt}第一行输入三个整数 n, a, b\left(1 \leq n \leq 5 \times 10^5, 1 \leq a < b \leq 10^9\right) 代表节点数、给定的上下限。

\hspace{15pt}第二行输入 n 个整数 w_1, w_2, \dots, w_n\left(1 \leq w_i \leq 10^9\right) 代表每个节点的权值。

\hspace{15pt}此后 n - 1 行,每行输入两个整数 u, v\left(1 \leq u, v \leq n, u \neq v\right) 代表一条无向边连接树上 uv 两个节点。


输出描述:
\hspace{15pt}在一行上输出一个整数,代表好路径的条数。
示例1

输入

5 2 3
5 4 3 3 1
1 2
1 3
3 4
3 5

输出

4

说明

\hspace{15pt}对于这个样例,如下图所示。路径 2 \to 1 \to 3 \to 5 是好的,因为路径点权最小值 1 \leqq a 且点权最大值 5 \geqq b


\hspace{15pt}除此之外,以下路径也是好的:
\hspace{23pt}\bullet\,1 \to 3 \to 5
\hspace{23pt}\bullet\,3 \to 5
\hspace{23pt}\bullet\,4 \to 3 \to 5
import sys
inps = sys.stdin.read().split('\n')

from collections import deque
def count(n, inSet, tree):
    total = 0
    visited = [False] * n
    for i in range(n):
        if not visited[i] and inSet[i]:
            que = deque()
            que.append(i)
            _n = 1
            visited[i] = True
            while que:
                u = que.popleft()
                for v in tree[u]:
                    if not visited[v] and inSet[v]:
                        _n += 1
                        que.append(v)
                        visited[v] = True
            total += _n * (_n - 1) // 2
    return total

def main():
    n, a, b = map(int, inps[0].split(' '))
    ws = list(map(int, inps[1].split(' ')))
    tree = [[] for _ in range(n)]
    for uv in inps[2:-1]:
        u, v = map(int, uv.split(' '))
        tree[u - 1].append(v - 1)
        tree[v - 1].append(u - 1)
    total = n * (n - 1) // 2
    
    inA = [w > a for w in ws]
    inB = [w < b for w in ws]
    inAB = [a < w < b for w in ws]

    cntA = count(n, inA, tree)
    cntB = count(n, inB, tree)
    cntAB = count(n, inAB, tree)
    oup = total - (cntA + cntB - cntAB)
    print(oup)


if __name__ == '__main__':
    main()

发表于 2025-07-09 09:41:26 回复(0)
// 总路径数
// 1...2, 1...3, 1...4, 1...5
// 2...3, 2...4, 2...5
// 3...4, 3...5
// 4...5
// 4 + 3 + 2 + 1 + 0 = n * (n - 1) / 2

//   min < a && max > b
// = 至少一个 < a && 至少一个 > b
// 但是不好统计,所有改用补集思路
//   所有路径 - 坏路径
// = 所有路径 - ((全部 > a) + (全部 < b) - (a < 全部 < b))

// 筛选出的节点不一定是一棵完整的树,需要先合并成n棵树再计算各个树上的路径


#include <cstdint>
#include <cstdio>
#include <vector>

using namespace std;

enum SECTION {
    NONE = 0,
    MORE_THAN_A = 1 << 0,
    LESS_THAN_B = 1 << 1
};


class UnionFind {
public:
    UnionFind(int n) {
        parent.resize(n);
        size.resize(n, 1);
        for (int i = 0; i < n; i++) {
            parent[i] = i;
        }
        return;
    }

    int find(int node) {
        while (parent[node] != node) {
            parent[node] = parent[parent[node]];
            node = parent[node];
        }
        return node;
    }

    void combine(int x, int y) {
        int rootx = find(x);
        int rooty = find(y);
        if (rootx == rooty) {
            return;
        }

        if (rootx > rooty) {
            parent[rooty] = rootx;
            size[rootx] += size[rooty];
        } else {
            parent[rootx] = rooty;
            size[rooty] += size[rootx];
        }
    }

    int64_t calculate() {
        int64_t ret = 0;

        for (int i = 0; i < parent.size(); i++) {
            if ((parent[i] == i) && size[i] > 1) {
                ret += (int64_t)size[i] * (size[i] - 1) / 2;
            }
        }
        return ret;
    }
private:
    vector<int> parent;
    vector<int> size;
};

int main() {
    int n = 0, a = 0, b = 0;
    scanf("%d %d %d", &n, &a, &b);

    int value;
    vector<int> flag(n);
    for (int i = 0; i < n; i++) {
        scanf("%d", &value);
        if (value > a) {
            flag[i] |= MORE_THAN_A;
        }
        if (value < b) {
            flag[i] |= LESS_THAN_B;
        }
    }

    vector<vector<int>> neighbour(n);
    for (int i = 0; i < n - 1; i++) {
        int u = 0, v = 0;
        scanf("%d %d", &u, &v);
        u--;v--;
        neighbour[u].push_back(v);
    }

    UnionFind bad1(n), bad2(n), bad3(n);
    for (int i = 0; i < n; i++) {
        for (auto j : neighbour[i]) {
            if ((flag[i] & MORE_THAN_A) && (flag[j] & MORE_THAN_A)) {
                bad1.combine(i, j);
            }

            if ((flag[i] & LESS_THAN_B) && (flag[j] & LESS_THAN_B)) {
                bad2.combine(i, j);
            }

            if ((flag[i] & MORE_THAN_A) && (flag[j] & MORE_THAN_A)
             && (flag[i] & LESS_THAN_B) && (flag[j] & LESS_THAN_B)) {
                bad3.combine(i, j);
            }
        }
    }

    int64_t badPath1 = bad1.calculate();
    int64_t badPath2 = bad2.calculate();
    int64_t badPath3 = bad3.calculate();
    int64_t totalPath = (uint64_t)n * (n - 1) / 2;

    printf("%ld\n", totalPath - badPath1 - badPath2 + badPath3);

    return 0;
}

发表于 2025-11-20 02:08:22 回复(0)
看不来lambda表达式,写个三个并查集的版本,不知道为啥时间空间复杂度都比lambda表达式好点
#include <iostream>
#include <stack>
#include <utility>
#include <vector>
using namespace std;

class union_find
{
public:
    vector<int> father_map;
    vector<int> size_map;


    union_find (int n): father_map(n+1), size_map(n+1, 1)
    {
        for (int i = 1; i<=n; ++i)
        {
            father_map[i] = i;
        }
    }

    int find_father(int i)
    {
        stack<int> path;
        while (i != father_map[i])
        {
            path.push(i);
            i = father_map[i];
        }
        while (!path.empty()) {
            father_map[path.top()] = i;
            path.pop();
        }
        return i;
    }

    void uinon_(int x, int y)
    {
        int fx = find_father(x);
        int fy = find_father(y);
        if (fx!=fy)
        {
            father_map[fx] = fy;
            size_map[fy] += size_map[fx];
        }
    }

};



int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n, a, b;
    cin >> n >> a >> b;

    vector<int> w(n+1);
    for (int i = 1; i<=n; ++i)
    {
        cin >> w[i];
    }
    vector<pair<int, int>> edges;
    for (int i = 0, u, v; i<n-1; ++i)
    {
        cin >> u >> v;
        edges.emplace_back(u,v);
    }

    //把所有点>a ,<b, >a&&<b的各建立一个并查集,然后按这三个条件合并,分别统计路径的数目,路径的数目就是组合数公式,从一个集合里挑选两个点组成路径。
    union_find uf1(n) , uf2(n), uf3(n);

    long long sum_a=0, sum_b=0, sum_ab=0;

    for (auto& [u, v] : edges)
    {
        if (w[u] > a && w[v] > a)
        {
            uf1.uinon_(u, v);
        }
        if (w[u] < b && w[v] < b) {
            uf2.uinon_(u, v);
        }
        if (w[u] > a && w[u] < b && w[v] > a && w[v] < b) {
            uf3.uinon_(u, v);
        }
    }

    long long total = (long long)(n-1)*n/2;

    for (int i = 1; i<=n; ++i)
    {
        if (uf1.father_map[i] == i && uf1.size_map[i] >= 2)
        {
            sum_a += (long long)uf1.size_map[i]*(uf1.size_map[i]-1)/2;
        }
        if (uf2.father_map[i] == i && uf2.size_map[i] >= 2)
        {
            sum_b += (long long)uf2.size_map[i]*(uf2.size_map[i]-1)/2;
        }
        if (uf3.father_map[i] == i && uf3.size_map[i] >= 2)
        {
            sum_ab += (long long)uf3.size_map[i]*(uf3.size_map[i]-1)/2;
        }
    }

    cout << total - sum_a -sum_b + sum_ab;

}


发表于 2025-09-09 19:35:02 回复(0)
为啥使用深度优先搜索只能通过示例和一个例子
line = [int(item) for item in input().split(" ")]
node_num = line[0]
min_num = line[1]
max_num = line[2]
node_val = [int(item) for item in input().split(" ")]
node_tu = [[0] * node_num for i in range(node_num)]
for i in range(0, node_num - 1):
    line = [int(item) for item in input().split(" ")]
    node_tu[line[0]-1][line[1]-1] = 1
    node_tu[line[1]-1][line[0]-1] = 1

# print(node_tu)

resa = []


def get_next_path(node_tu, searched_path):
    node = searched_path[-1]
    res = []
    for i in range(node_num):
        if node_tu[node - 1][i] == 1 and i + 1 not in searched_path:
            res.append(i + 1)
    return res


def search_path(node_tu, searched_path, res):
    next_path = get_next_path(node_tu, searched_path)
    if not next_path:
        val_list = []
        for item in searched_path:
            val_list.append(node_val[item - 1])
        if min(val_list) <= min_num and max(val_list) >= max_num:
            temp = list(reversed(searched_path))
            if searched_path not in res and temp not in res:
                res.append(searched_path)
    else:
        for node in next_path:
            search_path(node_tu, searched_path + [node], res)


for node in range(1,node_num+1):
    search_path(node_tu, [node,], resa)

print(len(resa))

发表于 2025-03-08 12:23:03 回复(1)