Batch normalizationの逆伝播の算出式を計算グラフを辿って求める

2024年5月18日

◆仕事や勉強の息抜きに。。。

Batch Normalizationの計算式

ニューラルネットワークの学習において、各層での出力値の分布(アクティベーション分布)が適当な広がりを持つように調整することは重要で、そのために重みの初期値の設定がポイントでした。

【勾配消失しない重みの初期値】Excelでモンテカルロシミュレーションしてみた

 

しかし初期値が適切に設定されていても、学習が進むにつれて重みが更新されてくるとアクティベーション分布が乱れてくる心配があります。

乱れないように学習率を小さくすると学習が遅くなります。

そこで、アクティベーション分布が適当な広がりを持つように強制的に調整しようというのがBatch normalizationのアイデアです。

 

Batch normalizationでは下記の式で各層の出力値を調整します。

最初の式は平均を、2番目の式は分散を求める式です。

そして、この2つを使って平均ゼロ、分散1の標準正規分布に変換しているのが最後の式です。

ところが、このような標準正規分布を次の層への入力とするのが必ずしも良いとは限らず、かえって元の出力の方が良かったということもありえます。

そのため次式で更に調整します。

γβハイパーパラメータです。

仮にγをσB2、βをμBとすると、元の分布に戻ります

つまり、γとβを調整することによって色々な正規分布に変えることができるということになります。

実際、これらγやβの勾配を求めて最適値を求めて学習させていくということが行われます。

 

学習といえばγやβだけでなく、yiからxiへの逆伝播も求める必要がありますが、「ゼロから作るDeep Learning」では逆伝播式の導出が省かれていて、

Understanding the backward pass through Batch Normalization Layer

を参照して下さいと書かれています。

そこで、これを読んで理解したことをまとめてみました。

 

Batch Normalizationの計算式を計算グラフで描く

Batch normalizationの計算グラフは次のようになります。

クリックすると拡大します

 

順伝播ではこの中のそれぞれのノードを信号が流れていきますが、各段階の信号のうちのいくつかは逆伝播でも使いますので、最初に順伝播についてザッと見ておきます。

ノードを流れる信号は行列の形になっていますが、行列のサイズには次の2種類あります。

 

各ニューロンが受ける入力信号の数をD、ミニバッチとしてNのデータをまとめて処理するとするとN×Dの行列になるのが基本ですが、γやβはN個のデータについて共通で、平均や分散もN個のデータから1つだけできるものなので、γ/β/平均/分散に関わる変数1×Dの行列になります。

 

これを念頭に置いて、最初のノードから見ていきます。

クリックすると拡大します

 

ここではN×D行列の入力信号xの平均を計算します。

1バッチでN個のデータがあるので、列ごとの平均を計算して結果を1×Dの行列に格納します。

クリックすると拡大します

 

ここでは先ほど計算したxの1バッチの平均を、各xの値から引いて、結果をN×D行列に格納します。

N×D行列と1×D行列の引き算は数学的にはできませんが、Pythonのブロードキャスト機能により1×DをN×Dに引き延ばすことにより行います。

Excelでも絶対参照を使えばできます。

クリックすると拡大します

 

ここでは分散を計算して、それを分母に持ってくる計算をしています。

平方和の平均を取るところから1×Dの行列に変わります。

クリックすると拡大します

 

ここではxと平均との偏差に分散γを掛け、βを足して最終結果を算出します。

分散、γ、βは1×Dの行列なので、ブロードキャスト機能によりN×Dに拡大してから計算します。

 

逆伝播の計算式を計算グラフから求める

次に、順伝播で作ったいくつかの信号を使いながら逆伝播の計算を追っていきます。

まずBatch normalizationの下流のレイヤーからN×D行列の信号が逆伝播してきます。(dout

その信号が最初に加算(+)のノードに逆伝播として入ります。

加算ノードではそのまま信号が逆伝播しますが、βの方向へ伝わる時には1×Dのサイズになるので列ごとに値を合算します。

次の乗算(*)ノードの逆伝播では、順伝播の時の信号をひっくり返した値が掛けられます。

γ方向へはdoutにxhatが掛けられますが、この乗算はアダマール積になります。

xhat方向へはγが掛けられますが、γは1×Dの行列なのでN×Dにブロードキャストして掛けられます。

クリックすると拡大します

 

次にdxhatが上流に逆伝播します。

乗算ノードでは順伝播信号がひっくり返されて掛けられるので、xmu方向へはdxhatにivarが掛けられますが、ivarは1×DなのでN×Dにブロードキャストしてからアダマール積として掛けられます。

一方ivar方向へはdxhatにxmuがアダマール積で掛けられますが、ivarが1×Dなのでdivarも1×Dにならないといけないため、列単位で合計します。

次にdivarが1/xのノードに逆伝播して通ると、-1/x2されます。

これは1/xの微分が-1/x2だからです。

同様に次の√(x+ε)のノードを逆伝播すると、1/2*1/√(x+ε)されます。

√xの微分は1/2*1/√xだから)

クリックすると拡大します

 

次にdvarが平均値を計算するノードに逆伝播します。

このノードの順伝播ではN×D行列を1×D行列に変換したため、逆伝播では1×DをN×Dに変換します。

そのため1×Dを行方向にブロードキャストしてから各要素を1/Nします。

次のx2を計算するノードではx2を微分すると2xなので、順伝播の信号を2倍した値がアダマール積で掛けられます。

クリックすると拡大します

 

ようやく最後のノードまで辿り着きました。

減算(-)ノードの逆伝播は、順伝播でマイナスする方向へは-1倍された値が逆伝播し、プラスする方向へはそのまま伝播します。

平均を取るノードの逆伝播は先述の通りです。

クリックすると拡大します

 

以上の細切れの計算を全部繋げれば、doutからdxへの逆伝播関数が計算できます。

また、γβ勾配も求まります。

Pythonのコードは以下の通りです。

def batchnorm_backward(dout, cache):

    #unfold the variables stored in cache
    xhat,gamma,xmu,ivar,sqrtvar,var,eps = cache

    #get the dimensions of the input/output
    N,D = dout.shape

    #step9
    dbeta = np.sum(dout, axis=0)
    dgammax = dout #not necessary, but more understandable

    #step8
    dgamma = np.sum(dgammax*xhat, axis=0)
    dxhat = dgammax * gamma

    #step7
    divar = np.sum(dxhat*xmu, axis=0)
    dxmu1 = dxhat * ivar

    #step6
    dsqrtvar = -1. /(sqrtvar**2) * divar

    #step5
    dvar = 0.5 * 1. /np.sqrt(var+eps) * dsqrtvar

    #step4
    dsq = 1. /N * np.ones((N,D)) * dvar

    #step3
    dxmu2 = 2 * xmu * dsq

    #step2
    dx1 = (dxmu1 + dxmu2)
    dmu = -1 * np.sum(dxmu1+dxmu2, axis=0)

    #step1
    dx2 = 1. /N * np.ones((N,D)) * dmu

    #step0
    dx = dx1 + dx2

    return dx, dgamma, dbeta

Understanding the backward pass through Batch Normalization Layer より抜粋