> 文章列表 > Consistency Models

Consistency Models

Consistency Models

Consistency Models- 理解

      • 问题定义
          • 研究动机
          • 本文中心论点
      • 相关工作和进展
          • Consistency Models创新点
          • review扩散模型
      • Consistency Model-Definition
          • 一致性模型的定义
          • 一致性模型参数化
          • 一致性模型采样
      • Training Consistency Models via Distillation
      • Training Consistency Models in Isolation

Consistency Models

pdf:https://arxiv.org/pdf/2303.01469.pdf
github:https://github.com/openai/consistency_models

问题定义

图像编辑等

研究动机

扩散模型依赖于迭代生成过程,导致采样速度较慢,实时应用有限。

本文中心论点
  • 给定一个概率流(PF) ODE,它能平滑地将数据转换为噪声。作者学习将ODE轨迹上的任何点(例如, x t , x t ′ x_t, x'_t xt,xt)映射到它的原点(例如, x 0 x_0 x0),用于生成建模。
  • 这些映射的模型称为一致性模型,因为对于同一轨迹上的点,他们的输出被训练为的一致的。

相关工作和进展

扩散模型受限于时间。

Consistency Models创新点
  • 一致性模型在设计上支持快速的一步生成,同时仍然允许少步采样以换取样本质量的计算。
  • 一种新的生成模型家族,可以在没有对抗性训练的情况下实现高样本质量
  • 支持zero-shot数据编辑,如图像修补、着色和超分辨率,而不需要对这些任务进行明确的训练。
  • 一致性模型既可以作为提取预训练扩散模型的一种方式训练,也可以作为独立的生成模型训练。
review扩散模型

将扩散理解为一个在时间上连续的变换过程(引入SDE形式来描述扩散模型的本质好处是“将理论分析和代码实现分离开来”,借助连续性SDE的数学工具做分析,实践的时候,则只需要用任意适当的离散化方案对SDE进行数值计算)
用随机微分方程(Stochastic Differential Equation,SDE)来描述扩散模型:
在这里插入图片描述
可以理解为下式(离散化):
在这里插入图片描述
在以前的论文中推导出上述SDE存在一个ODE形式的解轨迹(Probability Flow ODE)
在这里插入图片描述
在这里插入图片描述
SDE设计为让 p T ( x ) p_T(x) pT(x) 接近于易处理的高斯分布。采用别人论文中的设置,带入到(2)中
在这里插入图片描述

首先训练一个得分模型 s ϕ ( x , t ) ≈ ▽ l o g p t ( x ) s_\\phi (\\mathbf{x},t) \\approx\\bigtriangledown logp_t(\\mathbf{x}) sϕ(x,t)logpt(x),(2)转化为。称为empirical PF ODE
在这里插入图片描述
采样 x ^ ∼ π = N ( 0 , T 2 I ) \\widehat{\\mathbf{x}} \\sim \\pi = N (0,T^2 I) x π=N(0,T2I)来初始化empirical PF ODE

  • 利用现有的数值ODE solver来求解(Euler,Heun solvers等)
  • 得到的 x ^ \\widehat{\\mathbf{x}} x 可以被看作是数据分布 p d a t a ( x ) p_{data}(\\mathbf{x}) pdata(x)的一个近似样本。
  • 考虑到数值稳定性,往往不会直接求出原图,而是取一个很小的值逐步来进行近似,并持续这个过程来求出。(导致速度慢)

Consistency Model-Definition

一致性模型的定义

假设存在一个函数f,对于同一条PF ODE轨迹上的任意点都有相同的输出 f ( x t , t ) = f ( x t ′ , t ′ ) for all  t , t ′ ∈ [ ϵ , T ] \\boldsymbol{f}\\left(\\mathrm{x}_{t}, t\\right)=\\boldsymbol{f}\\left(\\mathrm{x}_{t^{\\prime}}, t^{\\prime}\\right) \\text { for all } t, t^{\\prime} \\in[\\epsilon, T] f(xt,t)=f(xt,t) for all t,t[ϵ,T]
在这里插入图片描述
consistency model的目标是从数据中估计一致性函数 f f f,来迫使self-consistency性质

一致性模型参数化

对于任意的一致性函数 f ( ⋅ , ⋅ ) f(\\cdot, \\cdot) f(,),用神经网络来拟合。但要满足两个条件:①同一个轨迹上的点输出一致;②在起始点f为一个对于x的恒等函数

  • 第一种做法简单地参数化consistency models
    在这里插入图片描述
  • 第二种做法使用跳跃连接(作者和许多其他的都用这个)
    在这里插入图片描述
一致性模型采样

有了训练好的一致性模型 f θ ( ⋅ , ⋅ ) f_\\theta(\\cdot, \\cdot) fθ(,) ,就可以通过初始分布采样来产生样本。(这里指的是训练好后怎么来生成样本)

在这里插入图片描述

在一致性模型中,可以一步生成样本。也可以多步生成,算法1为多步生成。
想法就是预测出x后回退然后再进行预测减小误差。实际中,采用贪心算法来寻找时间点,通过三值搜索每次确定一个时间点,优化算法得到的样本的FID(不太重要)

Training Consistency Models via Distillation

第一种训练consistency model的方式——蒸馏预训练好的score model s ϕ ( x , t ) s_{\\phi}(\\mathrm{x}, t) sϕ(x,t)
假设采样轨迹的时间序列为 t 1 = ϵ < t 2 < ⋯ < t N = T t_{1}=\\epsilon<t_{2}<\\cdots<t_{N}=T t1=ϵ<t2<<tN=T
通过运行数值ODE求解器的一个离散化步骤从 x t n + 1 \\mathbf{x}_{t_{n+1}} xtn+1得到 x t n \\mathbf{x}_{t_{n}} xtn
在这里插入图片描述
Φ ( . . . ; ϕ ) \\Phi(...;\\phi) Φ(...;ϕ)为ODE solver
例如使用Euler solver d x d t = − t s ϕ ( x t , t ) \\frac{\\mathrm{dx}}{\\mathrm{d} t}=-t s_{\\phi}\\left(\\mathrm{x}_{t}, t\\right) dtdx=tsϕ(xt,t) Φ ( x , t ; ϕ ) = − t s ϕ ( x , t ) \\Phi(\\mathrm{x}, t ; \\phi)=-t s_{\\phi}(\\mathrm{x}, t) Φ(x,t;ϕ)=tsϕ(x,t)带入上式得到
在这里插入图片描述
沿着ODE轨迹的分布进行第一次采样 x \\mathrm{x} x~ p d a t a p_{data} pdata,然后添加高斯噪声,生成一对在PF ODE轨迹上相邻的数据点 ( x ^ t n ϕ , x t n + 1 ) \\left(\\hat{\\mathbf{x}}_{t_n}^\\phi, \\mathbf{x}_{t_{n+1}}\\right) (x^tnϕ,xtn+1)

通过最小化这一对的输出差异来训练一致性模型,作者遵循一致性蒸馏损失来训练一致性模型,就有如下的consistency distillation loss:
在这里插入图片描述
Consistency Models
在蒸馏的过程中,作者用预训练模型来估计得分.
采用EMA来更新模型会提高训练的稳定性,并且性能会更好
在这里插入图片描述

Training Consistency Models in Isolation

Consistency models也可以单独进行训练,而不依赖于预训练好的扩散模型。
作者说这与扩散蒸馏技术不同,使一致性模型成为一个新的独立的生成模型家族。
在consistency distillation中,使用了预训练的score model s ϕ ( x , t ) s_{\\phi}(\\mathrm{x}, t) sϕ(x,t)来近似ground truth score function ▽ l o g p t ( x ) \\bigtriangledown logp_t(\\mathbf{x}) logpt(x)
作者证明了 ▽ l o g p t ( x ) \\bigtriangledown logp_t(\\mathbf{x}) logpt(x)的一种无偏估计,即证明了一种新的得分函数的估计
在这里插入图片描述

即给定x, xt,可以用 − ( x t − x ) / t 2 -(\\mathbf{x}_t -\\mathbf{x})/t^2 (xtx)/t2 形式化 ▽ l o g p t ( x ) \\bigtriangledown logp_t(\\mathbf{x}) logpt(x)的蒙特卡罗估计,可以理解为
在这里插入图片描述
利用该得分估计,作者构建了新的consistency training (CT) loss记作 L C T N ( θ , θ − ) L_{CT}^{N}(\\theta,\\theta^-) LCTN(θ,θ)
在这里插入图片描述
在这里插入图片描述