2017年7月6日 星期四

Backpropagation

      在training neural network 時,需要將 cost function 對每個變數作微分。因為network 很龐大的情況下,cost function 非常複雜,直接對每個變數做微分相當複雜且耗費資源。所以引進了 backpropagation  (BP)來完成這件事。BP 的原理其實就是利用微分的chain rule,先對每個local function 微分後,再從尾巴往前乘起來,結果就是cost function 對變數的微分。

        如下示意圖。變數X,Y 經過F function作用後再經過 G 的作用輸出。如果想計算G對X,Y的微分,先將每個local function 作微分 ( dG/dF , dF/dX, dF/dY) 。再從尾巴往前相乘,其結果就是G對X,Y的微分。這方法的好處是不管function 多複雜 ,只要把function 一步步用流程圖做拆解,對每個小部份做微分後(前題是可微分),從後往前乘就可以得到結果。因為與計算cost-function 計算是往前傳後不同,微分是由後往前計算,所以稱為back-propagation。
黑色為一般計算,由左往右。紅色微分由後往前傳















以下用幾個例子做說明

Example1 :X/(X+Y)

拆解成如下圖 data flow,BP 為紅色部分。從尾巴往前流動,遇到分岔時,計算local function 的微分在乘上分岔前的微分。遇到下個分岔點,再做同樣的事。特別注意,如果匯流到同一個結點,微分結果需要相加,如下圖變數X

data flow

equation



























Example2

試試比較複雜的function:

其中σ代表Sigmoid function:

一樣把data flow 圖畫出來,可以幫助計算:




Forward propagation:






Backpropagation:




因為 x 三條BP路線匯集,所以需要相加,y 同樣也是。

矩陣微分運算

上面提到的BP運算,data都是scalar 型式。如果是vector 或是matrix 呢? 該怎麼取微分呢?取完微分後該怎麼運算呢? 這邊有一個連結文件可以參考矩陣微分技巧。其實對矩陣微分看起來複雜,但只要先拆解成單一元素運算,就會比較容易理解。

Example3: Y=XW 

先從簡單的來,一層NN model,input layer 為X,shape :[ I , K ]。output layer 為 Y,shape [ I , J ]。weight 為 W, shape [ K , J ]。model 示意圖如下。


為了要簡化矩陣運算,拆解成單一element 運算,表示成:
現在計算某一個Y 元素對 X元素的微分,公式如下,其中 a = c 否則結果為0


計算某一個Y元素對W元素的微分: (必須 d = b , 否則為0)


Example4: Y=XWV

至於兩層NN model,Y元素對 X 的微分如何計算呢? 假設input layer 為X, shape [ I , K] , W shape [ K , P] , V shape [ P , J ], output layer Y shape [ I , J ]


其中 中間layer M = XW , Y=MV。現在要計算某一個Y element 對 X element 的微分公式如下:


微分相加是因為每一個Y element ,所有M element 都有貢獻,就像之前說到BP時,必須把所有回流到X element 的微分都加起來。


結果剛好等於WV 相乘後的某一個element。

Example5: 矩陣BP運算

現在我們要計算一個用矩陣運算的model,BP 該怎麼運算。假設一個model,input layer X shape [ I , K ] ,weight W shape [ K , J ],下一層layer D = X*W , shape [ I , J ]。之後的layer 我們不管,只看第一層的BP 怎麼計算。示意圖如下:



假設 dD 是後面layer 往前累加的微分,假設已知,我們想要計算 dX ,公式如下:

拆解成element 運算讓我們好理解。相加同樣是因為所有D element  在BP時都匯集到 每一個X element。




結果就跟scalar BP 一樣 dD 跟 W 相乘,差別就是矩陣相乘要考慮誰先誰後,然後有無轉置。

計算dW 基本一樣,公式如下:



sum 是因為W 對D 的每一個 row 都有貢獻。結果也很簡單 X轉置 跟 dD 相乘。