記録帳

クラウド、データ分析、ウイスキーなど。

誤差逆伝播法がなぜ速いのか考えてみた

GW最終日の昼下がり。いかがお過ごしでしょうか。
このGWで復習しよう!と思い、ゼロから作るの①を読み返しています。

今まで2回ほど読んでいて、毎回第5章の誤差逆伝播法で詰まっています。
今回は、以前よりは深く理解できたと思います。それは今度別記事に書く予定です。
今回の記事では、なぜ誤差逆伝播法はこんなに速いか?を解き明かそうという試みです。

実際、どれくらい速いのか?

以下のコードを実行して、数値微分であるnumerical_gradients関数と誤差逆伝播法であるgradient関数でどれくらい時間がかかっているかを調査します。
(具体的な関数の中身は後述)

import sys, os
from dataset.mnist import load_mnist
import tqdm

# MNISTのデータロード
(x_train, t_train), (x_test, t_test) = load_mnist(one_hot_label=True, normalize=True)

# パラメータ設定
iters_num = 10000
batch_size = 100
learning_rate = 0.1
train_size = x_train.shape[0]
network = TwoLayerNet(input_size=784, hidden_size=100, output_size=10)
train_loss_list = []
train_acc_list = []
test_acc_list = []

# 1バッチ目実行
batch_mask = np.random.choice(train_size, batch_size)
x = x_train[batch_mask]
t = t_train[batch_mask]
grads = {}
%time grads = network.gradient(x, t)
#%time grads = network.numerical_gradients(x, t)

結果は・・・
数値微分:239s
誤差逆伝播法:0.0256s
つまり、約1万倍も誤差逆伝播法の方が速いという結果になりました。
また、相対的ではなく絶対的に評価しても、239秒はかなり遅いですね。
本来はイテレーション数(今回は10000回)の数だけループするので、
239s × 10000 = 2,390,000s ≒ 664時間 ≒ 28日
MNISTの学習するだけで1か月かかってしまいます。

では、なぜここまで差が出るのかを考えてみます。

仮説1. 数値微分はパラメータごとに微分計算するから遅い

まずは、めちゃくちゃ遅い方の数値微分の関数numerical_gradients()の中身を見てみます。

  # x:入力データ、t:教師データ
  def numerical_gradients(self, x, t):
    loss_W = lambda W: self.loss(x, t)
    
    grads = {}
    grads['W1'] = numerical_gradient(loss_W, self.params['W1'])
    grads['W2'] = numerical_gradient(loss_W, self.params['W2'])
    grads['b1'] = numerical_gradient(loss_W, self.params['b1'])
    grads['b2'] = numerical_gradient(loss_W, self.params['b2'])
    
    return grads

ここでは、各パラメータW1,W2,b1,b2ごとにそれぞれ微分を計算しています。
この微分計算は、
 \dfrac{df\left( x\right) }{dx}=\lim _{h\rightarrow 0}\dfrac{f\left( x+h\right) -f\left( x-h\right) }{2h}
この数式を計算している関数です。
このf(x)に当たる関数は何かというと、損失関数です。
今回使っている損失関数は、このloss()関数。

  def predict(self, x):
    W1, W2 = self.params['W1'], self.params['W2']
    b1, b2 = self.params['b1'], self.params['b2']
    
    z1 = np.dot(x,W1) + b1
    a1 = sigmoid(z1)
    z2 = np.dot(a1, W2) + b2
    a2 = softmax(z2)

    return a2
  
  def loss(self, x, t):
    y = self.predict(x)
    return cross_entropy_error(y, t)

loss()関数の中でpredict()関数を呼んでいます。
このpredict()関数が、いわゆるニューラルネットワークの計算部分です。
np.dot()ということで、内積計算をしてます。いかにも重そうな気配です。
重そうなので、このpredict()関数を1回実行する単位を1predictとして名付けます。
数値微分と誤差逆電波法で何predictの処理を実行しているかで、処理時間の比較を試みます。

さて、先ほどの微分の数式を思い出すと、分子にf(x)が2回出てきました。
そのため、1回微分するごとに2predictとなります。
そして、それがnumerical_gradients()関数の中でパラメータごと(W1,W2,b1,b2)に計算されるため、
数値微分では合計8predictすることになりそうです。

次に、誤差逆伝播法の関数gradient()を見てみます。

  def gradient(self, x, t):
      W1, W2 = self.params['W1'], self.params['W2']
      b1, b2 = self.params['b1'], self.params['b2']
      grads = {}

      batch_num = x.shape[0]

      # forward
      a1 = np.dot(x, W1) + b1 
      z1 = sigmoid(a1)
      a2 = np.dot(z1, W2) + b2
      y = softmax(a2)

      # backward
      dy = (y - t) / batch_num
      grads['W2'] = np.dot(z1.T, dy)
      grads['b2'] = np.sum(dy, axis=0)

      dz1 = np.dot(dy, W2.T)
      da1 = sigmoid_grad(a1) * dz1
      grads['W1'] = np.dot(x.T, da1)
      grads['b1'] = np.sum(da1, axis=0)

      return grads

こちらは色々とやっていますが、つまるところforward処理とbackward処理の2つの処理を実行しています。
ここで注目してほしいのが、forward処理は先ほど出てきたpredict()関数と全く同じ内容です。
つまり、forward処理=1predictと考えられます。
次にbackward処理ですが、こちらは中でnp.dot処理を3回実施しています。
predict()関数は2回でしたから、大体1.5predictとカウントしましょう。
誤差逆伝播法では合計2.5predictすることになりそうです。

ここまでのまとめを図示すると、以下のようになります。
ただ、これだと8predictと2.5predictで3.2倍にしかなりません。
実際は1万倍ほどの差が出ていたので、パラメータごとに計算するからという仮説だけでは足りなさそうです。

f:id:supa25:20210509172113p:plain
比較

これだけだとまだ足りないようなので、今回は数式で表していた微分計算の具体的なコードを見ていきます。

仮説2. 数値微分はパラメータごとに微分計算するから遅い+微分計算をループで実行しているから遅い

めちゃくちゃ遅い方の数値微分の関数numerical_gradients()の中身を再掲します。

  # x:入力データ、t:教師データ
  def numerical_gradients(self, x, t):
    loss_W = lambda W: self.loss(x, t)
    
    grads = {}
    grads['W1'] = numerical_gradient(loss_W, self.params['W1'])
    grads['W2'] = numerical_gradient(loss_W, self.params['W2'])
    grads['b1'] = numerical_gradient(loss_W, self.params['b1'])
    grads['b2'] = numerical_gradient(loss_W, self.params['b2'])
    
    return grads

このnumerical_gradient()関数の中身を見ていく…前に、1つnumpyのメソッドを紹介します。
numpyのnditer関数です。
この関数は、arrayの要素数だけループを回すときに使う関数です。
以下のコードと実行結果を見ると、何をする関数かわかると思います。

実行するコード

np_array = np.random.randn(2, 3)
print('np_array')
print(np_array)

nditer = np.nditer(np_array, flags=['multi_index'])

print('loop_start')
while not nditer.finished:
  print(nditer.multi_index)
  print(np_array[nditer.multi_index])
  nditer.iternext()

実行結果

np_array
[[-0.55654663  1.02078965 -0.18507347]
 [ 0.14575905  1.21458774 -1.91342087]]
loop_start
(0, 0)
-0.5565466339864574
(0, 1)
1.0207896548415565
(0, 2)
-0.18507347073349137
(1, 0)
0.14575904769533748
(1, 1)
1.2145877434392889
(1, 2)
-1.9134208654465972

さて、では改めてnumerical_gradient()の中身です。

def numerical_gradient(f, x):
    h = 1e-4 # 0.0001
    grad = np.zeros_like(x)
    
    it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])
    while not it.finished:
        idx = it.multi_index
        tmp_val = x[idx]
        x[idx] = tmp_val + h
        fxh1 = f(x) # f(x+h)
        
        x[idx] = tmp_val - h 
        fxh2 = f(x) # f(x-h)
        grad[idx] = (fxh1 - fxh2) / (2*h)
        
        x[idx] = tmp_val # 値を元に戻す
        it.iternext()   
        
    return grad

先ほどのnumpyのnditer関数を用いて、ループを回しています。
仮説1では、この微分計算でf(x)を2回計算している、としていましたが、
それが間違いだったようです。
確かに、f(x)は2回計算していますが、それをxの要素数分ループしているようです。

f(x)は損失関数、つまりはpredict()を計算しているため、ループ回数×2predictが本来のpredict数となります。
では、ループ回数=xの要素数はいくつなのか?を調べます。

仮説1の最後の図にも合った通り、今回はW1,W2,b1,b2それぞれでこのnumerical_gradient()を計算します。
つまり、ループ回数は、W1,W2,b1,b2の要素数の合計になります。
それぞれの要素数は以下の通り。

W1:748×100
W2:100×10
b1:1×100
b2:1×10

素数の合計は、75,910個となるので、ループ回数×2predictを計算すると151,820となります。
数値微分では合計151,820predictすることになりそうです。
(仮説1では8predictにしていたので、めちゃくちゃ増えた…)

誤差逆伝播法の方は、特に考えたかは変わらないので2.5predictのままです。
つまり、差は151,820predict ÷ 2.5predictを計算して、60,728倍となりました。
実際の計測値が約1万倍だったので、ズレてはいますがオーダーは合っていますね。

まとめ

数値微分よりも誤差逆伝播法の方が、なぜ1万倍も処理速度が速いのか?
それは
・数値微分はパラメータごとに微分計算をする
・その微分計算をパラメータの要素数分のループで実装している
ためということがわかりました。

これ、せめて微分計算をループではなくて行列で計算できれば時間短縮になるのでは?と思いましたが、
数値微分は1要素だけちょびっと変化させて計算する方法なので、この方法でやる限りは無理そうですね。
こう考えると、微分を別の方法で求められるようになった誤差逆伝播法は偉大ですね。

このつぶやきをしてから、色々手元でいじりながら、ようやく答えっぽいのにたどり着けました。
みんな気になるところだと思うんだけど、検索しても特に見つからなかったんだよなぁ。
どうしてだろうか。まぁ、今回は自分で解消できたので良かった良かった。