> 文章列表 > axial attention 轴向注意力

axial attention 轴向注意力

axial attention 轴向注意力

Medical Transformer: Gated Axial-Attention for Medical Image Segmentation
论文解读:
https://zhuanlan.zhihu.com/p/408662947

1. axialAttentionUNet

1.1 原始的 axialAttentionUNet

model = ResAxialAttentionUNet(AxialBlock, [1, 2, 4, 1], s= 0.125, kwargs)

  1. 原始的轴注意力 + 残差网络构成的unet
ResAxialAttentionUNet((conv1): Conv2d(3, 8, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(conv2): Conv2d(8, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(conv3): Conv2d(128, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(layer1): Sequential((0): AxialBlock((conv_down): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention((qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention((qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))))(layer2): Sequential((0): AxialBlock((conv_down): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0))(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): AxialBlock((conv_down): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer3): Sequential((0): AxialBlock((conv_down): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0))(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): AxialBlock((conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): AxialBlock((conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(3): AxialBlock((conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer4): Sequential((0): AxialBlock((conv_down): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention((qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention((qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0))(conv_up): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))))(decoder1): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))(decoder2): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(decoder3): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(decoder4): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(decoder5): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(adjust): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))(soft): Softmax(dim=1)
)

1.2 添加了门控单元的轴注意力网络

model = ResAxialAttentionUNet(AxialBlock_dynamic, [1, 2, 4, 1], s= 0.125, kwargs)

在门控轴注意力网络中, 
1. gated axial attention network 将axial attention layers 轴注意力层 全部换成门控轴注意力层。

ResAxialAttentionUNet((conv1): Conv2d(3, 8, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(conv2): Conv2d(8, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(conv3): Conv2d(128, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(layer1): Sequential((0): AxialBlock_dynamic((conv_down): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention_dynamic((qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention_dynamic((qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))))(layer2): Sequential((0): AxialBlock_dynamic((conv_down): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention_dynamic((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention_dynamic((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0))(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): AxialBlock_dynamic((conv_down): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention_dynamic((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention_dynamic((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer3): Sequential((0): AxialBlock_dynamic((conv_down): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention_dynamic((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention_dynamic((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0))(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): AxialBlock_dynamic((conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention_dynamic((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention_dynamic((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): AxialBlock_dynamic((conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention_dynamic((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention_dynamic((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(3): AxialBlock_dynamic((conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention_dynamic((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention_dynamic((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer4): Sequential((0): AxialBlock_dynamic((conv_down): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention_dynamic((qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention_dynamic((qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0))(conv_up): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))))(decoder1): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))(decoder2): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(decoder3): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(decoder4): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(decoder5): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(adjust): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))(soft): Softmax(dim=1)
)

2. Medical Transformer

训练过程中,需要注意 前10个 epoch 并没有激活gated 门控单元,在10个epoch 之后才会开启。

2.1 local _ global

model = medt_net(AxialBlock,AxialBlock, [1, 2, 4, 1], s= 0.125, kwargs)

LoGo network:
在局部 + 全局的网络中:

使用的是方式是:

  1. 使用原始轴注意力构成的unet , 没有使用本文提出的门控轴注意力单元.
  2. 使用了 local+ global training 的训练策略.
medt_net((conv1): Conv2d(3, 8, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(conv2): Conv2d(8, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(conv3): Conv2d(128, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(layer1): Sequential((0): AxialBlock((conv_down): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention((qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention((qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))))(layer2): Sequential((0): AxialBlock((conv_down): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0))(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): AxialBlock((conv_down): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(decoder4): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(decoder5): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(adjust): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))(soft): Softmax(dim=1)(conv1_p): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(conv2_p): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(conv3_p): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1_p): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn2_p): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn3_p): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu_p): ReLU(inplace=True)(layer1_p): Sequential((0): AxialBlock((conv_down): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention((qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention((qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))))(layer2_p): Sequential((0): AxialBlock((conv_down): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0))(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): AxialBlock((conv_down): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer3_p): Sequential((0): AxialBlock((conv_down): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0))(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): AxialBlock((conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): AxialBlock((conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(3): AxialBlock((conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer4_p): Sequential((0): AxialBlock((conv_down): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention((qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention((qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0))(conv_up): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))))(decoder1_p): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))(decoder2_p): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(decoder3_p): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(decoder4_p): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(decoder5_p): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(decoderf): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(adjust_p): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))(soft_p): Softmax(dim=1)
)

2.2 Med transformer

model = medt_net(AxialBlock_dynamic,AxialBlock_wopos, [1, 2, 4, 1], s= 0.125, kwargs)

使用的是方式是:

  1. 在全局分支中,使用提出的门控轴注意力单元。  而在局部分支中,使用的是原始轴注意力,并且没有位置编码。

  2. 使用了 local+ global training 的训练策略.

medt_net((conv1): Conv2d(3, 8, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(conv2): Conv2d(8, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(conv3): Conv2d(128, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(layer1): Sequential((0): AxialBlock_dynamic((conv_down): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention_dynamic((qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention_dynamic((qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))))(layer2): Sequential((0): AxialBlock_dynamic((conv_down): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention_dynamic((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention_dynamic((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0))(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): AxialBlock_dynamic((conv_down): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention_dynamic((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention_dynamic((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(decoder4): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(decoder5): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(adjust): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))(soft): Softmax(dim=1)(conv1_p): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(conv2_p): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(conv3_p): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1_p): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn2_p): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn3_p): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu_p): ReLU(inplace=True)(layer1_p): Sequential((0): AxialBlock_wopos((conv_down): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)(conv1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention_wopos((qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention_wopos((qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))))(layer2_p): Sequential((0): AxialBlock_wopos((conv_down): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(conv1): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention_wopos((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention_wopos((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0))(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): AxialBlock_wopos((conv_down): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(conv1): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention_wopos((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention_wopos((qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer3_p): Sequential((0): AxialBlock_wopos((conv_down): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention_wopos((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention_wopos((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0))(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): AxialBlock_wopos((conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention_wopos((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention_wopos((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): AxialBlock_wopos((conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention_wopos((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention_wopos((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(3): AxialBlock_wopos((conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention_wopos((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention_wopos((qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer4_p): Sequential((0): AxialBlock_wopos((conv_down): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(hight_block): AxialAttention_wopos((qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(width_block): AxialAttention_wopos((qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)(bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0))(conv_up): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))))(decoder1_p): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))(decoder2_p): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(decoder3_p): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(decoder4_p): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(decoder5_p): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(decoderf): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(adjust_p): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))(soft_p): Softmax(dim=1)
)

reference:

轴注意力网络:
https://blog.csdn.net/hxxjxw/article/details/121445561;

https://blog.csdn.net/weixin_43718675/article/details/106760382

https://zhuanlan.zhihu.com/p/408662947;

QQ头像