> 文章列表 > 双塔模型实践

双塔模型实践

双塔模型实践

双塔是“召回”+“粗排”的绝对主力模型。但是要让双塔在召回、粗排中发挥作用,带来收益,只改进双塔结构是远远不够的。如何采样以减少“样本选择偏差”、如何保证上下游目标一致性、如何在双塔中实现多任务间的信息转移,都是非常重要的课题。但是受篇幅限制,本文只聚集于双塔模型结构上的改进。
市面上关于双塔改进的文章有很多,本文不会一一罗列这些改进的细节。遵循本人文章的一贯风格,本文将为读者梳理这些改进背后的发展脉络,了解这些改进“为什么这样做”,希望能够激发读者在“改进双塔”上新的灵感。至于“怎么做”,请感兴趣的读者稳步原文。

双塔模型的结构

双塔的模型结构很简单,

训练的时候:

  1. 将用户侧的信息喂入一个DNN(aka, user tower),最终得到一个user embedding
  2. 将物料侧的信息喂入一个DNN(aka, item tower),最终得到一个item embedding
  3. 拿user embedding与item embedding,做点积或cosine,得到logit,代表user & item之间的匹配程度
  4. 设计loss,将user tower, item tower, 和各种特征的embedding,都训练出来
    注意,虽然训练流程类似,但是“双塔召回”与“双塔粗排”所需要的负样本,截然不同,详情见《负样本为王:评Facebook的向量化召回算法》。
    双塔用作召回,线上预测时:
  5. 离线、周期性、批量将item信息(i.e., 十万、百万级别)喂入item tower,得到item embedding
  6. 得到的item embedding,导入FAISS,建立索引
  7. 线上接到用户请求后,将user信息喂入user tower,得到user embedding
  8. 拿得到的user embedding,在faiss中做近邻搜索(ANN),得到与user embedding相邻item,作为召回内容返回
    双塔用作粗排,线上预测时:
  9. 与双塔召回时一样,item embedding依然是离线、周期性、批量生成。
  10. 线上接到粗排请求后,将user信息喂入user tower,得到user embedding
  11. 拿粗排请求中的candidate item id(千级别)去KV库中检索出对应的item embedding
  12. 拿user embedding与检索出来的item embedding,逐一做点积或cosine得到user & item的匹配度
  13. 将candidate item按与user匹配度降序排列,top N candidate item喂入下游精排。

双塔模型的优缺点

综上所述,我们发现双塔模型最大的特点就是“双塔分离”。只不过“部署时分离”成了双塔最大的优点,而“训练时分离”成了制约双塔性能的最大因素。
部署时分离
由于item tower完全不依赖于user信息,所以海量的item embedding可以周期性、批量、离线生成,大大减轻了线上serving的压力。由于user tower完全不依赖于item信息,所以无论候选集是几千(粗排)或十万级、百万级(召回),user embedding只需要生成一遍。反观精排模型,由于从最底层user & item信息就需要开始产生交叉,“难舍难分”。所以user信息必须与每一条candidate item过一遍精排模型,从而限制了精排候选集的规模
训练时分离
user信息只能喂入user tower, item信息只能喂入item tower,没有地方喂入user & item之间的交叉特征。user侧信息与item侧信息,只有唯一一次交叉机会,就是在双塔生成各自的embedding之后的那次点积或cosine。但是这时参与交叉的user/item embedding,已经是高度浓缩的了。一些细节信息已经损失,永远失去了与对侧信息交叉的机会。为了线上快速serving,交叉只能是简单的dot或cosine。一些复杂的、依赖于底层信息的交叉结构,比如target item对user action history的attention,也在双塔中找不到位置。
综上所述,“双塔分离”的结构,既是保障线上快速serving的优点,也是不能使用交叉特征与结构、导致两侧信息交叉过晚、制约模型表达能力的最大缺点。“线上快速serving”正好对召回、粗排这种“大候选集”场景的胃口,而由于后面还有能力强大的精排,所以“模型表达能力弱”的缺点,也能够为召回、粗排所容忍。因此,双塔模型成为召回+粗排的主流模型,几乎是粗排的不二选择。

双塔模型的改进

综上,双塔最大的缺点就在于,user&item两侧信息交叉得太晚,等到最终能够通过dot或cosine交叉的时候,user & item embedding已经是高度浓缩的了,一些细粒度的信息已经在塔中被损耗掉,永远失去了与对侧信息交叉的机会。所以,双塔改建最重要的一条主线就是:如何保留更多的信息在tower的final embedding中,从而有机会和对侧塔得到的embedding交叉?围绕着这条主线,勤劳的互联网打工人设计出很多的改进方案。

改进方案1:

这种思路以张俊林大佬的SENet为代表。既然把信息“鱼龙混杂”一古脑地喂入塔,其中的噪声造成污染,导致很多细粒度的重要信息未能“幸存”到final dot product那一刻。SENet的思路就是,在将信息喂入塔之前,插入SEBlock。SEBlock动态学习各特征的重要性,增强重要信息,弱化甚至过滤掉原始特征中的噪声,从而减少信息在塔中传播过程中的污染与损耗,能够让可能多的重要信息“撑”到final dot product那一刻。

方案2:

信息在塔中向上流动的过程,也是一个信息压缩的过程,不可避免地带来信息损耗。所以,我们很自然地想到,何不让那些重要信息抄近路,走捷径,把它们直接送到离final dot product更近的地方。

提到抄近路,大家自然而然地想到ResNet,如下图所示。
是喂入塔的原始信息,经过塔中的信息流动,到最后一层时已经损失了很多重要的、细粒度信息。
这时,我们将抄近路,送到最后一层与tower的输出融合(图中是element-wise add,但是显然那并不是唯一的融合方式),得到final embedding
这时的既包含了经过tower高度浓缩后的信息,又包含原始输入中的一些细粒度信息。特别是这些细粒度的重要信息,终于有了和对侧信息交互的机会。
抄近路的思路确定了,那么抄近路的方式,就五花八门,多种多样了。比如除了原始输入能够抄近路,塔中间的一些信息是不是也能抄一把?比如下图模样(BTW,有谁知道logo的出处吗?如果知道,咱俩除了调参炼丹,就有了另一个共同爱好_);既然信息在塔中流动过程中就已经损失了,重要信息没必要等到最后一刻再补充,补充到中间层也会大有帮助,就像马拉松选手的中途补水。
但是,这种抄近路的方式,也有其固有的缺点,就是会导致输入层的肿胀。比如原来tower final embedding是64维,你现在要将一些重要的、细粒度的信息也抄近路到最后一层。既然称这些信息是细粒度的,自然是未经过压缩提炼的,维度一般都很大,比如1024维。如果你将原来tower embedding与抄近路信息简单拼接,那么final embedding就会膨胀好几倍,会给线上存储、内存都带来巨大的压力。当然你可以将抄近路的信息,经过一层简单的线性映射,压缩到一个比较小的维度,但这也会引入额外的映射权重,严重时会导致训练时OOM。

所以,将所有原始信息无脑地抄近路,显然是行不通的。这就牵扯到另外一个问题:哪些信息值得抄近路?要回答这个问题,你当然可以跑一个SE block或其他什么算法,获得各特征的重要性。而从我的个人经验来看,我们要特别注意那些“极其个性化”(e.g., userId, itemId)的特征,和,对划分人群、物群有显著区分性的特征(e.g., 用户是新用户还是老用户?用户是否登陆?文章所使用的语言,等)。

方案3

另一个出发点,与“抄近路”的思路是类似的:我们不再相信(或者说,迷信)DNN的拟合能力。
传统双塔,只有一种信息上升通道,就是DNN。我估计很多同学有与我类似的经历,就是刚接触DNN的时候,听过这样一句话,“只要DNN足够复杂,能够模拟任意函数”。现在看来,这句话的可信性要大打折扣了,Google DCN的论文里宣称,“People generally consider DNNs as universal function approximators, that could potentially learn all kinds of feature interactions. However, recent studies found that DNNs are inefficient to even approximately model 2nd or 3rd-order feature crosses.”
既然如此,我们也就没必要将宝都押在DNN这一种通道上。即使是相同的信息,也可以沿多种信息通道向上流动,最终将各通道得到的embedding聚合成final embedding,与对侧交互。
这种思路的典型代表,就是腾讯的并联双塔,“通过并联多个双塔结构增加双塔模型的宽度,来缓解双塔内积的瓶颈从而提升效果”

信息沿着MLP, DCN, FM, CIN这4种通道向上流动。每种通道各有所长,比如MLP是implicit feature cross,FM和DNN都属于explicit and bounded-degree feature cross,大家相互取长补短。
最终各通道的融合,这里是直接拼接,只不过各通道的embedding乘上一个可学习的系数,以形成一个logistic regression的效果。比如我们只有MLP和DCN两个通道,
双塔模型实践
另外,涉及到并联双塔训练细节的是,

由于FM和DCN等结构,只能完成信息交叉,而无法信息压缩,所以只能喂入有限的重要特征,否则会引发维度膨胀。
虽然不同结构可能共享特征,但是它们却不共享这些特征的底层embedding。同一个特征,如果要同时喂入MLP和DCN,就必须定义两套embedding,供MLP和DCN分别加以训练。根据我之前的经验,分离embedding空间的确能够换来性能上的提升,但是也带来模型膨胀,给线上serving带来压力。

总结

本文是我和双塔模型死磕了6个月之后的心得体会。如前文所述,双塔分离,既是保障线上快速serving、以适应召回+粗排场景的优点,也是不能使用交叉特征与结构、导致两侧信息交叉过晚、制约模型表达能力的最大缺点。user&item两侧信息交叉得太晚,等到最终能够通过dot或cosine交叉的时候,user & item embedding已经高度浓缩,一些细粒度的信息已经在塔中被损耗,永远失去了与对侧信息交叉的机会。

为了克服这一缺点,业界同仁设计出许多改进方案。这些方案背后有一个共同的思路,就是减少信息沿塔上升过程中的损耗,让更多细粒度的重要信息能够“幸存”到final embedding中,能够“撑”到final dot product那一刻。然后我分析了Facebook、阿里、腾讯、新浪、美团等中外大厂的工作,看看我的中外同仁们如何从“净化输入”、“重要信息走捷径”、“拓宽信息上升通道”、“双塔相互模仿”等方面实现了这个思路,克服双塔的缺点,提升其性能。

正如我之前经常论述的“道”与“术”,“了解双塔的缺点,知道从哪里改进”是“道”,“怎么改进”是“术”,深刻理解“道”之后,才能将各大厂的“术”综合运用,甚至创造你自己的“术”,提升你自己模型的性能。比如,如果美团方案中的对偶增强向量真的那么重要,那么不将它们接入DNN最底层,而是直接抄近路到塔的最后一层,离final dot-product更近,是不是效果更佳?至于是否是这样,就要等GPU和AB平台告诉我们了。

最后,我也指出,尽管阿里妈妈已经在召回+粗排领域告别了“双塔”,但是现在和双塔彻底告别还为时过早,“双塔”模型仍然是我们广大算法同仁手中得心应手的一件兵器,你值得拥有。