> 文章列表 > 【视频分割】【深度学习】MiVOS官方Pytorch代码--Propagation模块PropagationNet网络解析

【视频分割】【深度学习】MiVOS官方Pytorch代码--Propagation模块PropagationNet网络解析

【视频分割】【深度学习】MiVOS官方Pytorch代码--Propagation模块PropagationNet网络解析

【视频分割】【深度学习】MiVOS官方Pytorch代码–Propagation模块PropagationNet网络解析

MiVOS模型将交互到掩码和掩码传播分离,从而实现更高的泛化性和更好的性能。单独训练的交互模块将用户交互转换为对象掩码,传播模块使用一种新的top-k过滤策略在读取时空存储器时进行临时传播,本博客将讲解Propagation(用户交互产生分割图)模块的深度网络代码,Propagation模块封装了PropagationNet和FusionNet模型。
【视频分割】【深度学习】MiVOS官方Pytorch代码--Propagation模块PropagationNet网络解析

文章目录

  • 【视频分割】【深度学习】MiVOS官方Pytorch代码--Propagation模块PropagationNet网络解析
  • 前言
  • PropagationNetwork类
    • __init__函数
    • Memory Encoder
    • Query Encoder
    • Decoder
  • EvalMemoryReader类
  • Decoder类
  • modules.py
    • MaskRGBEncoder类
    • RGBEncoder类
    • KeyValue类
  • 总结

前言

在详细解析MiVOS代码之前,首要任务是成功运行MiVOS代码【win10下参考教程】,后续学习才有意义。
本博客讲解Propagation模块的深度网络(PropagationNetwork)代码,不再复述其他功能模块代码。
MiVOS原论文中关于Propagation Module的示意图:
【视频分割】【深度学习】MiVOS官方Pytorch代码--Propagation模块PropagationNet网络解析

关键帧是用户在某一帧有交互行为,传播帧是根据这些交互行为而需要改变的帧。


PropagationNetwork类

在model/propagation/prop_net.py内

__init__函数

def __init__(self, top_k=50):super().__init__()# Memory Encoder过程主干网络self.mask_rgb_encoder = MaskRGBEncoder()# Query Encoder过程的主干网络self.rgb_encoder = RGBEncoder() # 主干网络+Memory KeyValue网络=>Memory Encoder的key和valueself.kv_m_f16 = KeyValue(1024, keydim=128, valdim=512)# 主干网络+Query KeyValue网络=>Query Encoder的key和valueself.kv_q_f16 = KeyValue(1024, keydim=128, valdim=512)# 获得Memory Encoder中前top_k有价值的valueself.memory = EvalMemoryReader(top_k, km=None)# 获得原始图像的注意区域self.attn_memory = AttentionMemory(top_k)# 上采样Decoder获得mask,正确区分背景和多个目标前景self.decoder = Decoder()

Memory Encoder

memorize方法是Memory Encoder过程,mask_rgb_encoder是主干网络,kv_m_f16是编码网络。通过原始图片、mask以及other获得Memory key/value,mask是由S2M生成。

def memorize(self, frame, masks): k, _, h, w = masks.shape            # [k, 1, h, w]  # 扩展图片batchsize-->1到k [k,3,h,w]frame = frame.view(1, 3, h, w).repeat(k, 1, 1, 1)# Compute the "others" maskif k != 1:others = torch.cat([torch.sum(masks[[j for j in range(k) if i != j]], dim=0, keepdim=True)  # 计算除了i以外的其他k-1个obj mask的和,并在0维拼接for i in range(k)], 0)          # [k, 1, h, w]    else:others = torch.zeros_like(masks)f16 = self.mask_rgb_encoder(frame, masks, others)   # 数字16代表下采样后特征图为原图大小1/16k16, v16 = self.kv_m_f16(f16)               # [k, channel(k128 v512), H/16, W/16]return k16.unsqueeze(2), v16.unsqueeze(2)   # [k, channel(k128 v512), 1, h, w ]

Memory Encoder过程在论文原图中所示:
【视频分割】【深度学习】MiVOS官方Pytorch代码--Propagation模块PropagationNet网络解析

T存放着所有关键帧和已传播完成帧的Memory key/value,已传播完成帧指的根据关键帧信息完成PropagationNet和FusionNet完整过程的帧。

Memory Encoder的详细过程示意图如下所示:
【视频分割】【深度学习】MiVOS官方Pytorch代码--Propagation模块PropagationNet网络解析

这里的other图只是随机生成的示意图,只是为了方便说明,并不是真在根据masks计算得出

Query Encoder

get_query_values方法是Query Encoder过程,rgb_encoder是主干网络,kv_q_f16是编码网络。通过原始图片获得Query key/value。

def get_query_values(self, frame):f16, f8, f4 = self.rgb_encoder(frame)k16, v16 = self.kv_q_f16(f16)return f16, f8, f4, k16, v16

Query Encoder过程在论文原图中所示:
【视频分割】【深度学习】MiVOS官方Pytorch代码--Propagation模块PropagationNet网络解析

Query Encoder仅用于当前传播帧,传播完成后变为已传播完成帧,就需要Memory Encoder存到T

Query Encoder的详细过程示意图如下所示:
【视频分割】【深度学习】MiVOS官方Pytorch代码--Propagation模块PropagationNet网络解析

Query只有一个,Memory 有T个,具体请查看博文【Propagation功能模块】

Decoder

首先需要memory方法为所有目标k分别获取加权处理Memory value后有价值的特征并结合Query value,而后与rgb_encoder主干网络生成的中间浅层特征一起进行decoder解码过程获得最终的mask。

   def segment_with_query(self, keys, values, f16, f8, f4, k16, v16): k = keys.shape[0]# Do it batch by batch to reduce memory usagebatched = 1m4 = torch.cat([self.memory(keys[i:i+batched], values[i:i+batched], k16) for i in range(0, k, batched)], 0)   # [k,C,H,W]  C:channelv16 = v16.expand(k, -1, -1, -1)         # expand必须有一个维度的值为1m4 = torch.cat([m4, v16], 1)return torch.sigmoid(self.decoder(m4, f8, f4))

segment_with_query过程在论文原图中所示:
【视频分割】【深度学习】MiVOS官方Pytorch代码--Propagation模块PropagationNet网络解析
Memory value和Query value结合详细过程如下图所示:
【视频分割】【深度学习】MiVOS官方Pytorch代码--Propagation模块PropagationNet网络解析

EvalMemoryReader类

通过Memory key特征和Query key特征计算得到weight map(权重图)【个人理解】,然后Memory value和weight map做加权获得新的Memory new value特征。

class EvalMemoryReader(nn.Module):def __init__(self, top_k, km):super().__init__()self.top_k = top_k              # 选取相似度最近的top50self.km = kmdef forward(self, mk, mv, qk):B, CK, T, H, W = mk.shape       # B是1,即当前的obj类的key/value T是memory中已存的图片数_, CV, _, _, _ = mv.shapemi = mk.view(B, CK, T*H*W).transpose(1, 2)                  # [B,THW,CK]qi = qk.view(1, CK, H*W).expand(B, -1, -1) / math.sqrt(CK)  # [B,CK,HW]affinity = torch.bmm(mi, qi)    # 矩阵相乘 [B,THW,HW]  shape只能是3维# --------源码没有使用if self.km is not None:# Make a bunch of Gaussian distributionsargmax_idx = affinity.max(2)[1]y_idx, x_idx = argmax_idx//W, argmax_idx%Wg = make_gaussian(y_idx, x_idx, H, W, sigma=self.km)g = g.view(B, T*H*W, H*W)affinity = softmax_w_g_top(affinity, top=self.top_k, gauss=g)           # [B,THW,HW]# --------else:if self.top_k is not None:affinity = softmax_w_g_top(affinity, top=self.top_k, gauss=None)    # mv特征图的权重[B,THW,HW]else:affinity = F.softmax(affinity, dim=1)mo = mv.view(B, CV, T*H*W)      # [B,CV,THW]mem = torch.bmm(mo, affinity)   # [B, CV, HW]mem = mem.view(B, CV, H, W)return mem

EvalMemoryReader详细过程如下图所示:
【视频分割】【深度学习】MiVOS官方Pytorch代码--Propagation模块PropagationNet网络解析

weight map(权重图)是所有Memory key 和当前传播的帧Query key矩阵相乘计算而来,而后加权到所有Memory value获得新的Memory new value。FusionNet也有一部类似的操作,注意区分。

生成Memory value特征的weight map(权重图)的代码,权重图仅保留top-50的权重,其他置零。

def softmax_w_g_top(x, top=None, gauss=None):#  x[B,THW,HW]if top is not None:# ----源码未使用部分if gauss is not None:maxes = torch.max(x, dim=1, keepdim=True)[0]x_exp = torch.exp(x - maxes)*gaussx_exp, indices = torch.topk(x_exp, k=top, dim=1)# -----else:values, indices = torch.topk(x, k=top, dim=1)   #在THW 选择前top个的(值,索引)的元组x_exp = torch.exp(values - values[:, 0])        # e^vx_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)   # 求和之后这个dim的元素个数为1,所以要被去掉,如果要保留这个维度,则应当keepdim=Truex_exp /= x_exp_sum      # x_exp 归一化# The types should be the same already# some people report an error here so an additional guard is addedx.zero_().scatter_(1, indices, x_exp.type(x.dtype))     # 用x_exp[B,THW,HW]output = xelse:maxes = torch.max(x, dim=1, keepdim=True)[0]if gauss is not None:x_exp = torch.exp(x-maxes)*gaussx_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)x_exp /= x_exp_sumoutput = x_expreturn output

Decoder类

Decoder通过rgb_encoder主干网络生成的中间浅层特征f8/f4,以及处理合并Memory value和Query value的特征f16共同生成mask。

Decoder

class Decoder(nn.Module):def __init__(self):super().__init__()self.compress = ResBlock(1024, 512)self.up_16_8 = UpsampleBlock(512, 512, 256)     # 1/16 -> 1/8self.up_8_4 = UpsampleBlock(256, 256, 256)      # 1/8 -> 1/4self.pred = nn.Conv2d(256, 1, kernel_size=(3, 3), padding=(1, 1), stride=1)def forward(self, f16, f8, f4):x = self.compress(f16)x = self.up_16_8(f8, x)x = self.up_8_4(f4, x)x = self.pred(F.relu(x))x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=False)return x

网络结构如下图所示:
【视频分割】【深度学习】MiVOS官方Pytorch代码--Propagation模块PropagationNet网络解析

ResBlockh和UpsampleBlock代码位置model/propagation/modules.py

ResBlock模块

class ResBlock(nn.Module):def __init__(self, indim, outdim=None):super(ResBlock, self).__init__()if outdim == None:outdim = indimif indim == outdim:self.downsample = Noneelse:self.downsample = nn.Conv2d(indim, outdim, kernel_size=3, padding=1)self.conv1 = nn.Conv2d(indim, outdim, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1)def forward(self, x):r = self.conv1(F.relu(x))r = self.conv2(F.relu(r))if self.downsample is not None:x = self.downsample(x)return x + r

UpsampleBlock模块

class UpsampleBlock(nn.Module):def __init__(self, skip_c, up_c, out_c, scale_factor=2):super().__init__()self.skip_conv1 = nn.Conv2d(skip_c, up_c, kernel_size=3, padding=1)self.skip_conv2 = ResBlock(up_c, up_c)self.out_conv = ResBlock(up_c, out_c)self.scale_factor = scale_factordef forward(self, skip_f, up_f):x = self.skip_conv2(self.skip_conv1(skip_f))x = x + F.interpolate(up_f, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)x = self.out_conv(x)return x

modules.py

在model/propagation目录下

MaskRGBEncoder类

采用了resnet50网络,是Memory Encoder过程的主干网络。

这里的resnet50输入channels是5,不是3

class MaskRGBEncoder(nn.Module):def __init__(self):super().__init__()resnet = mod_resnet.resnet50(pretrained=True, extra_chan=2)self.conv1 = resnet.conv1self.bn1 = resnet.bn1self.relu = resnet.relu         # 1/2, 64self.maxpool = resnet.maxpool   # 1/4, 64self.layer1 = resnet.layer1     # 1/4, 256self.layer2 = resnet.layer2     # 1/8, 512self.layer3 = resnet.layer3     # 1/16, 1024def forward(self, f, m, o):f = torch.cat([f, m, o], 1)x = self.conv1(f)x = self.bn1(x)x = self.relu(x)        # 1/2, 64x = self.maxpool(x)     # 1/4, 64x = self.layer1(x)      # 1/4, 256x = self.layer2(x)      # 1/8, 512x = self.layer3(x)      # 1/16, 1024return x

RGBEncoder类

采用了resnet50网络,是Query Encoder过程的主干网络。

class RGBEncoder(nn.Module):def __init__(self):super().__init__()resnet = models.resnet50(pretrained=True)self.conv1 = resnet.conv1self.bn1 = resnet.bn1self.relu = resnet.relu         # 1/2, 64self.maxpool = resnet.maxpool   # 1/4, 64self.res2 = resnet.layer1       # 1/4, 256self.layer2 = resnet.layer2     # 1/8, 512self.layer3 = resnet.layer3     # 1/16, 1024def forward(self, f):x = self.conv1(f) x = self.bn1(x)x = self.relu(x)        # 1/2, 64x = self.maxpool(x)     # 1/4, 64f4 = self.res2(x)       # 1/4, 256f8 = self.layer2(f4)    # 1/8, 512f16 = self.layer3(f8)   # 1/16, 1024return f16, f8, f4

KeyValue类

编码网络,key用于评估当前帧和之前帧的相似性,value用来生成最后mask精细结果信息。

class KeyValue(nn.Module):def __init__(self, indim, keydim, valdim):super().__init__()self.key_proj = nn.Conv2d(indim, keydim, kernel_size=3, padding=1)self.val_proj = nn.Conv2d(indim, valdim, kernel_size=3, padding=1)def forward(self, x):  return self.key_proj(x), self.val_proj(x)

总结

尽可能简单、详细的介绍MiVOS中Propagation模块中PropagationNetwork网络的代码。后续会讲解Propagation中FusionNet网络代码以及MiVOS的训练。