サブロウ丸

主にプログラミングと数学

Transformerによる翻訳システム自作; part2 プロトタイプの作成(シンプルなTransformer)

本稿では翻訳モデルのプロトタイプとして簡易化したTransformerを作成します。英語→日本語の翻訳モデルを下図の構成で作成します。

上図は"dog is cute"をencoder、"犬はかわいい。"をdecoderに入力して"犬はかわいい。"を推論させるように学習させている様子です。 翻訳元の英語はそのまま入力し、日本語は一つのtokenをスライドさせた文章を予測させます。(つまり文末を除いた文章を入力して、文頭を除いた文章を予測させる。)

次回以降でAttentionをMasked Multi-head attentionにしたり、positional encodingを追加したりします。今回はあくまでプロトタイプの説明ということでご容赦ください。

またコードは GitHub - nariaki3551/transformer_scratch at v1.0 で見ることができます。

モデルを構成する要素

text encode

ここではトークン(単語)を事前に割り当てたidに置き換えます。図では"犬"と2、"は"と10が対応しています。また文章の長さを統一的な長さLにするためにPADという特殊トークンで埋めます。文頭や文末はBOSEOSという特殊トークンを用います。

実装としてはtorchtextのvocabを用いました。これにより"犬"、"雨"というtokenがidへ変換されます。

(utils/text.py)

indices = vocab.lookup_indices(tokens)

参考:

Linear embedding

text encodeでid化したトークンをベクトルの形に変換します。one-hotベクトルで表すこともできますが、それだと語彙数分の巨大なベクトルになり処理が大変なので通常はより小さな次元(たかだか数百、数千次元)に変換します。

ここでは最もシンプルな例を実装します。vocabularyサイズ(語彙数)を Vと、圧縮したい次元を Mとすると

  1.  W \in \mathbb{R}^{V\times M} のランダム行列を作成する
  2. id  iの単語について  i行目の Wの要素を返す (numpy風だと W[i, :])

というものですね、one-hotベクトル x について  x^{T}Wの計算をしていることに相当します。 またに対応するtokenは空白を表す"意味のない"tokenなのでその後の順/逆伝播に影響しないように0ベクトルに変換します。

コード

Attention

attention(注意機構)は無順序集合に対する特徴量取り出し機構とも言えます。 名前の通りどの入力に注目すれば良いか、を元に特徴量を算出して出力します。 Transformerにおけるコアの部分ですね。

backwardは頑張って計算します。3回以上のテンソルが含まれる内積はeinsumを用いると便利ですね。

コード

Encoder

text encode → Linear Embedding → Attentionと計算を行います。Attentionの部分はSelf-Attentionを行います。すなわち、

  • Q ← x WQ
  • K ← x WK
  • V ← x WV

と変換してsoftmax(QT K)Vを計算します。(Wはパラメタ)。

Decoder

text encode → Linear Embedding → Attention → Attention → Affine → Softmaxと計算を行います。1つ目のAttentionはSelf-Attention、2つ目のAttentionはQをDecoder、K、VをEncoderの出力から作成します。すなわち、Decoderの出力をx、Encoderの出力をyとすると、

  • Q ← x WQ
  • K ← y WK
  • V ← y WV

と変換してsoftmax(QT K)Vを計算します。(Wはパラメタ)。

# second attention
x = self.attention2.forward(x, encoder_output, encoder_output)

backwardは

dxq, dxk, dxv = self.attention2.backward(dout)
dout = dxq
dencoder_output = dxk + dxv

ですね。

また、Affine変換によって出力されるベクトルのshapeを(B, S, E) →(B, S, V)に変換します。ここでBはミニバッチサイズ、Sは文の長さ(sentence length)、Eは埋め込み次元(embedded dimension)、Vは語彙数(Vocabulary size)です。変換自体は入力層と出力層の2層からなる全結合層で作成します。

最後に出力をSoftmax変換します。これにより出力ベクトルx[i, j, k]はミニバッチiのj番目のtokenがid kの語彙である確率を表すとみなすことができます。

Loss

そして、このDecoderの出力と正解の単語のidとのcross entropyを損失関数とします。 この損失関数が最小になるように学習させるので、学習モデルとしては次の単語を当てられるようなモデルを目指したパラメタの更新を行います。

推論

さて、訓練を終えたモデルを用いて実際に翻訳を行う場合の流れを説明します。今回は"dog is cute"の日本語化を例にしましょう。まずはEncoderに"dog is cute"の文章を入力し、Decoderには空白の日本語を入力します(text encodeにより文頭を表すBOSと残りが全てPADのベクトルが入力されます)。(うまくモデルを学習できている場合は、次の日本語である"犬"を出力できるようになっているはずです。)モデルの出力ベクトルの文頭に相当する部分の中で最大の要素をもつインデックスiに対応するトークン(単語)(図では"犬")をDecoderに再度入力して次の語を予測させます。そしてまた得られた単語を"犬"に付け足して、Decoderでその次のその次のtokenに相当する単語を推測して、、という作業を文章が一定の長さになるか文末を表すEOSトークンがDecoderの出力になるまで継続させます。これにより"dog is cute"に対応するであろう日本語のトークン列を得られるので、あとはそれをつなげれば終わりです。

コード

実験

学習がうまく進んでいるのかを確認しましょう。

Optimizer & Evaluate metrics

今回のoptimizerはAdamを用いました。翻訳の評価としてはBLEUを用いています。

実験設定

実行コマンド

python prototype.py

実験結果

10epoch実行時のlossの推移です。順調に下がっているのでパラメタの更新はうまくいっていますね。

args Namespace(epoch=10, batch_size=16, max_tokens=128, sentence_length=18, translate_test_interval=10, batch_progress=False, quiet=False, embed_dim=16)
INFO:utils.text:create_vocabs(95):size of vocab_ja 128
INFO:utils.text:create_vocabs(96):size of vocab_en 128
INFO:__main__:create_data_loader(133):max length of japanes train sentense 16
INFO:__main__:create_data_loader(136):max length of english train sentense 16
INFO:__main__:create_data_loader(139):max length of japanes test sentense 16
INFO:__main__:create_data_loader(140):max length of english test sentense 10
epoch     0, loss 15.93064, time 0.37s
epoch     1, loss 15.88592, time 0.75s
epoch     2, loss 15.81100, time 1.13s
epoch     3, loss 15.74778, time 1.51s
epoch     4, loss 15.59901, time 1.89s
epoch     5, loss 15.53568, time 2.28s
epoch     6, loss 15.40725, time 2.68s
epoch     7, loss 15.29555, time 3.07s
epoch     8, loss 15.20547, time 3.47s
epoch     9, loss 14.96988, time 3.87s

コードまとめ

実行ファイルはsrc/prototype.pyです。

まとめ

本稿では日英翻訳を行うためのプロトタイプを作成し、損失関数の値が小さくなるようにパラメタの更新が行えていることを確認しました。

他の記事