サブロウ丸

Sabrou-mal サブロウ丸

主にプログラミングと数学

Pytorch 分散処理プロファイルを読む

Pytorchで用意されている分散処理機能(Distributed Data Parallel)の中では特にbackward時に勾配の計算と、その勾配の共有(集団通信)が頻繁に行われます。Pytorchではそのデバッグと性能計測用にその計算と通信の開始/終了のタイミングの計測が行われています。本稿ではその部分のコードについてまとめます。

Reducer

計算と通信時間の統計を管理しているのはtorch/csrc/distributed/c10d/logger.cppです。具体的に計測しているのはこのあたり。logger.cpp::Logger::set_construction_data_and_logで初期設定を行い、Reducerの次の関数たちを呼び出すことで計算と通信のチェックポイントの開始と終了時間を記録しています。

  • Reducer::record_forward_compute_start_time
  • Reducer::record_backward_compute_start_time
  • Reducer::record_backward_compute_end_time
  • Reducer::record_backward_comm_start_time
  • Reducer::record_backward_comm_end_time

record_forward_compute_start_timeがあってrecord_forward_compute_end_timeがない理由は、record_forward_compute_end_time = record_backward_compute_start_timeであるからなんでしょうね。 ただ、全てのタイミングを記録しているわけではなく、次の関数の条件を満たす場合にのみ計測結果が保存されます。

bool Reducer::should_collect_runtime_stats() {
  if (num_iterations_ > 0 &&
      (num_iterations_ <= 10 ||
       num_iterations_ % get_ddp_runtime_logging_sample_rate() == 0)) {
    return true;
  }
  return false;
}

Reducer::should_collect_runtime_stats

なので飛び飛びのタイミングで計測される感じですね。例えば常に計測したい場合は、ddp_runtime_logging_sample_rateを更新する必要があって、pythonスクリプトからだと次のように変更します。

    ddp_model = DDP(
        model,
        find_unused_parameters=not True,
        bucket_cap_mb=1,
    )
    # ddp_runtime_logging_sample_rateを1に変更
    ddp_model.reducer._set_ddp_runtime_logging_sample_rate(1)

上のコードではddp_runtime_logging_sample_rateを1に設定していますが、この状態だと全ての通信と計算の開始終了のタイミングの計測が行われます。

また、時間の取得自体はtorch::profiler::impl::getTime()本体はここで行っており、単位はナノ秒。時間計測は

の関数で行われており、Linuxではchronoよりclock_gettimeの方が高速とのこと。