ChainerでLSTMでミニバッチ学習する時に注意するべきこと
自然言語処理でSequence to Sequenceモデルを学習する時や、単純にLSTMで入力文を固定次元ベクトルを計算したい時に、
入力が可変長であるため、ミニバッチ学習をする時に工夫が必要です。
他のフレームワーク(TensorFlowやTheano)でも同じような工夫は必要だと思われます。
例えば、
- A B C D E F G
- A B C D E F G H I J
上記のような2つの入力があるとします。
1つ目は長さ7のSequenceで
2つ目は長さ10のSequenceとします。
オンライン学習で学習する場合は特に可変の入力でも問題ありません。
ミニバッチ学習する時に問題になるのは、
入力をmatrixで表現する時に
- 0 1 2 3 4 5 6 -1 -1 -1
- 0 1 2 3 4 5 6 7 8 9
このように1つ目の長さ7のデータに対して-1で空白を埋める必要があります。
空白を埋める行為をpaddingと呼びます。
さて、このpaddingをするだけで良いかと言うと、もう一工夫する方が良いです。
例えばLSTMで入力文から固定次元ベクトルhを計算する時を考えましょう。
LSTMは
- 0 1 2 3 4 5 6 -1 -1 -1
- 0 1 2 3 4 5 6 7 8 9
のデータを先頭から読み取っていきます。
まず[0, 0]を読み取ってLSTMの隠れ状態hを更新、
次に[1, 1]を読み取ってLSTMの隠れ状態hを更新、
次に[2, 2]を読み取ってLSTMの隠れ状態hを更新、
・・・
次に[6, 6]を読み取ってLSTMの隠れ状態hを更新、
次に[-1, 7]を読み取ってLSTMの隠れ状態hを更新、
次に[-1, 8]を読み取ってLSTMの隠れ状態hを更新、
最後に[-1, 9]を読み取ってLSTMの隠れ状態hを更新。
この隠れ状態hが最終的なベクトルになります。
しかし1つ目のデータにおけるhはpadding文字「-1」を読み取った時の隠れ状態hになっています。
本当なら1つ目のデータの固定次元ベクトルhは赤文字で示した隠れ状態のものであるはずです。
(-1が入力としてある時に状態hをほぼ更新しない学習をLSTMがしてくれれば良いのですが、本来の学習に影響が出てしまう可能性があります。)
そこでやることは、padding文字「-1」があったら前の隠れ状態hをそのまま伝搬させる、という処理を書きましょう。
Chainerのwhereを使って実現できます。
@kitano_kumo 2回やると2倍時間かかっちゃうので、
— Yuya Unno (@unnonouno) 2016年2月1日
c, h = lstm(c_prev, lstm_in)
enable = (x != 0)
c_next = where(enable, c, c_prev)
h_next = where(en... が理想です
こちらのブログ記事で紹介されています。(とても分かりやすいのでおすすめです。)
c, h = lstm(c_prev, lstm_in) # 本来のLSTM
enable = Variable(x != -1) #xの-1であるかどうかのフラグを計算する
c_next = where(enable, c , c_prev) #x!=-1ならcを、x=-1ならc_prevをc_nextに代入
h_next = where(enable, h , h_prev) #x!=-1ならhを、x=-1ならh_prevをh_nextに代入
また単語ベクトルをChainerではEmbedIDを使うと簡単に用意できます。
padding文字「-1」をignore_labelに指定しておきましょう。
embed=L.EmbedID(n_vocab, n_units,ignore_label=-1), #ここで設定
本来whereでpadding文字-1の部分はbackwardしないので問題ないと思うのですが、
GPUでEmbedIDに-1を投げると挙動に困ったことがあるので、上記のようにignore_labelをしておきましょう。
まぁpadding文字を0と決めてあげれば問題ないのですが、
自然言語処理で辞書を作る時って最初の単語は0のインデックス渡すことが多いと思うので、個人的にはpadding文字を-1としています。
ミニバッチをもっと簡単にやる方法
データの長さを元にミニバッチを作成する
ミニバッチサイズ=2として、
ミニバッチ内の長さを統一してあげます。
ミニバッチ内のデータの長さを7のデータだけにする。
- A B C D E F G
- G Y L W B D G
長さ10のデータは10のデータでミニバッチを作成。
- A B C D E F G H I J
- B B W D E E R G H J
これが1番楽かもしれません。
ただ、NNに学習させるデータに偏りが生じるので、心配な人は上記のようにpaddingすると良いと思います。
paddingする場合もデータ長でソートしてあげて、ミニバッチを作成してあげた方がpaddingのサイズが小さくなりメモリ効率は良くなるはず。
(今後のTodo)
※Seq2SeqをChainerで実装してサンプルコードをまとめたい
また深層学習(+NLP)を実装する上で気をつけることが、PFIの海野さんの資料でとても良くまとまっています。