WARush

SRMの結果とか、解けた問題のコードを書いていきます

SRM603 Div1 Medium "PairsOfStrings"

問題

TopCoder Statistics - Problem Statement

我々は文字列のペア(A, B)で次のようなものを探している。

・AとBはn個の文字で構成されている。
・AとBに含まれる文字は、最初のk個のアルファベット小文字のどれかである。
・A + C = C + Bとなるような任意の文字列をCが存在する。

例えば、n = 3, k = 4だとして、("aab", "daa")は両方とも3文字だし、最初の4つのアルファベット小文字しか使ってないし、C = "aa"としてA + C = C + Bが成り立つ。つまり我々が探している文字列のペアである。

あなたはint n, k が与えられる。我々が探している文字列のペアが何通りあるかを返せ。

制約

1 <= n <= 10^9
1 <= k <= 26


考えたこと

どのような(A, B)であればA + C = C + Bが成り立つかというと、BがAをローテーションさせた様な文字列である事が条件であることが分かる。具体的にはAが"aaac"だとすると、Bは"aaac", "aaca", "acaa", "caaa"のどれかでなくてはならない。

まずAの文字列のパターンとして、k^n通りとなる。それぞれのAに対し、Bはローテーションでnパターン取れるので、数え上げとしてはk^n * nとなる。

だがしかし、Aが"abab"とかだとBは"abab", "baba"のどちらかしか取れない。Aの中で"ab"で反復していた場合、その反復文字列の2文字分しかBでローテーションできないことになる。

さらに、Aにおいて2文字の反復文字列もk^2パターン取れるかといったらそうでもない。"aaaa" , "bbbb"といった文字列の反復文字列は"a", "b"の1文字である。

つまり、ある長さがLであるような反復文字列のパターン数は、それよりの短い反復文字列のパターン数を引かなくてはならない。具体的には以下のように処理する。

n = 6, k = 2

6の約数は1, 2, 3, 6。これが反復文字列の長さの種類となる。

L = 1
aかbの2パターン。
1より小さい反復文字列の長さはないので、2パターン取れる。

L = 2
aa, ab, ba, bbの4パターン
L=1のパターン数は引かなくてはならないので、(4-2)で2パターン取れる。

L = 3
aaa, aab, aba, abb, baa, bab, bba, bbbの8パターン
L=1のパターン数は引かなくてはならないので、(8-2)で6パターン取れる。
L=2のパターン数は引かなくてもよい。
なぜなら3文字ずつ反復していれば、2文字ずつ反復するのは不可能だから。
3 % 2 != 0 から判断できる。

L = 6
2^6で64パターン
L=1のパターン数は引かなくてはならないので、(64-2)で62パターン取れる。
L=2のパターン数は引かなくてはならないので、(62-2)で60パターン取れる。
L=0のパターン数は引かなくてはならないので、(60-6)で54パターン取れる。
結局54パターンとなる。

最後
L = n  反復文字列の数 * ローテパターン
L = 1  2 * 1 = 2
L = 2  2 * 2 = 4
L = 3  6 * 3 = 18
L = 6  54 * 6 = 324
  
2 + 4 + 18 + 324 = 348

ソースコード

long long cnt[100000];

class PairsOfStrings {

public:

    long long pow_mod(long long x, long long n) {
        long long r = 1;
        while (n > 0) {
            if (n & 1) r = r * x % MOD;
            n >>= 1;
            x = x * x % MOD;
        }
        return r;
    }

    int getNumber(int n, int k) {

        // 約数列挙
        vector<int> D;
        for (int d = 1; d * d <= n; d++) {
            if (n % d != 0) continue;
            D.push_back(d);
            D.push_back(n / d);
        }

        // 昇順にソート
        sort(D.begin(), D.end());
        
        // 答え算出
        long long res = 0;
        for (int i = 0; i < D.size(); i++) {
            long long c = pow_mod(k, D[i]);
            for (int j = 0; j < i; j++) {
                if (D[i] % D[j] != 0) continue;
                c -= cnt[j];
                if (c < 0) c += MOD;
            }
            cnt[i] = c;
            res = (res + c * D[i]) % MOD;            
        }

        return (int)res;
    }
};