サブロウ丸

Sabrou-mal サブロウ丸

主にプログラミングと数学

【Python】multiprocessing.Queueを活用した並列処理の最適化(サンプルコード付き)

Pythonでプロセス並列処理を効率的に行うためのツールとして、multiprocessing.Poolは非常に便利です。しかし、下記のような状況では並列化により逆に処理速度が低下することがあります。

data = xxx # 非常に大きなオブジェクト (巨大なリストなど)

def func(data, i, j):
    """並列化したい関数: data、i、jの3つの引数を受け取り、何かの処理を実行する関数"""
    return data[i] + data[j]

# funcの入力
args = [
    (data, i, j) for i in range(10) for j in range(10)
]

with multiprocessing.Pool as pool:
    pool.map(func, args)

これはなぜかというと、multiprocessing.Poolやmultiprocessing.Processはプロセスを起動する際に関数の入力がシリアライズ(pickle化)されるからです。つまり、データ量の大きいオブジェクトが関数の引数に含まれる場合、そのオブジェクトのシリアライズ処理に時間がかかります。その処理がオーバーヘッドとなり、並列化の効果が失われてしまうのです。

上記の例では、argsのタプルオブジェクトがpickle化され、それぞれのプロセスに配布されます。ここでdata変数が巨大であるため、pickleの作成に時間がかかります。このシリアライズが何回も起きるので全体として処理が遅くなるのです。しかし、dataはこの処理の間変更されることはないため、シリアライズを複数回行うのは明らかに無駄です。

しかし筆者が知る限りmultiprocessing.Poolを用いるときに関数の引数をシリアライズを行わずにworkerに渡す方法はありません。 けれども、シリアライズの回数の削減を行うことはできるのです。

multiprocessing.Pool + multiprocessing.Queue による解決策

multiprocessing.Queueを活用した解決策を紹介します。

この方法では、関数の普遍な入力値(data)はworkerの立ち上げのときのみシリアライズし、可変な入力値、上記の場合(i, j)、のみをworkerに処理の都度渡します。

# funcの入力 (data: 不変, i: 可変, j: 可変)
args = [
    (data, i, j) for i in range(10) for j in range(10)
]

以下の手順で進行します。

  1. 関数への可変な入力値(上記例では(i, j)のペア)をQueueに格納します。
  2. 子プロセスを一つずつ立ち上げます。この際にdataのpickle化が行われ、立ち上げに時間がかかる場合があります。
  3. 立ち上がった子プロセスはQueueから変数(i, j)を取り出し、処理を実行します。
  4. 子プロセスが処理を終えると、次の変数をQueueから取り出します。

言い換えると、multiprocessing.Poolによる並列化では「指示ごとに現地へ社員を派遣する」(すべての変数を都度シリアライズのに対し、multiprocessing.Processとmultiprocessing.Queueの組み合わせでは「携帯を持たせた社員を現地に派遣し、その社員が指示をこなした後に電話で新たな指示を受け取る」(大きな変数は1回だけシリアライズし、その他の軽量な変数を都度受け取る)というようなイメージです。

また、multiprocessing.Queueを用いると、関数の入力を生成してQueueに格納するプロセスと、Queueから値を取り出して処理を行うプロセスを同時に動かすことも可能です。

以下に、具体的な実装例を示します。

1. 可変入力値をQueueに格納

可変な入力(i, j)をQueueに格納します。

# queue for input
M = 10
q = multiprocessing.Queue()
for i in range(M):
    for j in range(M):
        q.put((i, j))

# queue for output ( result queue )
rq = multiprocessing.Queue()

2. 子プロセスの立ち上げ

Processの作成とprocess.start()によってプロセス(worker)を立ち上げます。 その際に、args

# heavy object
data = [i for i in range(N)]

# worker factory
def create_worker(worker_index):
    return multiprocessing.Process(
            target=worker_queue,
            args=(worker_index, data, q, rq))  # 立ち上げの際に変数を渡す

# create workers
processes = list()
worker_index = 1
for worker_index in range(1, num_worker):
    process = create_worker(worker_index)
    process.start()
    processes.append(process)

3. 立ち上がった子プロセスはQueueから変数(i, j)を受け取り, 処理 + 子プロセスが処理を終えると, 次の変数をQueueから受け取る

def work(data, i, j):
    return i, j, data[i] + data[j]

def worker_queue(worker_index, data, q, rq):
    """calculate i + j using tuple (i, j) in queue q
    Parameters
    ----------
    worker_index : int
    data : list
    q : multiprocessing.Queue
        queue for get input
    rq : multiprocessing.Queue
        queue for store results
    """
    print(f'start worker_index {worker_index}')
    timeout = 1
    while True:
        try:
            i, j = q.get(timeout=timeout)  # 変数の受け取り
            _, _, result = work(data, i, j)  # 処理
            rq.put((i, j, result))  # 結果の格納
        except queue.Empty:  # queueが空になっていれば終了
            break
    return

簡単な実験

一番下にサンプルコードを載せています.

  • main_queue() が上記で言及している Queueを用いたwork()関数の並列化
  • main_pool() が Poolを用いたwork()関数の並列化実装 になります.

重たいオブジェクトとして長さNのリスト(data変数)を作成しています. 3プロセス並列で簡単なテストを実行したところ 私の環境では実行時間の比較したところ, ある程度Nが大きい(objectが重たい)と Proecss + Queueの方が優位になりました. Queueによるデータの受け渡しにもオーバーヘッドがあるため, オブジェクがある程度大きくないと優位性がでないということですね.

Process + Queue Pool
N = int(1e6) 2.38 s 0.42 s
N = int(1e7) 4.44 s 3.36 s
N = int(1e8) 27.97 s 62.93 s

それぞれ, 下記のサンプルコードの実行時間.

サンプルコードのコアの部分の説明ですが Process + Queue事項の場合, worker_queue関数が子プロセスで動く処理になります. q.get(timeout=timeout)でqueueからデータを取り出しますが, 指定したtimeoutでデータを取り出せない場合はqueue.Empty エラーを返すのでそれを受け取ると終了します.

データを取り出せない場合はすでにqueueが空になっているケースがほとんどです.

def worker_queue(worker_index, data, q, rq):
    """calculate b ** 2

    Parameters
    ----------
    worker_index : int
    data : list
    q : multiprocessing.Queue
        queue for get input
    rq : multiprocessing.Queue
        aueue for store results
    """
    print(f'start worker_index {worker_index}')
    timeout = 1
    while True:
        try:
            i, j = q.get(timeout=timeout)
            _, _, result = work(data, i, j)
            rq.put((i, j, result))
        except queue.Empty:
            break
    return

サンプルコード

gist.github.com