学生による学生のためのデータサイエンス勉強会

【ゼロから作るディープラーニング#11】誤差逆伝播法の解説【p123-146】

https://bdarc.net/wp-content/uploads/2020/05/icon-1.pnghistoroid

こんにちは。清野です。

今回は誤差逆伝播法についてまとめました。

誤差逆伝播法の解説

今回から誤差逆伝播法について学んでいきます。前回は、勾配降下法について解説しています。

勾配を理解しないことには誤差逆伝播法を理解できないので、勾配について理解が足りないと思ったら前の記事を参考にしてください。

【ゼロから作るディープラーニング#7】数値微分と勾配法【p97-112】

計算グラフとは何か

計算グラフとは、ノードとエッジによって表現されるデータ構造を指します。

これまで使ってきたこの図もグラフです。計算ではないので計算グラフではありませんが、情報がどう流れていくのかを示している点でグラフといえます。

計算グラフの練習問題

なにはともあれ、計算グラフで問題を解いてみましょう。

  • 100円のりんごを2つ買う
  • 消費税は10%

このような計算をグラフで考えてみましょう。

ノードには計算(処理)を書き、矢印の上には計算過程にある数値を書きます。計算グラフはこのように表すことができます。

変数の導入

では、りんごを買う個数や税率を変数にしてみましょう。

こうすれば、りんごの個数や税率が変わっても計算グラフの形式は変化しませんね。

またこのように左から右へと計算することを順伝播といいます。

これから行う逆伝播は右から左というわけですね。

計算グラフによる逆伝播

計算グラフの利点は、全体としては複雑な計算を単純な計算に分割できることです。

計算過程が複雑になっても、ノードごとの計算で見れば単純です。

さらに、実は逆伝播の計算も楽に理解することができます

逆伝播の練習問題

では逆伝播の練習問題を解いてみましょう。

りんごの値段が変化したとき、最終的な支払いがどう変化するかを考えます。

この問題は、計算全体を\(f(x)\)としたとき、りんごの価格\(x\)で微分することを意味します。

上のイラストで、赤い文字と矢印で書かれた部分が逆伝播に相当します。

左から2つ目のノードは、税率で1.1倍する部分です。つまり\(1.1x\)なので、\(x\)で微分すれば1.1になります。

同じように、左から1つ目のノードの微分は2です。入ってくる値が1.1なので、計算結果は2.2になります。

連鎖律とは

連鎖律とは、微分による逆伝播を繰り返していくときのルールです。

ノードが多く連結していると、微分の微分、そのまた微分というように微分がつながっていきます。そのときの計算手順を連鎖律と呼んでいます。

$$z=(x+y)^2$$

このような\(z\)を微分するとします。

ここで、\(t = x + y\)とおきます。

したがって、\(z=t^2\)ですね。

連鎖律

ある関数が合成関数で表される場合、その合成関数の微分は、合成関数を構成するそれぞれの関数の微分の積によって表すことができる。

この連鎖律に従うと、\(z\)を\(x\)で微分することは、\(z\)を構成する関数の微分の積を求めることですね。

この場合、 \(z\)を構成する関数とは、\(t^2\)と\(t\)です。したがって、これらを微分すればOKです。

$$\frac{\partial t^2}{\partial t}=2t$$

$$\frac{\partial t}{\partial x}=1$$

したがって

$$\begin{align}
\frac{\partial z}{\partial x} &= \frac{\partial t^2}{\partial t} \frac{\partial t}{\partial x}\\
&= 2t \times 1 \\ &= 2(x+y)
\end{align}$$

連鎖律と計算グラフ

微分値自体は、\(\frac{\partial y}{\partial x}\)ですが、これにもとの値であるEを乗算していきます。

これが計算グラフおける連鎖律です。

逆伝播

加算ノードの逆伝播

加算ノードは簡単です。\(x+y\)に対して、\(x\)と\(y\)のそれぞれで微分するので1です。

つまり加算ノードはそのまま値が逆伝播するということです。

乗算ノードの逆伝播

乗算ノードは、\(xy\)に対する微分なので、加算ノードとは結果が異なりますよね。

レイヤーの実装

では実装していきましょう。

乗算ノードの実装

ノードの実装ですから、計算グラフの円の中の処理について書いています。

# layer_naive
class MulLayer:
    def __init__(self):
        self.x = None
        self.y = None

    def forward(self, x, y):
        self.x = x
        self.y = y                
        out = x * y

        return out

    def backward(self, dout):
        dx = dout * self.y
        dy = dout * self.x

        return dx, dy

ここでのdoutは上流のエッジの値です。

from layer_naive import *

apple = 100
apple_num = 2
tax = 1.1

mul_apple_layer = MulLayer()
mul_tax_layer = MulLayer()

# forward
apple_price = mul_apple_layer.forward(apple, apple_num)
price = mul_tax_layer.forward(apple_price, tax)

# backward
dprice = 1
dapple_price, dtax = mul_tax_layer.backward(dprice)
dapple, dapple_num = mul_apple_layer.backward(dapple_price)

計算グラフ全体を実装するとこのようになります。

加算ノードの実装

同じように加算レイヤーも実装していきましょう。

# # layer_naive
class AddLayer:
    def __init__(self):
        pass

    def forward(self, x, y):
        out = x + y

        return out

    def backward(self, dout):
        dx = dout * 1
        dy = dout * 1

        return dx, dy

これもそんなに難しくないですね。計算グラフですでにやった部分です。

活性化関数レイヤーの実装

では活性化関数ノードを実装していきましょう。

ReLUレイヤーの実装

まずReLUを思い出しましょう。

$$
ReLU(x)=
\begin{cases}
x \quad x \geqq 0 \\
0 \quad x<0 \\
\end{cases}
$$

はい。こんなのでしたね。

というわけで、微分すると以下のようになります。

$$
\frac{\partial ReLU(x)}{\partial x}=
\begin{cases}
1 \quad x \geqq 0 \\
0 \quad x<0 \\
\end{cases}
$$

ちょっと変な書き方ですけど、意味は伝わりますよね。

class Relu:
    def __init__(self):
        self.mask = None

    def forward(self, x):
        self.mask = (x <= 0)
        out = x.copy()  # xとは別のオブジェクトを作る
        out[self.mask] = 0

        return out

    def backward(self, dout):
        dout[self.mask] = 0
        dx = dout

        return dx

self.maskは真偽値ですね。

Sigmoidレイヤーの実装

Sigmoid関数でも同じことをします。

$$Sigmoid(x)=\frac{1}{1+exp(-x)}$$

class Sigmoid:
    def __init__(self):
        self.out = None

    def forward(self, x):
        out = sigmoid(x)
        self.out = out
        return out

    def backward(self, dout):
        dx = dout * (1.0 - self.out) * self.out

        return dx

https://bdarc.net/wp-content/uploads/2020/05/icon-1.pnghistoroid

ちょっと中途半端ですが、私の担当はここまでです。

難しく見えますが、今回の計算グラフを使った逆伝播についてはよく理解しておきましょう。

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です