  • 相关论文:Learning Efficient Convolutional Networks through Network Slimming (ICCV 2017)


先给出答案:在Batch Normalize层的缩放因子上施加L1正则化(这是上面这篇论文的核心思想,更多细节请自行阅读论文😂)


  • 不需要对现有的CNN架构进任何更改
  • 使用L1正则化将BN缩放因子的值推向零
    • 使我们能够识别不重要的通道(或神经元),因为每个缩放因子对应于特定的卷积通道(或全连接层的神经元)
    • 这有利于在接下来的步骤中进行通道级剪枝
  • 附加的正则化项很少会损害性能。不仅如此,在某些情况下,它会导致更高的泛化精度
  • 剪枝不重要的通道有时可能会暂时降低性能,但这个效应可以通过接下来的修剪网络的微调来弥补
  • 剪枝后,由此得到的较窄的网络在模型大小、运行时内存和计算操作方面比初始的宽网络更加紧凑。上述过程可以重复几次,得到一个多通道网络瘦身方案,从而实现更加紧凑的网络。

L=∑(x,y)l(f(x,W),y)+λ∑γ∈Γg(γ)L = \\sum_{(x,y)} l\\Big(f(x, W), y\\Big) + \\lambda\\sum_{\\gamma \\in \\Gamma} g(\\gamma) L=(x,y)l(f(x,W),y)+λγΓg(γ)





2.1 说明



  • 在之前的课程中我们对 BatchNorm 进行了稀疏训练
  • 训练完成后我们获取所有的 BatchNorm 的参数数量,将 BatchNorm 所有参数取出来排序
  • 根据剪枝比例 rrr 设置 threshold 阈值,通过 gt() (greater than) 方法得到 mask,小于 threshold 的置零
  • 根据 mask 计算剩余的数量,记录
    • cfg:用于创建新模型
    • cfg_mask:用于剪枝
  • 后面会用到这两个 mask,操作每一层的输入和输出



  • weights:(out_channels, in_channels, kernel_size, kernel_size)
  • 利用 mask 做索引,对应赋值
  • 使用 start_mask、end_mask


  • self.weight:存储 γ\\gammaγ,(input_size)
  • self.bias:存储 β\\betaβ,(input_size)
  • 使用 end_mask
  • 更新 start_mask、end_mask


  • self.weight:(out_features, int_features)
  • self.bias:(out_features)
  • 使用 start_mask


2.2 test()


import argparse
from utils import get_test_dataloader
import torchdef parse_opt():# Prune settingparser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune')parser.add_argument('--dataset', type=str, default='cifar10', help='training dataset (default: cifar10)')parser.add_argument('--test-batch-size', type=int, default=256, metavar='N', help='input batch size for test (default: 256)')parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')parser.add_argument('--depth', type=int, default=11, help='depth of the vgg')parser.add_argument('--percent', type=float, default=0.5, help='scale sparse rate (default: 0.5)')parser.add_argument('--model', default='', type=str, metavar='PATH', help='path to the model (default: none)')parser.add_argument('--save', default='logs/', type=str, metavar='PATH', help='path to save pruned model (default: none)')args = parser.parse_args()return argsdef test(model):kwargs = {'num_workers' : 1, 'pin_memory' : True} if args.cuda else {}test_loader = get_test_dataloader(batch_size=args.test_batch_size, **kwargs)model.eval()correct = 0with torch.no_grad():for data, target in test_loader:if args.cuda:data, target = data.cuda(), target.cuda()output = model(data)pred = output.data.max(1, keepdim=True)[1]correct += pred.eq(target.data.view_as(pred)).cpu().sum()accuracy = 100. * correct / len(test_loader.dataset)print('\\nTest set: Accuracy: {}/{} ({:.1f}%)\\n'.format(correct, len(test_loader), accuracy))return accuracy / 100if __name__ == "__main__":args = parse_opt()

2.3 加载稀疏训练模型


python .\\train.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 11 --epochs 10


import os
import argparse
from models.vgg import VGG
from utils import get_test_dataloader
import torchdef parse_opt():...def test(model):...if __name__ == "__main__":args = parse_opt()args.cuda = not args.no_cuda and torch.cuda.is_available()if not os.path.exists(args.save):os.makedirs(args.save)model = VGG(depth=args.depth)if args.cuda:model.cuda()if args.model:if os.path.isfile(args.model):print("=> loading checkpoing '{}'".format(args.model))checkpoint = torch.load(args.model)args.start_epoch = checkpoint['epoch']best_prec1 = checkpoint['best_prec1']model.load_state_dict(checkpoint['state_dict'])print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}".format(args.model, checkpoint['epoch'], best_prec1))else:print("=> no checkpoing found at '{}'".format(args.model))print(model)

2.4 前处理


if __name__ == "__main__":...total = 0for m in model.modules():if isinstance(m, nn.BatchNorm2d):total += m.weight.data.shape[0]bn = torch.zeros(total)index = 0# 获取所有BN层的 gamma 参数,存储在nn.BatchNorm2d.weight.data# beta 参数存储在nn.BatchNorm2d.bias.datafor m in model.modules():if isinstance(m, nn.BatchNorm2d):size = m.weight.data.shape[0]bn[index:(index+size)] = m.weight.data.abs().clone()index += size# 获取thresholdy, i = torch.sort(bn)thre_index = int(total * args.percent)thre = y[thre_index]pruned  = 0cfg = []cfg_mask = []for k, m in enumerate(model.modules()):if isinstance(m, nn.BatchNorm2d):weight_copy = m.weight.data.abs().clone()mask = weight_copy.gt(thre).float().cuda()pruned = pruned + mask.shape[0] - torch.sum(mask)m.weight.data.mul_(mask)m.bias.data.mul_(mask)cfg.append(int(torch.sum(mask)))cfg_mask.append(mask.clone())print('layer index: {:d} \\t total channel: {:d} \\t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))elif isinstance(m, nn.MaxPool2d):cfg.append('M')pruned_ratio = pruned / totalprint("Pre-process Sucessful Pruned Ratio: {:.2f}%".format(pruned_ratio * 100.))acc = test(model)print(cfg)


layer index: 3   total channel: 64       remaining channel: 63
layer index: 7   total channel: 128      remaining channel: 126
layer index: 11          total channel: 256      remaining channel: 227
layer index: 14          total channel: 256      remaining channel: 162
layer index: 18          total channel: 512      remaining channel: 180
layer index: 21          total channel: 512      remaining channel: 194
layer index: 25          total channel: 512      remaining channel: 191
layer index: 28          total channel: 512      remaining channel: 232
Pre-process Sucessful Pruned Ratio: 50.04%
Files already downloaded and verifiedTest set: Accuracy: 1757/40 (17.6%)[63, 'M', 126, 'M', 227, 162, 'M', 180, 194, 'M', 191, 232]


# old
[64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512]
# new
[63, 'M', 126, 'M', 227, 162, 'M', 180, 194, 'M', 191, 232]

2.5 建立新模型并存储信息



if __name__ == "__main__":newmodel = VGG(cfg=cfg)if args.cuda:newmodel.cuda()num_parameters = sum([param.nelement() for param in newmodel.parameters()])savepath = os.path.join(args.save, "prune.txt")with open(savepath, 'w') as fp:fp.write("Configuation: " + str(cfg) + "\\n")fp.write("Number of parameters: " + str(num_parameters) + "\\n")fp.write("Test accuracy: " + str(acc))layer_id_in_cfg = 0start_mask = torch.ones(3)end_mask = cfg_mask[layer_id_in_cfg]for [m0, m1] in zip(model.modules(), newmodel.modules()):pass

2.6 BatchNorm层的剪枝

说明:start_mask和end_mask => 对应于Conv+BN层的输入和输出


if __name__ == "__main__":...for [m0, m1] in zip(model.modules(), newmodel.modules()):if isinstance(m0, nn.BatchNorm2d):idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))if idx1.size == 1:idx1 = np.resize(idx1, (1,))m1.weight.data = m0.weight.data[idx1.tolist()].clone()m1.bias.data   = m0.bias.data[idx1.tolist()].clone()m1.running_mean = m0.running_mean[idx1.tolist()].clone()m1.running_var  = m0.running_var[idx1.tolist()].clone()layer_id_in_cfg += 1start_mask = end_mask.clone()if layer_id_in_cfg < len(cfg_mask):end_mask = cfg_mask[layer_id_in_cfg]

2.7 Conv2d的剪枝


if __name__ == "__main__":...for [m0, m1] in zip(model.modules(), newmodel.modules()):...elif isinstance(m0, nn.Conv2d):idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))print("In channels: {:d}, Out channels: {:d}".format(idx0.size, idx1.size))if idx0.size == 1:idx0 = np.resize(idx0, (1,))if idx1.size == 1:idx1 = np.resize(idx1, (1,))w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()w1 = w1[idx1.tolist(), :, :, :].clone()m1.weight.data = w1.clone()

2.8 Linear的剪枝


if __name__ == "__main__":...for [m0, m1] in zip(model.modules(), newmodel.modules()):...elif isinstance(m0, nn.Linear):idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))if idx0.size == 1:idx0 = np.resize(idx0, (1,))m1.weight.data = m0.weight.data[:, idx0].clone()m1.bias.data   = m0.bias.data.clone() 



import os
import argparse
import numpy as np
import torch
import torch.nn as nnfrom models.vgg import VGG
from utils import get_test_dataloaderdef parse_opt():# Prune settingsparser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune')parser.add_argument('--dataset', type=str, default='cifar100', help='training dataset (default: cifar10)')parser.add_argument('--test-batch-size', type=int, default=256, metavar='N', help='input batch size for testing (default: 256)')parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')parser.add_argument('--depth', type=int, default=19, help='depth of the vgg')parser.add_argument('--percent', type=float, default=0.5, help='scale sparse rate (default: 0.5)')parser.add_argument('--model', default='', type=str, metavar='PATH', help='path to the model (default: none)')parser.add_argument('--save', default='logs/', type=str, metavar='PATH', help='path to save pruned model (default: none)')args = parser.parse_args()return args# simple test model after Pre-processing prune (simple set BN scales to zeros)
# Define a function named test that takes a PyTorch model as input
def test(model):# Set kwargs to num_workers=1 and pin_memory=True if args.cuda is True, # otherwise kwargs is an empty dictionarykwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}# Create a test data loader for the CIFAR10 dataset if args.dataset is 'cifar10'if args.dataset == 'cifar10':test_loader = get_test_dataloader(batch_size=args.test_batch_size, **kwargs)else:raise ValueError("No valid dataset is given.")# Set the model to evaluation modemodel.eval()# Initialize the number of correct predictions to 0correct = 0# Turn off gradient calculation during inferencewith torch.no_grad():# Loop through the test datafor data, target in test_loader:# Move the data and target tensors to the GPU if args.cuda is Trueif args.cuda:data, target = data.cuda(), target.cuda()# Compute the output of the model on the input dataoutput = model(data)# Compute the predictions from the output using the argmax operationpred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability# Compute the number of correct predictions and add it to the running totalcorrect += pred.eq(target.data.view_as(pred)).cpu().sum()# Compute the test accuracy and print the resultaccuracy = 100. * correct / len(test_loader.dataset)print('\\nTest set: Accuracy: {}/{} ({:.1f}%)\\n'.format(correct, len(test_loader.dataset), accuracy))# Return the test accuracy as a floatreturn accuracy / 100.if __name__ == '__main__':# Parse command line arguments using the parse_opt() functionargs = parse_opt()# Check if CUDA is available and set args.cuda flag accordinglyargs.cuda = not args.no_cuda and torch.cuda.is_available()# Create the save directory if it does not existif not os.path.exists(args.save):os.makedirs(args.save)# Create a new VGG model with the specified depthmodel = VGG(depth=args.depth)# Move the model to the GPU if args.cuda is Trueif args.cuda:model.cuda()# If args.model is not None, # attempt to load a checkpoint from the specified fileif args.model:if os.path.isfile(args.model):print("=> loading checkpoint '{}'".format(args.model))checkpoint = torch.load(args.model)args.start_epoch = checkpoint['epoch']best_prec1 = checkpoint['best_prec1']model.load_state_dict(checkpoint['state_dict'])print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}".format(args.model, checkpoint['epoch'], best_prec1))else:print("=> no checkpoint found at '{}'".format(args.model))# Print the model to the consoleprint(model)# Initialize the total number of channels to 0total = 0# Loop through the model's modules and count the number of channels in each BatchNorm2d layerfor m in model.modules():if isinstance(m, nn.BatchNorm2d):total += m.weight.data.shape[0]# Create a new tensor to store the absolute values of the weights of each BatchNorm2d layerbn = torch.zeros(total)# Initialize an index variable to 0index = 0# Loop through the model's modules again and # store the absolute values of the weights of each BatchNorm2d layer in the bn tensorfor m in model.modules():if isinstance(m, nn.BatchNorm2d):size = m.weight.data.shape[0]bn[index:(index+size)] = m.weight.data.abs().clone()index += size# Sort the bn tensor and compute the threshold value for pruningy, i = torch.sort(bn)thre_index = int(total * args.percent)thre = y[thre_index]# Initialize the number of pruned channels to 0 and # create lists to store the new configuration and mask for each layerpruned = 0cfg = []cfg_mask = []# Loop through the model's modules a third time and # prune each BatchNorm2d layer that falls below the threshold valuefor k, m in enumerate(model.modules()):if isinstance(m, nn.BatchNorm2d):# Compute a mask indicating which weights to keep and which to pruneweight_copy = m.weight.data.abs().clone()mask = weight_copy.gt(thre).float().cuda()pruned = pruned + mask.shape[0] - torch.sum(mask)# Apply the mask to the weight and bias tensors of the BatchNorm2d layerm.weight.data.mul_(mask)m.bias.data.mul_(mask)# Record the new configuration and mask for this layercfg.append(int(torch.sum(mask)))cfg_mask.append(mask.clone())# Print information about the pruning for this layerprint('layer index: {:d} \\t total channel: {:d} \\t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))elif isinstance(m, nn.MaxPool2d):# If the module is a MaxPool2d layer, # record it as an 'M' in the configuration listcfg.append('M')# Compute the ratio of pruned channels to total channelspruned_ratio = pruned/total# Print a message indicating that the pre-processing was successfulprint('Pre-processing Successful!')# Evaluate the pruned model on the test set and # store the accuracy in the acc variableacc = test(model)# ============================ Make real prune ============================# Print the new configuration to the consoleprint(cfg)# Initialize a new VGG model with the pruned configurationnewmodel = VGG(cfg=cfg)# Move the new model to the GPU if availableif args.cuda:newmodel.cuda()# Compute the number of parameters in the new model num_parameters = sum([param.nelement() for param in newmodel.parameters()])# Save the configuration above, number of parameters, and test accuracy to a filesavepath = os.path.join(args.save, "prune.txt")with open(savepath, "w") as fp:fp.write("Configuration: \\n"+str(cfg)+"\\n")fp.write("Number of parameters: "+str(num_parameters)+"\\n")fp.write("Test accuracy: "+str(acc))# Initialize variables for the masks corresponding to the start and end of each pruned layerlayer_id_in_cfg = 0start_mask = torch.ones(3)end_mask = cfg_mask[layer_id_in_cfg]# Loop through the modules of the original and new models# Copy the weights and biases of each layer from the original model to the new model# Applying the appropriate masks to the weights and biases of the pruned layersfor [m0, m1] in zip(model.modules(), newmodel.modules()):# ============================ BatchNorm Layers ============================# If the module is a BatchNorm2d layer, # compute the indices of the non-zero weights and biases in the new model and # copy them from the original modelif isinstance(m0, nn.BatchNorm2d):# Compute the list of indices of the remaining channels in the current BatchNorm2d layeridx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))# Resize the index list if it has only one elementif idx1.size == 1:idx1 = np.resize(idx1,(1,))# Compute the weight of the current layer # by copying only the weights of the remaining channels using the index listm1.weight.data = m0.weight.data[idx1.tolist()].clone()# Compute the bias of the current layer # by copying the bias values of the original layer and then clonedm1.bias.data = m0.bias.data[idx1.tolist()].clone()# Compute the running mean of the current layer by # copying the mean values of the original layer and then clonedm1.running_mean = m0.running_mean[idx1.tolist()].clone()# Compute the running variance of the current layer by # copying the variance values of the original layer and then clonedm1.running_var = m0.running_var[idx1.tolist()].clone()# Update the masks for the next pruned layerlayer_id_in_cfg += 1start_mask = end_mask.clone()if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FCend_mask = cfg_mask[layer_id_in_cfg]# ============================ Conv2d Layers ============================# If the module is a Conv2d layer, # compute the indices of the non-zero weights in the input and output channels and # copy them from the original modelelif isinstance(m0, nn.Conv2d):# Get the indices of input and output channels that are not pruned for this convolutional layer, # by converting the start and end masks from the previous and current layers into numpy arrays, # finding the non-zero elements, and removing the extra dimensionsidx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))# Print the number of input and output channels that are not prunedprint('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))# If either idx0 or idx1 has a size of 1, # resize it to (1,) to avoid a broadcasting error.if idx0.size == 1:idx0 = np.resize(idx0, (1,))if idx1.size == 1:idx1 = np.resize(idx1, (1,))# Extract the weight tensor for this layer from the original model (m0) # by selecting the input and output channels that are not pruned, # and clone it to create a new tensor (w1)w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()w1 = w1[idx1.tolist(), :, :, :].clone()m1.weight.data = w1.clone()# ============================ Linear Layers ============================# If the module is a Linear layer, # compute the indices of the non-zero weights in the input channels and # copy them from the original modelelif isinstance(m0, nn.Linear):# Compute the list of indices of the remaining neurons/channels # of the previous layer that connect to this current linear layeridx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))# Resize the index list if it has only one elementif idx0.size == 1:idx0 = np.resize(idx0, (1,))# Compute the weight of the current layer # by copying only the weights of the remaining channels of the previous layer # using the index listm1.weight.data = m0.weight.data[:, idx0].clone()m1.bias.data   = m0.bias.data.clone()torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned.pth'))print(newmodel)model = newmodeltest(model)


本次课程完成了对VGG模型的剪枝训练,主要是复现论文中对BN层的γ\\gammaγ参数进行稀疏训练,得到对应的mask后对Conv2d、Batch Normalize以及Linear层进行剪枝,可以看到剪枝后的模型的参数量大大减少(71M=>9.6M),且预测准确率反而提高了(87.4%=>88.4%),而对YOLOv8采用该方法进行剪枝时,其精度会略微下降,但是其参数量会大大减少,具有可应用性,期待下基于YOLOv8的剪枝吧🤣