ややプログラム紀行

博士2年のプログラムに関する日記

natural gradient

natural gradient descentは見かけるたびに怪しい理解のままスルーしてたので、ちょっとここら辺でちゃんと確認しておくかと思って↓を読んだ
arxiv.org

natural gradient descentといえば勾配にフィッシャー情報行列の逆行列をかけたものを使って更新してく奴だけど、どうも文献によってフィッシャー情報行列という名前で経験フィッシャーの方を指したりで定義がばらついてるみたいなので、改めて状況設定を定義する

natural gradient descent

データが q(y,x) = q(y\mid x)q(x)からi.i.d.に生成されるとし、一方モデルの分布を p(x,y\mid \theta) = p(y\mid x, \theta)q(x)とする*1とき、フィッシャー情報行列は

 \displaystyle F = \mathbb{E}_{p(x,y)} \left[ \nabla \log p(x,y \mid \theta) \nabla \log p(x,y \mid \theta)^\top \right]
として定義される

最急降下法で通常考える勾配はユークリッド空間での勾配であるのに対し、natural gradientはフィッシャー情報行列 Fによって定義されるリーマン計量からなるリーマン多様体での勾配を指し、目的関数を h(\theta)とすれば

 \displaystyle \tilde{\nabla}h = F^{-1} \nabla h
と表される


フィッシャー情報行列によるリーマン多様体を考える直感的理解として以下のようなものがある

分布の違いを表すKLダイバージェンステイラー展開によって

 \displaystyle \mathrm{KL}(p(x,y\mid \theta+d) || p(x,y\mid \theta)) = \frac{1}{2}d^\top F d + O(d^3)
と展開できることを利用すると、
 \displaystyle -\sqrt{2}\frac{\tilde{\nabla}h}{\| \nabla h \|_{F^{-1}}} = \lim_{\epsilon \to 0} \frac{1}{\epsilon} \underset{d: \mathrm{KL}(p(x,y\mid \theta+d) || p(x,y\mid \theta) \leq \epsilon^2}{\arg\min} h(\theta + d)
を示すことができ、natural gradientの方向に下ることは、同程度のKLダイバージェンスの変化で最大限目的関数を小さくすることに対応することがわかる*2


以上がnatural gradientの概要だが、これだけでは多くの場合natural gradient descentが最急降下法より少ない回数で収束することの直接的な説明にはなっていない

そこでここからは、natural gradient descentがフィッシャー情報行列をヘッセ行列の代替として用いた上での2nd order methodになっていることを確認する

2nd-order optimization

目的関数の \theta_kを中心とする空間を二次関数 M_k(\delta) = h(\theta_k) + \nabla h(\theta_k)^\top \delta + \frac{1}{2}\delta^\top B_k \deltaで近似することを考える

ここで B_kはcurvature matrixであり、 B_kが正定値ならば平方完成によってM_k(\delta) \delta^* = -B_k^{-1}\nabla hで最小値を取ることがわかるので、特にnatural gradient descentは B_kとしてフィッシャー情報行列 Fを採用した2nd-order optimizationだと考えることができる

フィッシャー情報行列とヘッセ行列の期待値

入力の分布 q(x)を学習せず、かつその分布が未知である場合は q(x)を経験分布\hat{q}(x) = \frac{1}{|S|}\sum_{(x,y) \in S} \delta_xで近似することが考えられ、この場合

 \begin{align*} 
F &= \mathbb{E}_{q(x)} \left[ \mathbb{E}_{p(y\mid x, \theta)} \left[ \nabla \log p(y \mid x, \theta) \nabla \log p(y \mid x, \theta)^\top \right] \right] \\
&\approx \frac{1}{|S|} \sum_{(x,y) \in S} \mathbb{E}_{p(y\mid x, \theta)} \left[ \nabla \log p(y \mid x, \theta) \nabla \log p(y \mid x, \theta)^\top \right]
\end{align*}
となる

一方、モデルとなる分布p(y\mid x, \theta)が決定的関数 f(x, \theta)と分布 r(y\mid z)を用いて p(y\mid x ,\theta) = r(y \mid f(x, \theta))と分解できるとし、さらに損失関数L(y,z) L(y,z) = -\log r(y \mid z)によって定義され*3、目的関数 h h(\theta) = \frac{1}{|S|} \sum_{(x,y) \in S} L(y, f(x, \theta))と定義されているとすると、目的関数 hのヘッセ行列は

 \begin{align*} 
H &= \frac{1}{|S|} \sum_{(x, y) \in S} \left[ \nabla L(y, f(x,\theta)) \nabla L(y, f(x, \theta))^\top \right] \\
&= \frac{1}{|S|} \sum_{(x, y) \in S} \mathbb{E}_{\hat{q}(y \mid x)} \left[ \nabla L(y, f(x, \theta)) \nabla L(y, f(x, \theta))^\top \right] \\
&=  \frac{1}{|S|} \sum_{(x,y) \in S} \mathbb{E}_{\hat{q}(y\mid x)} \left[ \nabla \log p(y \mid x, \theta) \nabla \log p(y \mid x, \theta)^\top \right]
\end{align*}
と表すことができる(ここで \hat{q}(y \mid x) q(y \mid x)の経験分布とする)

以上より、フィッシャー情報行列とヘッセ行列の違いは、期待値の分布を p(y\mid x, \theta)とするか q(y \mid x)とするかによるものだと考えることができる

generalized Gauss-Newton (GGN) matrix

上の類似点に加えて、更なるフィッシャー情報行列とヘッセ行列の類似点を見るために、まずはGGN matrixというものの定義を確認する

 L(y,z)z=f(x,\theta)に関するヘッセ行列を H_L f(x,\theta)\thetaに関するJacobianを J_fと表すことにするとき、GGN matrix G G = \frac{1}{|S|} \sum_{(x,y) \in S} J_f^\top H_L J_fと定義される

目的関数h(\theta)のヘッセ行列H

 \begin{align*} 
H = \frac{1}{|S|} \sum_{(x,y) \in S} \left( J_f^\top H_L J_f + \sum_{j=1}^m \left[ \left. \nabla_z L(y,z) \right|_{z = f(x,\theta)} \right]_j H_{[f]_j} \right)
\end{align*}
と式変形できるので、GGN matrixは上式の二項目を無視したものと捉えることが出来、また \thetaが局所解である時には損失関数の勾配がほぼ0になることから二項目の値は0に近似でき、GGN matrixとヘッセ行列が一致することがわかる

また、GGN matrixのヘッセ行列に対する利点として、GGN matrixは常に正定値行列であることが挙げられる*4

フィッシャー情報行列とGGN matrix

モデルとなる分布p(y\mid x, \theta)が決定的関数 f(x, \theta)と分布 r(y,z)を用いて p(y\mid x ,\theta) = r(y, f(x, \theta))と分解できるとすると、フィッシャー情報行列は

 \begin{align*} 
F &= \mathbb{E}_{q(x)} \left[ \mathbb{E}_{p(y \mid x, \theta)} \left[ \nabla \log p(y \mid x, \theta) \nabla \log p(y \mid x, \theta)^\top \right] \right] \\
&=  \mathbb{E}_{q(x)} \left[ \mathbb{E}_{p(y \mid x, \theta)} \left[ J_f^\top \nabla_z \log r(y,z)  \nabla_z \log r(y,z)^\top J_f \right] \right] \\
&=  \mathbb{E}_{q(x)} \left[ J_f^\top \mathbb{E}_{p(y \mid x, \theta)} \left[ \nabla_z \log r(y,z)  \nabla_z \log r(y,z)^\top \right] J_f \right] \\
&=  \mathbb{E}_{q(x)} \left[ J_f^\top F_R J_f \right] \\
&\approx \frac{1}{|S|} \sum_{(x,y) \in S}  J_f^\top F_R J_f
\end{align*}
と式変形できる

よってGGN matrixとフィッシャー情報行列の違いは H_L F_Rにあり、損失関数L L(y,z) = -\log r(y\mid z)と表されるときは

 \begin{align*} 
F_R &= \mathbb{E}_{p(y \mid x, \theta)} \left[ \nabla_z \log r(y,z)  \nabla_z \log r(y,z)^\top \right] \\
H_L &= \nabla_z \log r(y,z)  \nabla_z \log r(y,z)^\top
\end{align*}
と表されることから両者の差は期待値を取るか否かにあることがわかり、これらが等しくなる重要なケースとして r(y \mid z)が指数方分布族であることが挙げられる*5

以上の話より、パラメータが局所解にある時はヘッセ行列とGGN matrixが近似的に等しくなり、かつ例えば r(y \mid z)が指数方分布族の時はGGN matrixとフィッシャー情報行列が等しくなるため、natural gradient descentは近似的にヘッセ行列による2nd-order optimizationになっていることがわかる



このほかに、natural gradient descentが2nd-order optimizationと見做せることから、後者において用いられるdampingなどのテクニックを援用する話などがあって、割とためになった気がする

*1:今回は xが与えられた時の yの分布を推定することのみが問題であり、入力そのものの分布 q(x)は学習しなくて良い

*2:また、このことからnatural gradientは分布のパラメータの取り方に依存しないことがわかる

*3:すなわち L(y, f(x, \theta) ) = -\log p(y \mid x, \theta )となり、例えば二乗損失は正規分布のnegative log likelihoodだと解釈することができる

*4:ので、例えば2nd-order optimizationに利用できる

*5:二乗誤差やクロスエントロピーなどがその場合に含まれる