WARush

SRMの結果とか、解けた問題のコードを書いていきます

SRM574 Div1 Medium "TheNumberGameDiv2"

考えた事

Nが18までになっているので、調べなくてはならない頂点数は
最大で18 - 3の15。
15!回も交差判定をする訳にもいかず、今度こそbitDP。

巡った頂点を1、まだ巡っていない頂点を0としたbitを持つ。
あと交差判定では、今どの点にいるか?という情報が必要なため
その情報を持つ。

dp[from][bit] := 現在頂点fromにいて、巡った頂点がbitの時の線の引き方の数
とする。

初期値はdp[ pointsの最後の頂点番号 ][ pointsにある点を全て1にしたbit ] = 1

更新は
fromから他の線と交差させて線を引ける頂点をto、
bitからtoのビットを1にしたものをnewBitとすると、
dp[to][newBit] += dp[from][bit]となる。

全ての頂点を巡る事が出来たならば、
最後にpointsの最初の頂点に戻らなければいけない。
最初の頂点の、隣以外の頂点であれば、かならず交差させつつ
戻る事ができるので、答えはそのような頂点vの
dp[v][111...111]
を全て足し合わせたものとなる。

更新の方向は
SRM 574 Div1 450 PolygonTraversal - kojingharangの日記
こちらを参考にしました。

bitを00000....111111で単純に回しても漏れなく数えられるのに気付かなかった・・


ソースコード

const int MAX_N = 18;

int N;
// dp[i][bit] i番目の位置にいて、訪問した頂点がbitの時の場合の数
long long dp[MAX_N][1 << MAX_N];

class PolygonTraversal {
public:

    // 次の頂点番号
    int next( int i ){
        return i == N - 1 ? 0 : i + 1;
    }

    // 前の頂点番号
    int prev( int i ){
        return i == 0 ? N - 1: i - 1;
    }

    // from - toって交差する?
    bool isCross( int from, int to, int bit ){
        bool b1 = false;
        for( int i = next( from ); i != to; i = next( i ) ){
            if( bit & (1 << i) ) b1 = true;
        }
        bool b2 = false;
        for( int i = prev( from ); i != to; i = prev( i ) ){
            if( bit & (1 << i) ) b2 = true;
        }
        return b1 && b2;
    }

    long long count(int n, vector <int> points){
        N = n;
        
        // dp初期化
        for( int i = 0; i < MAX_N; i++ ){
            for( int j = 0; j < 1 << MAX_N; j++ ){
                dp[i][j] = 0;
            }
        }
        int bit = 0;
        int m = points.size();
        int S = points[0] - 1; // 簡単にするため頂点番号をbase 0に
        int T = points[m-1] - 1;
        for( int i = 0; i < m; i++ ){
            int si = points[i] - 1;
            bit |= (1 << si);
        }
        dp[T][bit] = 1; // pointsの最後の頂点だけ1に

        // dp更新
        for( int bit = 0; bit < (1 << N) - 1; bit++ ){
            for( int from = 0; from < N; from++ ){
                if( dp[from][bit] == 0 ) continue; // 0 なので意味なし
                // 線を引けるtoを見つける
                for( int to = 0; to < N; to++ ){
                    if( from == to ) continue;
                    if( bit & (1 << to) ) continue; // 既に行っていた
                    if( !isCross( from, to, bit ) ) continue; // 交差しなかった

                    // from - toに線引ける!
                    int newBit = bit | (1 << to);
                    dp[to][newBit] += dp[from][bit];
                }                    
            }
        }
        
        // 結果を計算
        long long res = 0;
        int resBit = (1 << N) - 1; // 111...111
        for( int v = 0; v < N; v++ ){
            if( v != S && v != next(S) && v != prev(S) ){
                res += dp[v][resBit];
            }
        }
        return res;
    }
};