题解 | 小红的01子序列构造(hard)

小红的01子序列构造(hard)

https://www.nowcoder.com/practice/c3c222f5b0a54966a6df34637291f4ce

import sys
import math

sys.setrecursionlimit(300000)


def solve():
    input = sys.stdin.readline
    n, m = map(int, input().split())
    cons = []
    for _ in range(m):
        l, r, x, y, k = map(int, input().split())
        cons.append((l, r, x, y, k))

    # 按左端点排序,左相同右降序(保证父区间在前)
    cons.sort(key=lambda x: (x[0], -x[1]))

    # 建树:每个区间的父节点是第一个完全包含它的区间
    parent = [-1] * m
    stack = []
    for i, (l, r, x, y, k) in enumerate(cons):
        while stack and cons[stack[-1]][1] < r:
            stack.pop()
        if stack:
            parent[i] = stack[-1]
        stack.append(i)

    children = [[] for _ in range(m)]
    for i in range(m):
        if parent[i] != -1:
            children[parent[i]].append(i)

    ans = [""] * n

    # ------------------------------------------------------------
    # 构造一段连续的区间(无子区间约束),位置从 L 到 R(1‑based)
    # 需要恰好 zeros 个 0,且内部 01 子序列数等于 target_k
    def build_segment(L, R, zeros, target_k):
        length = R - L + 1
        ones = length - zeros
        if target_k < 0 or target_k > zeros * ones:
            return False
        arr = ["0"] * length
        cnt0 = zeros
        need = target_k
        rem_ones = ones
        # 从右往左放 1
        for pos in range(length - 1, -1, -1):
            if rem_ones == 0:
                break
            if need >= cnt0:
                arr[pos] = "1"
                rem_ones -= 1
                need -= cnt0
            if arr[pos] == "0":
                cnt0 -= 1
        if need != 0 or rem_ones != 0:
            return False
        for i in range(length):
            ans[L - 1 + i] = arr[i]
        return True

    # ------------------------------------------------------------

    def solve_interval(idx):
        l, r, x, y, k = cons[idx]
        childs = children[idx]

        if not childs:
            # 叶子区间:直接构造
            return build_segment(l, r, x, k)

        # 有孩子区间
        childs.sort(key=lambda i: cons[i][0])

        # 计算各部分长度
        parts = (
            []
        )  # 每个元素: ('left', length) 或 ('child', length, child_idx) 或 ('right', length)
        prev_r = l - 1
        for ci in childs:
            cl, cr, cx, cy, ck = cons[ci]
            left_len = cl - (prev_r + 1)
            if left_len > 0:
                parts.append(("left", left_len))
            parts.append(("child", cr - cl + 1, ci))
            prev_r = cr
        right_len = r - prev_r
        if right_len > 0:
            parts.append(("right", right_len))

        # 孩子区间总和
        x2_sum = sum(cons[ci][2] for ci in childs)
        y2_sum = sum(cons[ci][3] for ci in childs)
        k2_sum = sum(cons[ci][4] for ci in childs)

        # 计算左右空白段总长度
        left_len_total = 0
        right_len_total = 0
        for item in parts:
            if item[0] == "left":
                left_len_total += item[1]
            elif item[0] == "right":
                right_len_total += item[1]

        # 剩余可分配的 0 和 1
        U = x - x2_sum  # 剩余 0 的个数
        V = y - y2_sum  # 剩余 1 的个数
        if U < 0 or V < 0 or U + V != left_len_total + right_len_total:
            return False

        # 系数准备
        D = y2_sum + right_len_total - U + x2_sum
        C0 = k2_sum + x2_sum * (right_len_total - U)

        # 辅助函数:求二次不等式 a^2 + B*a + C <= 0 的整数解区间
        def solve_quad_leq(B, C, a_min, a_max):
            # 解 a^2 + B*a + C <= 0,a 在 [a_min, a_max] 内
            # 返回 (L, R) 闭区间,若无解返回 None
            # 使用浮点数近似求根,再检查附近整数
            delta = B * B - 4 * C
            if delta < 0:
                return None
            sqrt_delta = math.sqrt(delta)
            # 左根和右根(实数)
            left_root = (-B - sqrt_delta) / 2.0
            right_root = (-B + sqrt_delta) / 2.0
            # 考虑整数范围,扩展一点以防浮点误差
            l_cand = max(a_min, int(math.floor(left_root)) - 2)
            r_cand = min(a_max, int(math.ceil(right_root)) + 2)
            # 收集所有满足不等式的整数
            good = []
            for a in range(l_cand, r_cand + 1):
                if a * a + B * a + C <= 0:
                    good.append(a)
            if not good:
                return None
            return (good[0], good[-1])

        # 第一个不等式:a^2 + D*a + (C0 - k) <= 0
        interval1 = solve_quad_leq(D, C0 - k, 0, left_len_total)
        if interval1 is None:
            return False

        # 第二个不等式:a^2 - (D + left_len_total + 2*U - right_len_total)*a + (k - C0 - U*(right_len_total - U)) <= 0
        B2 = -(
            D + left_len_total + 2 * U - right_len_total
        )  # 注意我们的公式是 a^2 + B2*a + C2 <= 0,所以 B2 是负的
        C2 = k - C0 - U * (right_len_total - U)
        interval2 = solve_quad_leq(B2, C2, 0, left_len_total)
        if interval2 is None:
            return False

        # 取交集
        L1, R1 = interval1
        L2, R2 = interval2
        L = max(L1, L2)
        R = min(R1, R2)
        if L > R:
            return False

        # 还需要考虑 a 的实际可行域:由 b = U - a 必须在 [0, right_len_total]
        a_low = max(L, 0, U - right_len_total)
        a_high = min(R, left_len_total, U)
        if a_low > a_high:
            return False

        # 任选一个 a,取最小的
        a = a_low
        b = U - a
        # 检查 b 范围(已保证)

        # 计算左右内部需要的 01 数
        F = a * a + D * a + C0
        target = k - F  # 必须等于 kL + kR
        # target 应该在 [0, a*(left_len_total - a) + b*(right_len_total - b)] 内(由不等式保证)

        # 分配 kL 和 kR
        max_kL = a * (left_len_total - a)
        max_kR = b * (right_len_total - b)
        # 取 kL 尽量小或尽量大,这里取尽量小(0)?但 target 可能大于0,需要合理分配
        # 简单取 kL = min(max_kL, target),则 kR = target - kL,只要 kR <= max_kR(由 target <= max_kL+max_kR 保证)
        kL = min(max_kL, target)
        kR = target - kL
        # 确保 kR 在范围内(应该成立)
        if kR < 0 or kR > max_kR:
            # 若不行,尝试另一种分配
            kR = min(max_kR, target)
            kL = target - kR
            if kL < 0 or kL > max_kL:
                return False

        # 现在按 parts 顺序构造
        pos = l - 1
        for item in parts:
            if item[0] == "left":
                length = item[1]
                seg_l = pos + 1
                seg_r = pos + length
                if not build_segment(seg_l, seg_r, a, kL):
                    return False
                a = 0  # 已用完
                pos += length
            elif item[0] == "child":
                length, ci = item[1], item[2]
                if not solve_interval(ci):
                    return False
                pos += length
            else:  # 'right'
                length = item[1]
                seg_l = pos + 1
                seg_r = pos + length
                if not build_segment(seg_l, seg_r, b, kR):
                    return False
                b = 0
                pos += length
        return True

    # 处理森林
    roots = [i for i in range(m) if parent[i] == -1]
    for ri in roots:
        if not solve_interval(ri):
            print(-1)
            return

    # 补全未填位置(填0)
    for i in range(n):
        if ans[i] == "":
            ans[i] = "0"
    print("".join(ans))


if __name__ == "__main__":
    solve()

思路不算复杂,但是解决起来感觉很复杂,其实我还是对使用深度优先搜索有点疑惑,代码我也打算之后再磨一磨

全部评论

相关推荐

评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务