WARush

SRMの結果とか、解けた問題のコードを書いていきます

SRM605 Div1 Medium "AlienAndSetDiv1"

問題

TopCoder Statistics - Problem Statement

エイリアンのフレッドは地球を破壊しようとしていた。しかし、それを行う前に下記のような問題を解決しておきたかった。

{1, 2, 3, ..., 2N}のセットがある。彼はこれをセットAとBに分けたい。しかし、次のような制約がある。

  • 元のセットのそれぞれの要素は、セットAかBに1つだけ存在しなければならない。
  • セットA, Bは同じサイズにしなくてはならない。(つまりそれぞれにN個の要素が入っている)
  • A[i]をセットAのi番目に小さい要素だとする。
 同じくB[i]をセットBのi番目に小さい要素だとする。  このとき、A[i] - B[i]の絶対値はK以上でなくてはならない。

あなたにはint N, Kが与えられる。セットA, Bに分ける方法を数を返せ。

制約

1 <= N <= 50
1 <= K <= 10


考えたこと

1から順番に、セットA,Bに振り分けていくことを考えてみる。セットAにa個目の要素を置けるかどうかは、セットBのa個目に置いた要素が直近のK-1個の中にあるかないかで判断できる。

dp[今まで置いた要素の数][セットAに置いた要素の数][直近のK-1個をABのどっちに置いたかビット]のDPでおk。


ソースコード

const int MOD = 1000000007;
long long dp[105][55][1 << 10];

class AlienAndSetDiv1 {
public:

    int getNumber(int N, int K) {

        K--; // K-1

        if (K == 0) {
            // K-1 = 0だとうまくDP出来ないため別途
            return another(N);
        }

        // dp初期化
        memset(dp, 0, sizeof(dp));
        dp[K + 1][0][0] = 1; // 最初のK個を白
        dp[K + 1][K + 1][(1 << K) - 1] = 1; // 最初のK個を黒

        // dp更新
        for (int i = K + 1; i < N * 2; i++) {
            for (int j = 0; j <= N; j++) {
                for (int b = 0; b < (1 << K); b++) {
                    if (dp[i][j][b] == 0) continue;
                    int c = cntBit(b, K);
                    int k = i - j;
                    int sj = j - c + 1;
                    int sk = k - (K - c) + 1;

                    // 白を置く
                    if (k + 1 < sj || j < k + 1) add(dp[i + 1][j][nextBit(b, K)], dp[i][j][b]);

                    // 黒を置く
                    if (j + 1 < sk || k < j + 1) add(dp[i + 1][j + 1][nextBit(b, K) + 1], dp[i][j][b]);

                }
            }
        }

        // dp集計
        long long res = 0;
        for (int b = 0; b < (1 << K); b++) {
            add(res, dp[N * 2][N][b]);
        }

        return (int)res;
    }

    int another(int N) {
        int c[55];
        int m[55];

        for (int i = 1; i <= N; i++) {
            c[i] = i + N;
            m[i] = i;
        }

        for (int i = 1; i <= N; i++) {
            for (int j = 1; j <= N; j++) {
                int d = gcd(m[i], c[j]);
                m[i] /= d;
                c[j] /= d;
            }
        }

        long long res = 1;
        for (int i = 1; i <= N; i++) {
            res = res * c[i] % MOD;
        }

        return res;
    }
    
    void add(long long & target, long long add) {
        target += add;
        target %= MOD;
    }

    int cntBit(int bit, int K) {
        int res = 0;
        for (int mask = 1; mask < (1 << K); mask <<= 1) {
            if ((bit & mask) != 0) {
                res++;
            }
        }
        return res;
    }

    int nextBit(int bit, int K) {
        int max = 1 << K;
        bit <<= 1;
        if (max <= bit) bit -= max;
        return bit;
    }

    int gcd(int a, int b) {
        if (a < b) swap(a, b);
        if (a % b == 0) return b;
        return gcd(b, a % b);
    }
};