学習する天然ニューラルネット

主に機械学習に関する覚書や情報の整理

らくらくp進全探索 コピペで使えるPython実装

何をしたか?

連続するp進数を次々返してくれるiteratorを実装しました(といっても標準ライブラリにラップしただけ)。 例えば、3桁の3進数だったら000, 001, 002, 010, 012 ..., 222 というものを次々に返してくれます。 実際には桁ごとにリストの1要素を構成していて,pが10以上でも問題なく動作します。

1種類あたりp個の選択肢があり、n種に対して全探索したい場合、これをn桁のp進全探索と呼ぶことにします。 この全探索を生成するコードは以下です。

2020/04/18訂正

自分の知識不足でこんな記事を書きましたが、本ブログの内容は以下の一行で書けることが判明しました

product(range(p),repeat=n)

コピペ用

def iter_p_adic(p, n):
    '''
    連続して増加するp進数をリストとして返す。nはリストの長さ
    return
    ----------
    所望のp進数リストを次々返してくれるiterator
    '''
    from itertools import product
    tmp = [range(p)] * n
    return product(*tmp)

テスト的な

#3桁の4進全探索
iterator = iter_p_adic(4, 3)
for idxs in iterator:
    print(idxs)
(0, 0, 0)
(0, 0, 1)
(0, 0, 2)
(0, 0, 3)
(0, 1, 0)
(0, 1, 1)
(0, 1, 2)
(0, 1, 3)
(0, 2, 0)
中略
(3, 2, 3)
(3, 3, 0)
(3, 3, 1)
(3, 3, 2)
(3, 3, 3)

ちゃんと出力できています。

なぜこれが必要なのか?

競技プログラミングに限らず、プログラミングではしばしば全探索をすることがあります。たとえば、n種のフルーツがあったときにありえる組み合わせを全部列挙しようとすれば、使う使わないで2n通りの全探索をすることになります(いわゆるbit全探索)。

3種の場合、組み合わせのbit全探索は 000, 001, 010, 011, ..., 111 となります。

一般的にbit全探索は

for i in range(1 << n): #nは桁数 000→111まで探索したいなら3
    for j in range(n): # 各桁について
        if i >> j:
            #ここにbitが1だったときの処理
        else:
            #ここにbitが0だったときの処理

のようになります。しかし、これは実装が煩雑になりがちです。

for pattern in [(0,0,0), (0,0,1), (0,1,0),...,(1,1,1)]:
    for keta in pattern:
        #各ketaについての処理

としたほうが直感的です。

この[(0,0,0), (0,0,1), (0,1,0),...,(1,1,1)]を生成するのにiter_p_adic(p=2,n=3)とすればよいです。

また、「n種を使う使わない」を拡張して「半分使う」みたいな選択肢もあったときに、各種に対しては3通りの選択肢を持つことになります。素直に全探索をイメージするなら以下のようなコードになると思います。

for pattern in [(0,0,0), (0,0,1), (0,0,2), (0,1,0)...,(2,2,2)]:
    for keta in pattern:
        #各ketaについての処理

bit全探索のコードをすこし改変することでこれを実現することはできませんが、iter_p_adic(p=3,n=3)は素直に[(0,0,0), (0,0,1), (0,0,2), (0,1,0)...,(2,2,2)]を返します。

(p>2の場合の全探索は深さ優先探索で実装する人が多いように感じますが、バグを生じやすいので個人的には回避したい。)

具体例

ABC015 の C問題

C - 高橋くんのバグ探し

各行から一つずつ要素を選んでxorする。という操作を全通りやればいい。これは問題文を読めばすぐにわかります。

しかし、1行につき選択肢が最大5種類あるので、bit全探索はできません。 一般的には深さ優先探索で実装するひとが多いでしょう(ほんまか?)。しかしバグらせやすい自分としてはなるべく避けたいところです。(実際jは再帰には関係無いのにjも再帰関数に入れてしまっている)。

N, K = map(int, input().split())
T = [list(map(int, input().split())) for _ in range(N)]
def dfs(i, j, x):  # i行目j列目を使うときについて、xは今までの経路のxor
    # 終了条件
    if i == N - 1:
        if x == 0:  # バグがあればTrueを返す
            print('Found')
            exit()
        return
    # 探索
    for jj in range(K):
        dfs(i + 1, jj, x ^ T[i + 1][jj])

dfs(-1, -1, 0)
print('Nothing')

一方で、iter_p_adicを用いると以下のようにシンプルに実装することができます。

N, K = map(int, input().split())
T = [list(map(int, input().split())) for _ in range(N)]
for idxes in iter_p_adic(K, N):
    sumxor = 0
    for i, j in zip(range(N), idxes):
        sumxor ^= T[i][j]
    if sumxor == 0:
        print('Found')
        exit()
print('Nothing')