サブロウ丸

Sabrou-mal サブロウ丸

主にプログラミングと数学

beam search; ビームサーチのpython実装

幅優先探索の亜種ですね。

アルゴリズムの流れは下記です。

  1. rootノードのみからなるpath、を持つpathsリストを生成(paths = [ [root] ])
  2. paths内の全てのpathを1階層分だけ展開しpathsを更新
  3. スコアが最も高いk個のpathのみをpathsに残す
  4. 2に戻る

補足: https://www.baeldung.com/cs/beam-search

下記のように実装しました。

コード

Notesに書いてあるような次の関数を持つapiを作って実行すると動きます。

  1. init : 初期化関数 api.init() がbeam_searchの初めに1度呼ばれます
  2. step : pathを入力として探索を1階層進めた pathのiterator/generatorを返す関数です
  3. score : pathを入力としてそのpathの評価値を返す関数です。値が高いほど良いことを表します。
  4. count : 探索が1ラウンド(上記の説明の3)が終了すると呼び出されます。
  5. terminate : beam_searchを終了するべきであればtrue, そうでないならfalseを返します。

ヒープの管理はpython標準のheapqを用いています。また今回はheapの大きさが固定長なので、heapの大きさがkになったらheapq.pushpop関数を用いてpushとpopを一つの関数で実行しています。(別々にやるよりもこの関数を使う方が効率が良いらしい; 公式ドキュメントより)

TSP(traveling salesman problem)を例に実行してみます。

class TSPAPI:
    def __init__(self, n):
        self.iter = 0
        self.n = n

    def init(self):
        self.iter = 0

    def step(self, path):
        for i in range(self.n):
            if i not in path:
                yield path[:] + [i]

    def score(self, path):
        """total distance of path"""
        return -sum(D[i][j] for i, j in zip(path, path[1:]))

    def count(self):
        self.iter += 1

    def terminate(self):
        return self.iter >= self.n

# TSP
N = 4
D = [
    [0, 1, 2, 3],
    [4, 0, 2, 1],
    [1, 2, 0, 3],
    [2, 3, 4, 0],
]

Nが都市数。Dが移動コスト行列です。

  • step(path)ではまだ訪問していない都市(i)をpathの最後に加えて返します
  • score(path)ではpathの巡回コストを出力
  • count()では訪問都市数のカウントをして
  • terminate()では訪問都市数が全都市数になったタイミングでtrueを出すようにしています

さて、このコードを k=1で実行してみます。

from beam_search import beam_search
root = []
k = 1
api = TSPAPI(N)
paths, scores = beam_search(root, k, api)
for path, score in zip(paths[::-1], scores[::-1]):
    print("path", path, "distance", -score)
path [3, 0, 1, 2] distance 5

k=2で実行してみます。

from beam_search import beam_search
root = []
k = 3
api = TSPAPI(N)
paths, scores = beam_search(root, k, api)
for path, score in zip(paths[::-1], scores[::-1]):
    print("path", path, "distance", -score)
path [2, 0, 1, 3] distance 3
path [3, 0, 1, 2] distance 5

ということでk = 2の方がより少ないコストのパス([2, 0, 1, 3])をつけることができていますね。

まとめ

beam searchの実装を行いました。