記録帳

クラウド、データ分析、ウイスキーなど。

word2vecの高速化

ゼロから作る2の第4章「word2vecの高速化」を読んだ。
第3章で出てきたCBOWモデルの、計算が複雑な箇所を単純化した。

全体のうちの問題はこの部分で、それぞれの対応はこれ、とわかりやすい章だった。

・CBOWモデルのおさらい

CBOWモデルは、コンテキスト(ある複数の単語)から、ターゲット(コンテキストに囲まれた単語)を推定するモデルである。学習させていくと、重みであるWが各単語の分散表現となっていく。
以下がCBOWモデルの全体像である。確率まで出したら、あとは損失関数を算出し、それが小さくなるよう何回も学習させる…とおきまりのパターン。
今回は、入力が100万個の単語を持つコーパス、中間層は100で計算している。

f:id:supa25:20190211150031p:plain

 

・CBOWモデルの問題点

以下の絵の①、②、③の箇所で計算量が多くなっている。
今回は、それぞれに対処をしていく。

f:id:supa25:20190211150106p:plain

CBOW問題点

・①入力とWinの積

以下の図を用いて説明する。
まず、元の計算は入力(1行100万列)とWin(100万行100列)の積である。
これを真面目にやると、大量の計算が必要になる。
そこで、入力をよく見ると、one-hot表現であるため、1つの要素が1で、それ以外は全て0となっている。これはつまり、入力で1となっている列数(例では0列目)を、Winから抜き出してきて足し合わせると答えになることを意味する。(もちろん他の列は0で埋めておく)

これで、計算を(1×100万)と(100万×1)の積とすることができ、計算量が大幅に減った。

f:id:supa25:20190211150636p:plain

 

 

・②中間層とWoutの積

ここも、以下の図を用いて説明する。
元の計算は、中間層(1行100列)とWout(100行100万列)の積である。
今回は、Woutの行列の列をサンプリングして抽出する。
まず、Woutの中で1番必要な列は、正解となる単語を示す列である。(例でいうとsayを表す1列目)そのため、正解となる列(正例と呼ぶ)は抽出対象とする※(1)。
他は全て不正解となる列(負例と呼ぶ)だが、文中にほとんど出てこない単語を抽出すると、うまく学習できず全体の精度が落ちる可能性がある。そこで、それぞれの単語の出現頻度順に抽出する。何個抽出するかは、パラメータで決める。(例では、3つとしている)

f:id:supa25:20190211151401p:plain

※(1)

正解となるラベルを、学習に使ってしまっていいのか?と疑問が湧いた。
テキストにも載っていないが、こう考えることで納得させた。

---------------

学習は、以下のサイクルを経る。

①入力層から中間層を経て出力する。

②それを正解データと比較し、損失関数を求める。

③損失関数を少なくするように重みを更新する。

④①に戻る。

この②で正解データを用いている。今回は出力層の計算、つまり①で正解データを用いている。通常とタイミングは違うが、①ー④という学習の1サイクルの中で使用するということは変わらないので、特に問題はない。

---------------

 

・③Softmax関数計算

②で抽出されてくるので、そのまま計算するだけで良い。
ただ、今回の修正で、多値分類から二値分類となった(どの単語が推測される?から、推測される単語はsayかどうか?になった)ことで、使用する出力層の関数が変わる。
多値分類ではSoftmax関数を使用するが、二値分類ではSigmoid関数を使用する。損失関数は変わらず交差エントロピー誤差を使用する。

 

・まとめ

以上の改修をして、CBOWモデルを改善した。
どの改修も、全てを実施するのではなく一部を計算することで速度改善を図っていた。この考え方は色々と応用できると感じた。