ベイズ推論で正規分布使うなら標準形のほうがいいかもよ?

2019-12-26

指数型分布族とは

確率変数 xRmx \in \mathbb{R}^m を規定する密度関数がパラメータ θRd\theta \in \mathbb{R}^d を用いて以下の形に表現できるとき、この分布を指数型分布族といいます。

p(xθ)=1Z(θ)h(x)eθϕ(x)p(x|\theta) = \frac{1}{Z(\theta)}h(x)e^{\theta^\top \phi(x)}

ただし Z(θ)Z(\theta) は正規化項なので以下の積分で与えられます。

Z(θ)=Rmh(x)eθϕ(x)dxZ(\theta) = \int_{\mathbb{R}^m}h(x)e^{\theta^\top \phi(x)} dx

θ\theta を標準パラメータ (canonical parameter) と言い、この密度関数の形を標準形と言います。

正規分布や二項分布、ポアソン分布など、よく出てくる分布は指数型分布族です。

普通の正規分布と指数型分布族としての正規分布の関係

普段正規分布を扱うときは、以下のように平均パラメータと分散パラメータ (両者をあわせてモーメントパラメータといいます) を使って表記することが多いかと思います。

N(xμ,σ2)\mathcal{N}(x | \mu, \sigma^2)

指数型分布族の標準形ではパラメータのとり方が異なります。これを以下のように表すことにしましょう。

Nc(xξ,λ)\mathcal{N}_c(x | \xi, \lambda)

ξ,λ\xi, \lambda 2つあわせたものが標準パラメータです。

これらのパラメータの変換則を1次元と多次元に分けて紹介していきます。

1次元

1次元正規分布におけるパラメータの変換則は次のようになります。

N(xμ,σ2)=Nc(xμσ2, 1σ2)Nc(xξ,λ)=N(xξλ, 1λ)\begin{aligned} &\mathcal{N}(x | \mu, \sigma^2) = \mathcal{N}_c\left(x \Bigg| \frac{\mu}{\sigma^2},\ \frac{1}{\sigma^2}\right) \\[.5em] &\mathcal{N}_c(x \vert \xi, \lambda) = \mathcal{N}\left(x \Bigg| \frac{\xi}{\lambda},\ \frac{1}{\lambda}\right) \end{aligned}

次に、密度関数の数式としてはどのような違いがあるのかを見てみましょう。ただし指数の部分だけを取り出して比べてみます。

N(xμ,σ2):(xμ)2σ2Nc(xξ,λ):λ2x2+ξx\begin{aligned} &\mathcal{N}(x | \mu, \sigma^2): -\frac{(x - \mu)^2}{\sigma^2} \\[.5em] &\mathcal{N}_c(x | \xi, \lambda): -\frac{\lambda}{2}x^2 + \xi x \end{aligned}

とくにベイズ推論では密度関数の全体を計算する必要がないことが多いので、上記の形だけ覚えておけば十分です。

多次元

多次元正規分布におけるパラメータの変換則は次のようになります。

N(μ,Σ)=Nc(Σ1μ, Σ1)Nc(ξ,Λ)=N(Λ1ξ, Λ1)\begin{aligned} &\mathcal{N}(\mu, \Sigma) = \mathcal{N}_c\left(\Sigma^{-1}\mu,\ \Sigma^{-1}\right) \\[.5em] &\mathcal{N}_c(\xi, \Lambda) = \mathcal{N}\left(\Lambda^{-1}\xi,\ \Lambda^{-1}\right) \end{aligned}

こちらも密度関数の指数部分を比べてみます。

N(xμ,Σ):12(xμ)Σ1(xμ)Nc(xξ,Λ):12xΛx+xξ\begin{aligned} &\mathcal{N}(x | \mu, \Sigma): -\frac{1}{2}(x - \mu)^\top \Sigma^{-1} (x - \mu) \\[.5em] &\mathcal{N}_c(x | \xi, \Lambda): -\frac{1}{2}x^\top\Lambda x + x^\top \xi \end{aligned}

標準形のメリット

ベイズ推論で正規分布を扱う場合, 事後分布は標準形で表したほうがシンプルになることがあります。

しかも Julia の Distributions パッケージでは NormalCanonMvNormalCanon という型を用意してくれているため, コードもきれいに保てます。

計算公式

ベイズ推論では複数の密度関数の積を計算することがよくありますが、2つの正規分布の積は次のように簡単に計算できます。ただし本記事において \simeq は定数倍を除いて等しいことを表すことにします。

Nc(xξ0,Λ0)Nc(xξ1,Λ1)Nc(xξ0+ξ1,Λ0+Λ1)\mathcal{N}_c(x | \xi_0, \Lambda_0)\mathcal{N}_c(x | \xi_1, \Lambda_1) \simeq \mathcal{N}_c(x | \xi_0 + \xi_1, \Lambda_0 + \Lambda_1)

ベイズ推論における例

ユーザー uu のアイテム ii に対する評価値が Xu,iX_{u, i} であるような行列 XX を生成するモデルを考えます。各評価値 Xu,iX_{u, i} は user factor vector WuW_u と item factor vector HiH_i の内積 WuHiW_u^\top H_i を平均とする正規分布から生成されていると仮定します。各 Wu,HiW_u, H_i の事前分布にも正規分布を設定します。まとめると、以下のようなモデルを考えているということです。

XRU×IWuRD(u=1,,U)HiRD(i=1,,I)p(Xu,iWu,Hi)=N(Xu,iWuHi,λ1)p(Wu)=N(WuμuW,(ΛuW)1)p(Hi)=N(HiμiH,(ΛiH)1)\begin{aligned} & X \in \mathbb{R}^{U \times I} \\ & W_u \in \mathbb{R}^D \quad (u = 1, \ldots, U) \\ & H_i \in \mathbb{R}^D \quad (i = 1, \ldots, I) \\[1em] & p(X_{u, i} | W_u, H_i) = \mathcal{N}\left(X_{u, i} | W_u^\top H_i, \lambda^{-1}\right) \\ & p(W_u) = \mathcal{N}\left(W_u \Big| \mu_u^W, \left(\Lambda_u^W\right)^{-1}\right) \\ & p(H_i) = \mathcal{N}\left(H_i \Big| \mu_i^H, \left(\Lambda_i^H\right)^{-1}\right) \\ \end{aligned}

モデル設計時には平均などの意味を考えるため、モーメントパラメータのほうが適していますね。

さて、このモデルを学習するためにギブスサンプリングのアルゴリズムを導出してみましょう。

まずは WuW_u をサンプルすることを考えます。答えを先に言ってしまうと、WuW_u の条件付き分布は正規分布になります。しかも (記事の流れから当然ですが) 標準形で表したほうがシンプルになります。

p(Xu,iWu,Hi)p(X_{u, i} | W_u, H_i) の指数部分を取り出すと、以下のように変形できます。ただし本記事において \simWuW_u に関係ない定数部分を除いて等しいことを表します。

λ2(Xu,iWuHi)2λ2(WuHi)2+λXu,iWuHi=12Wu(λHiHi)Wu+Wu(λXu,iHi)\begin{aligned} -\frac{\lambda}{2}(X_{u, i} - W_u^\top H_i)^2 &\sim -\frac{\lambda}{2}(W_u^\top H_i)^2 +\lambda X_{u, i}\cdot W_u^\top H_i \\ &= -\frac{1}{2}W_u^\top (\lambda H_i H_i^\top) W_u + W_u^\top (\lambda X_{u, i}H_i) \end{aligned}

これは Nc(WuλXu,iHi, λHiHi)\mathcal{N}_c(W_u | \lambda X_{u, i}H_i,\ \lambda H_i H_i^\top) の指数部分と同じです。

さて、今は uu を固定して WuW_u について考えていましたから、計算すべきは ip(Xu,iWu,Hi)\prod_i p(X_{u, i} | W_u, H_i) です。事前分布とまとめることで、WuW_u の条件付き分布は以下のように求まります。

iNc(WuλXu,iHi, λHiHi)N(WuμuW, (ΛuW)1)Nc(WuλiXu,iHi, λiHiHi)Nc(WuΛuWμuW, ΛuW)Nc(WuΛuWμuW+λiXu,iHi,ΛuW+λiHiHi)\begin{aligned} &\prod_i\mathcal{N}_c\left(W_u | \lambda X_{u, i}H_i,\ \lambda H_i H_i^\top\right) \mathcal{N}\left(W_u | \mu_u^W,\ (\Lambda_u^W)^{-1}\right) \\ &\simeq \mathcal{N}_c\left(W_u \Bigg| \lambda\sum_i X_{u, i}H_i,\ \lambda\sum_i H_i H_i^\top\right) \mathcal{N}_c\left(W_u | \Lambda_u^W\mu_u^W,\ \Lambda_u^W\right) \\ &\simeq \mathcal{N}_c\left(W_u \Bigg| \Lambda_u^W\mu_u^W + \lambda\sum_i X_{u, i}H_i,\quad \Lambda_u^W + \lambda\sum_i H_i H_i^\top\right) \end{aligned}

HiH_i の条件付き分布も同様です。

Nc(HiΛiHμiH+λuXu,iWu,ΛiH+λuWuWu)\mathcal{N}_c\left(H_i \Bigg| \Lambda_i^H\mu_i^H + \lambda\sum_u X_{u, i}W_u,\quad \Lambda_i^H + \lambda\sum_u W_u W_u^\top \right)

パラメータの足し算だけで計算できるので楽ちんですね!