【AtCoder】ARC169 C - Not So Consecutive

問題

解法?(本題とは関係なし)

まず、dp[i][j][k] = i 個目まで見て末尾が j で k 個続いているものの個数 として各要素の遷移に  O(N) かけて全体  O(N ^ 4) のコードができる。 (下記コードのコメントアウト部分に対応)

簡単に思いつく高速化としてほとんどの adp[i + 1][a][1] に遷移するのでまとめて計算する。

#include "my_template.hpp"
#include "math/static_modint.hpp"
using namespace std;

using mint = mint998;

int main() {
    INT(N);
    VEC(int, A, N);
    REP(i, N) if (A[i] != -1) A[i]--;
    vector dp(N, vector<mint>(N + 2));
    if (A[0] == -1) {
        REP(a, N) dp[a][1] = 1;
    } else {
        dp[A[0]][1] = 1;
    }
    REP(i, 1, N) {
        vector np(N, vector<mint>(N + 2));
        if (A[i] == -1) {
            mint s = 0;
            REP(j, N) {
                REP(k, j + 2) {
                    s += dp[j][k];
                    np[j][1] -= dp[j][k];
                    if (k <= j) np[j][k + 1] += dp[j][k];
                    /*
                    REP(a, N) {
                        if (j == a) {
                            if (k + 1 <= j + 1) np[j][k + 1] += dp[j][k];
                        } else {
                            np[a][1] += dp[j][k];
                        }
                    }
                    */
                }
            }
            REP(a, N) np[a][1] += s;
        } else {
            REP(j, N) {
                REP(k, j + 2) {
                    if (j == A[i]) {
                        if (k <= j) np[j][k + 1] += dp[j][k];
                    } else {
                        np[A[i]][1] += dp[j][k];
                    }
                    /*
                    REP(a, A[i], A[i] + 1) {
                        if (j == a) {
                            if (k + 1 <= j + 1) np[j][k + 1] += dp[j][k];
                        } else {
                            np[a][1] += dp[j][k];
                        }
                    }
                    */
                }
            }
        }
        swap(dp, np);
    }
    mint ans = 0;
    REP(j, N) REP(k, j + 2) ans += dp[j][k];
    print(ans);
    return 0;
}

ここまではバーチャルコンテスト中に考えた解法で、結論としてはこれを  O(N ^ 2) に落とすことはできなかった。

解法

dp[i][j] = i 個目まで見て連続部分列の先頭要素として j から始まっているものの個数 とする。

貰う DP で遷移を考える。

イメージとしては LIS の DP の 2 次元版のように、配列を更新しながら計算していく。 (つまり、dp[i + 1] の計算が終われば dp[i] は不要になるタイプの DP ではないということ)

どういうことかというと、すでに計算が終了している dp[i][j] についても、今見ている dp[i] に遷移できなくなったら個数を 0 に更新するという処理を加える。 つまり、dp[i][j] について、[ ... ][ j ... ということはわかっているが、そこから今見ている要素の 1 つ前まで [ ... ][ j ... j ] で埋めることができなくなったら 0 で更新してその遷移が不可能なことを表現する。

できなくなるというのは、同じ要素が連続しすぎてダメになった場合と、A[i] != -1 なので埋める値が決まってしまった場合である。 同じ要素が連続しすぎてダメになる場合は各 dp[i] について各ループごとに先頭から順番に 1 つずつ増えていくので先頭要素を取り除けば良い。 埋める値が決まってしまった場合は残るのは 1 つなのでその手前の要素を後ろの要素を全て取り除けば良い。

遷移を考慮すると dp 配列の行方向の和と列方向の和をそれぞれ持っておく必要があるので別で用意して、以下のような高速化を行うことができる。

#include "my_template.hpp"
#include "math/static_modint.hpp"
using namespace std;

using mint = mint998;

int main() {
    INT(N);
    VEC(int, A, N);
    REP(i, N) if (A[i] != -1) A[i]--;
    {
        using pr = pair<int, mint>;  // dp[i][j] = x -> (j, x) を格納した deque
        vector<deque<pr>> dp(N + 1);
        vector<mint> dp_sum_i(N + 1);  // dp_sum_i[i] = SUM(dp[i][a] for a)
        vector<mint> dp_sum_a(N + 1);  // dp_sum_a[a] = SUM(dp[i][a] for i)

        dp[0].push_back({N, 1});
        dp_sum_i[0] += 1;
        dp_sum_a[N] += 1;

        REP(i, N) {
            mint dp_sum = 0;
            REP(j, i + 1) dp_sum += dp_sum_i[j];
            vector<mint> np(N + 1, dp_sum);
            REP(a, N) np[a] -= dp_sum_a[a];

            // apply
            REP(a, N) {
                dp[i + 1].push_back({a, np[a]});
                dp_sum_i[i + 1] += np[a];
                dp_sum_a[a] += np[a];
            }

            // shrink
            if (i == 0) {
                dp[0].clear();
                dp_sum_i[0] = 0;
                dp_sum_a[N] = 0;
            }

            REP(j, i + 1) {
                if (LEN(dp[j]) > 0 and dp[j].front().first == i - j) {
                    dp_sum_i[j] -= dp[j].front().second;
                    dp_sum_a[dp[j].front().first] -= dp[j].front().second;
                    dp[j].pop_front();
                }
            }

            if (A[i] != -1) {
                REP(j, i + 2) {
                    while (LEN(dp[j]) > 0 and dp[j].front().first != A[i]) {
                        dp_sum_i[j] -= dp[j].front().second;
                        dp_sum_a[dp[j].front().first] -= dp[j].front().second;
                        dp[j].pop_front();
                    }
                    while (LEN(dp[j]) > 0 and dp[j].back().first != A[i]) {
                        dp_sum_i[j] -= dp[j].back().second;
                        dp_sum_a[dp[j].back().first] -= dp[j].back().second;
                        dp[j].pop_back();
                    }
                }
            }
        }
        mint ans = 0;
        REP(i, N + 1) ans += dp_sum_i[i];
        print(ans);
    }
#if 0
    // 以下の O(N ^ 4) のコードを高速化する
    {
        vector dp(N + 1, vector<mint>(N + 1));
        dp[0][N] = 1;
        REP(i, N) {
            // calc dp[i + 1]
            REP(a, N) REP(j, i + 1) dp[i + 1][a] += SUM(dp[j], mint(0)) - dp[j][a];
            // shrink
            if (i == 0) dp[0][N] = 0;
            REP(j, i + 1) dp[j][i - j] = 0;
            REP(j, i + 2) {
                REP(a, N + 1) {
                    if (A[i] != -1 and a != A[i]) {
                        dp[j][a] = 0;
                    }
                }
            }
        }
        mint ans = 0;
        REP(i, N + 1) ans += SUM(dp[i], mint(0));
        print(ans);
    }
#endif
    return 0;
}

感想

全然高速化ができなくて解説を読んでもかなり理解に苦しんだが、LIS の DP の 2 次元版っぽく考えることでやっと理解できた気がする。

本番中に通せるようになるのはまだまだ先だなと感じる。(そもそもいつか通せるようになるんですかね…?)

おまけ

よく考えると上記のコードで dp_sum_i はいらない。(dp_sum の計算で使っているが、dp_sum_a で代用できる。) 以下のコードと解説の maroon さんのコードを比較してようやく z[j][i] が自分のコードにおける dp[i][j] に対応しているということがわかった。 転置するとだいぶ配列に対する 0 埋め操作が簡単になる。

#include "my_template.hpp"
#include "math/static_modint.hpp"
using namespace std;

using mint = mint998;

int main() {
    INT(N);
    VEC(int, A, N);
    REP(i, N) if (A[i] != -1) A[i]--;
    {
        using pr = pair<int, mint>;  // dp[i][j] = x -> (j, x) を格納した deque
        vector<deque<pr>> dp(N + 1);
        // vector<mint> dp_sum_i(N + 1);  // dp_sum_i[i] = SUM(dp[i][a] for a)
        vector<mint> dp_sum_a(N + 1);  // dp_sum_a[a] = SUM(dp[i][a] for i)

        dp[0].push_back({N, 1});
        // dp_sum_i[0] += 1;
        dp_sum_a[N] += 1;

        REP(i, N) {
            mint dp_sum = 0;
            // REP(j, i + 1) dp_sum += dp_sum_i[j];
            REP(a, N + 1) dp_sum += dp_sum_a[a];
            vector<mint> np(N + 1, dp_sum);
            REP(a, N) np[a] -= dp_sum_a[a];

            // apply
            REP(a, N) {
                dp[i + 1].push_back({a, np[a]});
                // dp_sum_i[i + 1] += np[a];
                dp_sum_a[a] += np[a];
            }

            // shrink
            if (i == 0) {
                dp[0].clear();
                // dp_sum_i[0] = 0;
                dp_sum_a[N] = 0;
            }

            REP(j, i + 1) {
                if (LEN(dp[j]) > 0 and dp[j].front().first == i - j) {
                    // dp_sum_i[j] -= dp[j].front().second;
                    dp_sum_a[dp[j].front().first] -= dp[j].front().second;
                    dp[j].pop_front();
                }
            }

            if (A[i] != -1) {
                REP(j, i + 2) {
                    while (LEN(dp[j]) > 0 and dp[j].front().first != A[i]) {
                        // dp_sum_i[j] -= dp[j].front().second;
                        dp_sum_a[dp[j].front().first] -= dp[j].front().second;
                        dp[j].pop_front();
                    }
                    while (LEN(dp[j]) > 0 and dp[j].back().first != A[i]) {
                        // dp_sum_i[j] -= dp[j].back().second;
                        dp_sum_a[dp[j].back().first] -= dp[j].back().second;
                        dp[j].pop_back();
                    }
                }
            }
        }
        mint ans = 0;
        REP(a, N + 1) ans += dp_sum_a[a];
        // REP(i, N + 1) ans += dp_sum_i[i];
        print(ans);
    }
    return 0;
}