Consistency Models
Consistency Models- 理解
-
-
- 问题定义
-
-
- 研究动机
- 本文中心论点
-
- 相关工作和进展
-
-
- Consistency Models创新点
- review扩散模型
-
- Consistency Model-Definition
-
-
- 一致性模型的定义
- 一致性模型参数化
- 一致性模型采样
-
- Training Consistency Models via Distillation
- Training Consistency Models in Isolation
-
pdf:https://arxiv.org/pdf/2303.01469.pdf
github:https://github.com/openai/consistency_models
问题定义
图像编辑等
研究动机
扩散模型依赖于迭代生成过程,导致采样速度较慢,实时应用有限。
本文中心论点
- 给定一个概率流(PF) ODE,它能平滑地将数据转换为噪声。作者学习将ODE轨迹上的任何点(例如, x t , x t ′ x_t, x'_t xt,xt′)映射到它的原点(例如, x 0 x_0 x0),用于生成建模。
- 这些映射的模型称为一致性模型,因为对于同一轨迹上的点,他们的输出被训练为的一致的。
相关工作和进展
扩散模型受限于时间。
Consistency Models创新点
- 一致性模型在设计上支持快速的一步生成,同时仍然允许少步采样以换取样本质量的计算。
- 一种新的生成模型家族,可以在没有对抗性训练的情况下实现高样本质量
- 支持zero-shot数据编辑,如图像修补、着色和超分辨率,而不需要对这些任务进行明确的训练。
- 一致性模型既可以作为提取预训练扩散模型的一种方式训练,也可以作为独立的生成模型训练。
review扩散模型
将扩散理解为一个在时间上连续的变换过程(引入SDE形式来描述扩散模型的本质好处是“将理论分析和代码实现分离开来”,借助连续性SDE的数学工具做分析,实践的时候,则只需要用任意适当的离散化方案对SDE进行数值计算)
用随机微分方程(Stochastic Differential Equation,SDE)来描述扩散模型:
可以理解为下式(离散化):
在以前的论文中推导出上述SDE存在一个ODE形式的解轨迹(Probability Flow ODE)
SDE设计为让 p T ( x ) p_T(x) pT(x) 接近于易处理的高斯分布。采用别人论文中的设置,带入到(2)中
首先训练一个得分模型 s ϕ ( x , t ) ≈ ▽ l o g p t ( x ) s_\\phi (\\mathbf{x},t) \\approx\\bigtriangledown logp_t(\\mathbf{x}) sϕ(x,t)≈▽logpt(x),(2)转化为。称为empirical PF ODE
采样 x ^ ∼ π = N ( 0 , T 2 I ) \\widehat{\\mathbf{x}} \\sim \\pi = N (0,T^2 I) x ∼π=N(0,T2I)来初始化empirical PF ODE
- 利用现有的数值ODE solver来求解(Euler,Heun solvers等)
- 得到的 x ^ \\widehat{\\mathbf{x}} x 可以被看作是数据分布 p d a t a ( x ) p_{data}(\\mathbf{x}) pdata(x)的一个近似样本。
- 考虑到数值稳定性,往往不会直接求出原图,而是取一个很小的值逐步来进行近似,并持续这个过程来求出。(导致速度慢)
Consistency Model-Definition
一致性模型的定义
假设存在一个函数f,对于同一条PF ODE轨迹上的任意点都有相同的输出 f ( x t , t ) = f ( x t ′ , t ′ ) for all t , t ′ ∈ [ ϵ , T ] \\boldsymbol{f}\\left(\\mathrm{x}_{t}, t\\right)=\\boldsymbol{f}\\left(\\mathrm{x}_{t^{\\prime}}, t^{\\prime}\\right) \\text { for all } t, t^{\\prime} \\in[\\epsilon, T] f(xt,t)=f(xt′,t′) for all t,t′∈[ϵ,T]
consistency model的目标是从数据中估计一致性函数 f f f,来迫使self-consistency性质
一致性模型参数化
对于任意的一致性函数 f ( ⋅ , ⋅ ) f(\\cdot, \\cdot) f(⋅,⋅),用神经网络来拟合。但要满足两个条件:①同一个轨迹上的点输出一致;②在起始点f为一个对于x的恒等函数
- 第一种做法简单地参数化consistency models
- 第二种做法使用跳跃连接(作者和许多其他的都用这个)
一致性模型采样
有了训练好的一致性模型 f θ ( ⋅ , ⋅ ) f_\\theta(\\cdot, \\cdot) fθ(⋅,⋅) ,就可以通过初始分布采样来产生样本。(这里指的是训练好后怎么来生成样本)
在一致性模型中,可以一步生成样本。也可以多步生成,算法1为多步生成。
想法就是预测出x后回退然后再进行预测减小误差。实际中,采用贪心算法来寻找时间点,通过三值搜索每次确定一个时间点,优化算法得到的样本的FID(不太重要)
Training Consistency Models via Distillation
第一种训练consistency model的方式——蒸馏预训练好的score model s ϕ ( x , t ) s_{\\phi}(\\mathrm{x}, t) sϕ(x,t)
假设采样轨迹的时间序列为 t 1 = ϵ < t 2 < ⋯ < t N = T t_{1}=\\epsilon<t_{2}<\\cdots<t_{N}=T t1=ϵ<t2<⋯<tN=T
通过运行数值ODE求解器的一个离散化步骤从 x t n + 1 \\mathbf{x}_{t_{n+1}} xtn+1得到 x t n \\mathbf{x}_{t_{n}} xtn
Φ ( . . . ; ϕ ) \\Phi(...;\\phi) Φ(...;ϕ)为ODE solver
例如使用Euler solver d x d t = − t s ϕ ( x t , t ) \\frac{\\mathrm{dx}}{\\mathrm{d} t}=-t s_{\\phi}\\left(\\mathrm{x}_{t}, t\\right) dtdx=−tsϕ(xt,t) , Φ ( x , t ; ϕ ) = − t s ϕ ( x , t ) \\Phi(\\mathrm{x}, t ; \\phi)=-t s_{\\phi}(\\mathrm{x}, t) Φ(x,t;ϕ)=−tsϕ(x,t)带入上式得到
沿着ODE轨迹的分布进行第一次采样 x \\mathrm{x} x~ p d a t a p_{data} pdata,然后添加高斯噪声,生成一对在PF ODE轨迹上相邻的数据点 ( x ^ t n ϕ , x t n + 1 ) \\left(\\hat{\\mathbf{x}}_{t_n}^\\phi, \\mathbf{x}_{t_{n+1}}\\right) (x^tnϕ,xtn+1)
通过最小化这一对的输出差异来训练一致性模型,作者遵循一致性蒸馏损失来训练一致性模型,就有如下的consistency distillation loss:
在蒸馏的过程中,作者用预训练模型来估计得分.
采用EMA来更新模型会提高训练的稳定性,并且性能会更好
Training Consistency Models in Isolation
Consistency models也可以单独进行训练,而不依赖于预训练好的扩散模型。
作者说这与扩散蒸馏技术不同,使一致性模型成为一个新的独立的生成模型家族。
在consistency distillation中,使用了预训练的score model s ϕ ( x , t ) s_{\\phi}(\\mathrm{x}, t) sϕ(x,t)来近似ground truth score function ▽ l o g p t ( x ) \\bigtriangledown logp_t(\\mathbf{x}) ▽logpt(x)。
作者证明了 ▽ l o g p t ( x ) \\bigtriangledown logp_t(\\mathbf{x}) ▽logpt(x)的一种无偏估计,即证明了一种新的得分函数的估计
即给定x, xt,可以用 − ( x t − x ) / t 2 -(\\mathbf{x}_t -\\mathbf{x})/t^2 −(xt−x)/t2 形式化 ▽ l o g p t ( x ) \\bigtriangledown logp_t(\\mathbf{x}) ▽logpt(x)的蒙特卡罗估计,可以理解为
利用该得分估计,作者构建了新的consistency training (CT) loss记作 L C T N ( θ , θ − ) L_{CT}^{N}(\\theta,\\theta^-) LCTN(θ,θ−)