Skip to content

ABC370F Cake Division

原题链接:F - Cake Division

Tag:二分、倍增

题目描述

给定一个长度为 \(N\) 的环,数值为 \(A_i\), 现在要把这个环分成 \(K\) 段。 记每一段的总和为 \(w\), 求 \(\min(w_1,w_2,\cdots,w_K)\) 的最大值, 以及满足条件的分割中从未被切割的切割线的数量。

数据说明:

  • \(2 \leq K \leq N \leq 2 \times 10^5\)
  • \(1 \leq A_i \leq 10^4\)

分析

看到我们要求一个最小值的最大值, 很容易想到要二分答案, 对于这个题而言,记我们要二分出来一个 \(mid\)。 这个 \(mid\) 的意义为,如果我让每一个区间都至少总和为 \(mid\), 能否成功把这个环分成 \(K\) 个区间。

问题的难点转化成了如何在 \(check\) 中快速判断在规定区间最小和的情况下能不能分成 \(K\) 个区间, 这个问题可以用倍增来优化。

首先把环看成链, 我们定义一个 \(nxt_{i,j}\) 代表从 \(i\) 位置开始向后选择了 \(2^j\) 个满足条件的区间之后的位置, 对于 \(nxt_{i,0}\) 我们将其设置为从位置 \(i\) 开始的下一个满足区间元素之和 \(\geq k\) 的位置, 这个地方可以用前缀和进行优化。 对于 \(nxt_{i,j}\) 我们可以这样更新 \(nxt[i][j] = nxt[nxt[i][j - 1]][j - 1]\), 其意义为我们从 \(i\) 跳了 \(2^{j-1}\) 次的位置再跳 \(2^{j-1}\) 次, 就是从 \(i\) 跳了 \(2^j\) 的位置。

在上述处理的过程中可能出现跳到了大于 \(n\) 的位置的情况, 所以我们可以把环拆成链之后再拼接一个完全一致的链上去来解决这种问题。

我们要实现的是检查能否分成 \(K\) 个区间, 所以我们维护一个 \(cur\), 对于每一个 \(K\) 的二进制位,我们用维护的倍增去跳, 就可以快速判断是否能分成 \(K\) 个区间了。

第一问做完了,考虑第二问, 一共有 \(N\) 个位置可以切, 我们在 \(check\) 的过程中记录一个变量 \(sum\), 如果我们从 \(i\) 个点开始跳可以有一个满足的切割, 就给 \(sum += 1\) 这样最后的答案就是 \(N - sum\)

上述做法中,二分答案的复杂度为 \(O(logV)\), 每次 \(check\) 的倍增复杂度是 \(O(nlogn)\) 的, 总体复杂度为 \(O(nlognlogV)\)

代码实现的过程中注意可以通过 \(nxt\) 的维度来减少 \(cache\) 丢失次数优化运行时间, 下文附上两版代码,代码不同之处在于 \(nxt\) 数组。

代码实现

代码一(446ms)

int nxt[22][N];
void NeverSayNever() {
    int n, k;
    cin >> n >> k;
    vector<int> vec(2 * n + 2), pre(2 * n + 2);
    for (int i = 1; i <= n; ++i) {
        cin >> vec[i];
        vec[i + n] = vec[i];
    }
    for (int i = 1; i <= 2 * n; ++i) {
        pre[i] = pre[i - 1] + vec[i];
    }

    int ans = 0;
    auto check = [&](int x) -> bool {
        for (int i = 1, j = 1; i <= 2 * n; ++i) {
            while (j <= 2 * n && pre[j] - pre[i - 1] < x) j++;
            nxt[0][i] = j + 1;
        }
        for (int j = 1; j < 22; ++j) {
            for (int i = 1; i <= 2 * n; ++i) {
                if (nxt[j - 1][i] > 2 * n) {
                    nxt[j][i] = nxt[j - 1][i];
                } else {
                    nxt[j][i] = nxt[j-1][nxt[j-1][i]];
                }
            }
        }
        bool flag = false;
        int sum = 0;
        for (int i = 1; i <= n; ++i) {
            int cur = i;
            for (int j = 21; j >= 0; --j) {
                if ((k >> j) & 1) {
                    cur = nxt[j][cur];
                }
                if (cur > 2 * n) break;
            }
            if (cur - n <= i && cur <= 2 * n) {
                flag = true;
                sum++;
            }
        }
        if (flag) ans = n - sum;
        return flag;
    };

    int L = 1, R = INT_MAX;
    while (L + 1 < R) {
        int mid = (L + R) >> 1;
        if (check(mid)) L = mid;
        else R = mid;
    }
    cout << L << ' ' << ans << endl;
}

代码二(2655ms)

int nxt[N][22];
void NeverSayNever() {
    int n, k;
    cin >> n >> k;
    vector<int> vec(2 * n + 2), pre(2 * n + 2);
    for (int i = 1; i <= n; ++i) {
        cin >> vec[i];
        vec[i + n] = vec[i];
    }
    for (int i = 1; i <= 2 * n; ++i) {
        pre[i] = pre[i - 1] + vec[i];
    }

    int ans = 0;
    auto check = [&](int x) -> bool {
        for (int i = 1, j = 1; i <= 2 * n; ++i) {
            while (j <= 2 * n && pre[j] - pre[i - 1] < x) j++;
            nxt[i][0] = j + 1;
        }
        for (int j = 1; j < 22; ++j) {
            for (int i = 1; i <= 2 * n; ++i) {
                if (nxt[i][j - 1] > 2 * n) {
                    nxt[i][j] = nxt[i][j - 1];
                } else {
                    nxt[i][j] = nxt[nxt[i][j - 1]][j - 1];
                }
            }
        }
        bool flag = false;
        int sum = 0;
        for (int i = 1; i <= n; ++i) {
            int cur = i;
            for (int j = 21; j >= 0; --j) {
                if ((k >> j) & 1) {
                    cur = nxt[cur][j];
                }
                if (cur > 2 * n) break;
            }
            if (cur - n <= i && cur <= 2 * n) {
                flag = true;
                sum++;
            }
        }
        if (flag) ans = n - sum;
        return flag;
    };

    int L = 1, R = INT_MAX;
    while (L + 1 < R) {
        int mid = (L + R) >> 1;
        if (check(mid)) L = mid;
        else R = mid;
    }
    cout << L << ' ' << ans << endl;
}

日志

本页面创建于 2024/09/10 18:16