A VAE

我们先从 VAE 的角度理解生成模型的优化目标如何获得

A.1 背景知识:数值计算vs采样计算

给定已知概率密度函数p(x) ,那么x的期望定义为

(0)E[x]=xp(x)dx

如果求解基于概率密度函数的期望的数据计算,也就是数据积分,可以选择若干有代表性的点 x0<x1<...<xn ,得到数值解:

(2)E[x]i=1nxip(xi)(xixi1)

这里本质就是借助离散化区间计算积分数值。
此外,我们可以利用采样方法计算期望的近似数值解:从p(x)中采样若干个点x1,x2,...,xn ,我们有:

(3)E[x]1ni=1nxi,xp(x)

要理解Eq.2Eq.3的区别,重点在于理解概率密度函数的含义:描述样本 x 出现的可能性
Eq.2中包含了概率的计算而Eq.3中仅有x的计算,这是因为Eq.3中的采样过程已经包含了概率过程,概率大的xi出现的次数也多。
更一般的,我们有

(4)E[f(x)]=f(x)p(x)dx1ni=1nf(xi),xip(xi)

上式为蒙特卡洛模拟的基础。

此外,关于概率期望,我们后面会重点用到如下视角的变化:

(5)KL(p(x)||q(x))=p(x)lnp(x)q(x)dx=Exp(x)[lnp(x)q(x)]

VAE的优化目标

VAE 作为生成模型的经典任务:给定若干数据样本:{x1,...,xn},整体用x描述,我们希望借助隐变量z描述x的分布p~(x):

(6)q(x)=q(x|z)q(z)dz

最终我们希望构造的分布q(x)能逼近目标数据分布p~(x),即两分布的 KL散度足够小。

由于引入了隐变量z,推导原始KL散度是复杂的,需要引入各种变分推理以及 ELBO。因此,我们考虑优化两个概率的联合分布距离:

(7)KL(p(x,z)||q(x,z))=p(x,z)lnp(x,z)q(x,z)dzdx

对于联合分布,借助贝叶斯公式有

(8)q(x,z)=q(x|z)q(z)(9)p(x,z)=p~(x)p(z|x)

注意Eq.8Eq.9中我们分别从不同从面引入了条件概率。具体的,q(x|z)可视为解码器,而p(z|x)则为编码器,因为p(x,z)=p~(x)p(z|x)来源与真实数据样本x~

(10)KL(p(x,y)||q(x,y))=p~(x)[p(z|x)lnp~(x)p(z|x)q(x,z)dz]dx=Exp~(x)[p(z|x)lnp~(x)+p(z|x)lnp(z|x)q(x,z)dz]

对于积分的前一项,有

(11)Exp~(x)[p(z|x)lnp~(x)dz]=Exp~(x)[lnp~(x)p(z|x)dz]=Exp~(x)[lnp~(x]=C

为常数。所以我们将优化目标设为

(12)L=Exp~(x)[p(z|x)lnp(z|x)q(x,z)dz]=Exp~(x)[p(z|x)lnp(z|x)q(x|z)q(z)dz]=Exp~(x)[p(z|x)lnq(x|z)dz+p(z|x)lnp(z|x)q(z)dz]=Exp~(x)[Ezp(z|x)[ln(q(x|z))]+KL(p(z|x)||q(z))]

Eq.12中的两项,即分别对应了 VAE 的 MSE 损失和 KL 损失。

具体实现

分析得到了上面的优化目标,其实还不能直接实现的。因为要求解最小化损失L,我们是需要其中的q(z),q(x|z),p(z|x)的具体解析式的。下面我们逐项分析。
首先为了方便采样,我们假设zN(0,1),如何q(z)有了解析式。
对于另外两项,我们需要在基于一定先验假设的情况下,使用神经网络拟合!
首先,假设编码器p(z|x)服从正态分布,其均值和方差由x决定,由此,我们构建网络μθ(x) , σθ(x)分别建模均值和方差,那么就可以获得解析式:

(13)p(z|x)=1k=1d2πσθ2(x)exp(12zμθ(x)σθ(x)2)

这样一来,借助两个正态分布的 KL 散度,Eq.12中的 KL 散度就可以详细计算

(14)KL(p(z|x)q(z))=μθ2(x)+σθ2(x)lnσθ2(x)1

现在只剩解码器q(x|z),该选择什么分布呢?。在 VAE 的原始论文中,给出了两种候选方案:伯努利分布和正态分布。这里我们只看后者。

我们将解码器建模成高斯分布,是否过于简单?在生成模型领域,其实我们并没有什么其他选择。因为我们要构造一个分布,而不是任意一个函数,既然是分布就得满足归一化的要求,而要满足归一化,又要容易算,候选分布模型并不多

已知q(x|z)为高斯分布,那么我们同样我们用神经网络构建均值和方差μ~θ(z)σ~θ(z)。这里由于我们的优化目标中的两项是同时优化,所以我们统一使用θ代表网络参数。注意到解码器输入为z,因此我们有解码器的解析式

(15)q(x|z)=1k=1D2πσ~θ2(x)exp(12zμ~θ(x)σ~θ(z)2)

通常情况下,对于生成器的方差我们设置为固定值σ~2,此时

(16)lnq(x|z)=12zμ~θ(x)σ~θ(z)2+D2ln2π+12k=1Dlnσ~θ2(z)12σ~θ2zμ~θ(x)2

对整个过程来看,从p~(x)从随机采样样本x,经过编码器,我们获得p(z|x)的均值和方差,进而得知p(z|x)的具体分布解析式,从中采样得到得到隐变量z,并可以计算 KL 损失。然后将 z经过解码器,获得输出x^,从本质上这代表估计的均值,我们将其作为生成目标,此外根据Eq.16计算 MSE 损失。

分析

Eq.12的两项损失中我们可以得知,VAE 中也存在一种对抗过程。假设 KL loss优化为 0,那么编码器的输出为正态分布,隐变量z没有任何辨识度,此时的生成器q(x|z)是不可能和样本x~预测准确。而如果前一项 MSE 损失足够小,此时q(x|z)大,预测的准确,此时的网络非常接近自编解码器,此时的隐变量z不会太随机,KL损失也不会小。所以 VAE 损失中的两项也存在互相对抗的过程,而且相较于 GAN,是共同优化,不需要交替优化过程。

但 VAE 存在的模糊的缺点,其本质也是因为我们将解码器q(x|z)预定义成正态分布,显然正态分布的表征能力是不足的,因此 VAE 从理论层面就不具备足够的生成能力。而 GAN 的生成器不需要满足正态分布的先验。对于 vae 来讲,通过迭代多层解码器(HVAE),既使用多个嵌套正态分布提高表征能力,能够大大提升生成效果。从这个角度来看,VAE 和 DDPM 是同根同源的。