機器之心整理
參與:思、Jamin
一直以來,自動微分都在 DL 框架背後默默地運行着,本文希望探讨它到底是什麼,通過 JAX,自動微分又能怎麼用。
自動微分現在已經是深度學習框架的标配,我們寫的任何模型都需要靠自動微分機制分配模型損失信息,從而更新模型。在廣闊的科學世界中,自動微分也是必不可少的。說到底,大多數算法都是由基本數學運算與基本函數組建的。
在 ICLR 2020 的一篇 Oral 論文中(滿分 8/8/8),圖賓根大學的研究者表示,目前深度學習框架中的自動微分模塊隻會計算批量數據反傳梯度,但批量梯度的方差、海塞矩陣等其它量也很重要,它們可以在計算梯度的過程中快速算出來。
目前自動微分框架隻計算出梯度,因此就限定了研究方向隻能放在梯度下降變體之上,而不能做更廣的探讨。為此,研究者構建了 BACKPACK,它建立在 PyTorch 之上,還擴展了自動微分與反向傳播能獲得的信息。
選自論文 BACKPACK,arXiv:1912.10985。
除此之外,Julia Computing 團隊去年 7 月份也發表了一份論文,提出了可微編程系統,它能将自動微分内嵌于 Julia 語言,從而将其作為第一級的語言特性。由于廣泛的科學計算和機器學習領域都需要線性代數的支持,因此這種可微編程能成為更加通用的一種模式。
從這些前沿研究可以清晰地感受到,自動微分越來越重要。
自動微分是什麼
在數學與計算代數學中,自動微分也被稱為微分算法或數值微分。它是一種數值計算的方式,用來計算因變量對某個自變量的導數。此外,它還是一種計算機程序,與我們手動計算微分的「分析法」不太一樣。
自動微分基于一個事實,即每一個計算機程序,不論它有多麼複雜,都是在執行加減乘除這一系列基本算數運算,以及指數、對數、三角函數這類初等函數運算。通過将鍊式求導法則應用到這些運算上,我們能以任意精度自動地計算導數,而且最多隻比原始程序多一個常數級的運算。
一般而言會存在兩種不同的自動微分模式,即前向累積梯度(前向模式)和反向累計梯度(反向模式)。前向累積會指定從内到外的鍊式法則遍曆路徑,即先計算 d_w1/d_x,再計算 d_w2/d_w1,最後計算 dy/dw_2。
反向梯度累積正好相反,它會先計算 dy/dw_2,然後計算 d_w2/d_w1,最後計算 d_w1/d_x。這是我們最為熟悉的反向傳播模式,它非常符合「沿模型誤差反向傳播」這一直觀思路。
如圖所示,兩種自動微分模式都在求 dy/dx,隻不過根據鍊式法則展開的形式不太一樣。
來一個實例:誤差傳播
在統計學上,由于變量含有誤差,使得函數也含有誤差,我們将其稱之為誤差傳播。闡述這種關系的定律叫做誤差傳播定律。
先定義一個函數 q(x,y) ,我們想通過 q 傳遞 x 與 y 的不确定性信息,即 _x 與 _y。最直接的方式是随機采樣 x 與 y,并計算 q 的值,然後查看它的分布。這就是「傳播不确定性」這個概念的意義。
誤差傳播的積分公式可以是一個近似值, q(x,y) 的一般表達式可以寫為:
如果我們定義一個特殊案例,即 q(x,y)=x±y,那麼總不确定性可以寫為:
對于特例 q(x,y)=xy 與 q(x,y)=x/y ,不确定性分别為 (σ_q/q)^2 = (σ_x/x)^2 (σ_y/y)^2 與 σ_q=(x/y)* sqrt((σ_x/x)^2 (σ_y/y)^2)。
我們可以嘗試這些方法,并對比根據這些近似公式算出來的反傳誤差,以及實際發生的反傳誤差。
實戰 JAX 自動微分
Jax 是谷歌開源的一個科學計算庫,能對 Python 程序與 NumPy 運算執行自動微分,而且能夠在 GPU 和 TPU 上運行,具有很高的性能。
如下先導入 JAX,然後用三行代碼就能定義之前給出的反傳不确定性度量。
from jax *import* grad, jacfwd
import jax.numpy *as* np
def error_prop_jax_gen(q,x,dx):
jac = jacfwd(q)
return np.sqrt(np.sum(np.power(jac(x)*dx,2)))
這裡計算的反傳梯度是根據 jax 完成的,後面的反傳誤差會直接通過公式計算,并對比兩者。
1. 配置兩個具有不确定性的觀察值
我們需要使用 x 與 y 作為符号推理,但可以把它們都儲存在數組 x 中,x[0]=x、x[1]=y。
x_ = np.array([2.,3.])
dx_ = np.array([.1,.1])
2. 加減法
在 (,)=± 這一特例情況下,誤差傳播公式可以簡化為
上圖所示,通過誤差傳播公式計算出來的值與 JAX 計算出來的是一緻地。
3. 乘除法
在 (,)= 與 (,)=/ 這兩種特例中,誤差傳播公式可以寫為:
4. 幂
對于特例 (,)=^*^,傳播公式可以表示為:
我們可以寫成
JAX 的使用非常多樣,甚至能直接使用它搭建神經網絡。例如 JAXnet 框架,它是一個基于 JAX 的深度學習庫,它的 API 提供了便利的模型搭建體驗。比如說,以下代碼就能建個神經網絡:
from jaxnet import *
net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), logsoftmax)
此外,不久之前,DeepMind 也發布了兩個新庫:在 Jax 上進行面向對象開發 的 Haiku 和 Jax 上的強化學習庫 RLax。JAX 這樣的通用自動微分庫也許能在更廣泛的領域發揮作用。
,更多精彩资讯请关注tft每日頭條,我们将持续为您更新最新资讯!