剪枝与重参第四课:NVIDIA的2:4剪枝方案
目录
- NVIDIA的2:4 pattern稀疏方案
-
- 前言
- 1.稀疏性的研究现状
- 2.图解nvidia2-4稀疏方案
- 3.训练策略
- 4.手写复现
-
- 4.1 大体框架
- 4.2 ASP类的实现
- 4.3 mask的实现
- 4.4 模型初始化
- 4.5 Layer嵌入稀疏特性
- 4.6 优化器初始化
- 4.7 拓展-dynamic function assignment
- 4.8 完整示例代码
- 总结
NVIDIA的2:4 pattern稀疏方案
前言
手写AI推出的全新模型剪枝与重参课程。记录下个人学习笔记,仅供自己参考。
本次课程主要讲解NVIDIA的2:4剪枝方案。
reference:
ASP nvidia 2:4 pattern pruning
paper:
- Accelerating Sparse Deep Neural Networks
code:
- https://github.com/NVIDIA/apex/tree/master/apex/contrib/sparsity
blog:
- https://developer.nvidia.com/blog/accelerating-inference-with-sparsity-using-ampere-and-tensorrt/
tensor core:
- https://developer.nvidia.com/blog/programming-tensor-cores-cuda-9/
课程大纲可看下面的思维导图
1.稀疏性的研究现状
许多研究集中在两方面:
- 大量(80-95%)的非结构化、细粒度稀疏
- 用于简单加速的粗粒度稀疏
这些方法所面临的挑战有:
-
精度损失
- 高稀疏度往往会导致准确率损失几个百分点,即使拥有先进的训练技术也是如此
-
缺少一种适用于不同任务和网络的训练方法
- 恢复准确性的训练方法因网络而异,通常需要超参数搜索
-
缺少加速
- Math:非结构数据难以利用现代向量/矩阵数学指令的优势
- Memory access:非结构化数据往往不能很好地利用内存总线,由于读操作之间存在依赖关系,导致延迟增加
- Storage overheads:metadata占用的存储空间比非零权重多消耗2倍,从而抵消了一些压缩的好处。(metadata通常指的是对于权重矩阵的稀疏性描述信息,例如哪些位置是零元素,哪些位置是非零元素)
2.图解nvidia2-4稀疏方案
NVIDIA在处理稀疏矩阵W时,会采用2:4稀疏方案。在这个方案中,稀疏矩阵W首先会被压缩,压缩后的矩阵存储着非零的数据值,而metadata则存储着对应非零元素在原矩阵W中的索引信息。具体来说,metadata会将W中非零元素的行号和列号压缩成两个独立的一维数组,这两个数组就是metadata中存储的索引信息。如下图所示:
对于大型矩阵相乘时,我们可以采用2:4稀疏方案减少计算量,假设矩阵A和B相乘得到C,正常运算如下图所示:
我们可以将A矩阵进行剪枝使其变得稀疏,如下图所示:
而针对于稀疏矩阵,我们可以通过上述的NVIDIA方案将其变为2:4的结构,可以将A矩阵进行压缩,而对矩阵B的稀疏是通过硬件上面的Sparse Tensor Cores进行选择,如下图所示:
3.训练策略
NVIDIA提供的2:4稀疏训练方案步骤如下:
- 1)训练网络
- 2)2:4稀疏剪枝
- 3)重复原始的训练流程
- 超参数的选择与步骤1一致
- 权重的初始化与步骤2一致
- 保持步骤2中的 0 patter:不需要重新计算mask
图示如下:
4.手写复现
4.1 大体框架
示例代码如下:
import os
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as Fmodel = None
optimizer = Noneclass ToyDataset(Dataset):def __init__(self):x = torch.round(torch.rand(1000) * 200) # (1000,)x = x.unsqueeze(1) # (1000,1)x = torch.cat((x, x * 2, x * 3, x * 4, x * 5, x * 6, x * 7, x * 8), 1) # (1000,8)self.X = xself.Y = self.Xdef __getitem__(self, index):return self.X[index], self.Y[index]def __len__(self):return len(self.X)training_loader = DataLoader(ToyDataset(), batch_size=100, shuffle=True)def train():criterion = nn.MSELoss()for i in range(500):for x, y in training_loader:loss = criterion(model(x.to("cuda")), y.to("cuda"))optimizer.zero_grad()loss.backward()optimizer.step()print("epoch #%d: loss: %f" % (i, loss.item()))def test():x = torch.tensor([2, 4, 6, 8, 10, 12, 14, 16]).float()y_hat = model(x.to("cuda"))print("input: ", x, "\\n", "predict: ", y_hat)def get_model(path):global model, optimizerif os.path.exists(path):model = torch.load(path).cuda()optimizer = optim.Adam(model.parameters(), lr=0.01)else:model = nn.Sequential(nn.Linear(8, 16),nn.PReLU(),nn.Linear(16, 8)).cuda()optimizer = optim.Adam(model.parameters(), lr=0.01)train()torch.save(model, path)class ASP():...if __name__ == "__main__":# ---------------- train ----------------get_model("./model.pt")print("-------orig-------")test()# ---------------- prune ----------------ASP.prune_trained_model(model, optimizer)print("-------pruned-------")test()# ---------------- finetune ----------------train()print("-------retrain-------")test()torch.save(model, "./model_sparse.pt")
上述示例代码演示了2:4稀疏方案的大体框架,包括数据集准备、模型训练、模型剪枝、模型微调和模型保存等步骤。剪枝方案为ASP(Automatic SParsity),主要实现的是前面提到过的2:4稀疏剪枝,其具体实现细节在ASP类中。
4.2 ASP类的实现
ASP类的实现示例代码如下:
class ASP():@classmethoddef init_model_for_pruning(model, mask_calculater, whitelist):pass@classmethoddef init_optimizer_for_pruning(optimizer):pass@classmethoddef compute_sparse_masks():pass@classmethoddef prune_trained_model(cls, model, optimizer):cls.init_model_for_pruning(model,mask_calculater = "m4n2_1d",whitelist = [torch.nn.Linear, torch.nn.Conv2d])cls.init_optimizer_for_pruning(optimizer)cls.compute_sparse_masks() # 2:4
在上面的示例代码中,ASP的类方法prune_trained_model
会对训练好的模型进行剪枝操作,首先它会去调用init_model_for_pruning
和init_optimizer_for_pruning
对模型和优化器进行初始化,然后调用compute_sparse_masks
生成稀疏掩码(具体首先见4.3),最后使用掩码对模型进行剪枝。
4.3 mask的实现
我们来看下核心部分,mask的实现,2:4的方案就是在一张密集的weights中实现每4个weight取其中两个比较大的,其他两个置0,如下图所示:
最简单的实现方案就是遍历所有的weights,每4个进行比较,然后将较大的weight所对应的mask置1,其他mask置0,如下图所示:
而NVIDIA的方案是首先创建一个patterns,如下图所示,由于是2:4的方案,所有总共有6种不同的pattern;然后将weight matrix变换成nx4的格式方便与pattern进行矩阵运算,运算后的结果为nx6的矩阵,在n的维度上进行argmax取得最大的索引(索引对应pattern),然后将索引对应的pattern值填充到mask中即可。
示例代码如下:
import sys
import torch
import numpy as np
from itertools import permutationsdef reshape_1d(matrix, m):# If not a nice multiple of m, fill with zerosif matrix.shape[1] % m > 0:mat = torch.cuda.FloatTensor(matrix.shape[0], matrix.shape[1] + (m - matrix.shape[1] % m)).fill_(0)mat[:, : matrix.shape[1]] = matrixshape = mat.shapereturn mat.view(-1, m), shapeelse:return matrix.view(-1, m), matrix.shapedef compute_valid_1d_patterns(m,n):patterns = torch.zeros(m)patterns[:n] = 1valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist()))))return valid_patternsdef mn_1d_best(matrix, m, n):# find all possible patternspatterns = compute_valid_1d_patterns(m,n).cuda()# find the best m:n patternmask = torch.cuda.IntTensor(matrix.shape).fill_(1).view(-1,m)mat, shape = reshape_1d(matrix, m)pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1)mask[:] = patterns[pmax[:]]mask = mask.view(matrix.shape)return maskdef m4n2_1d(mat, density):return mn_1d_best(mat, 4, 2)def m4n3_1d(mat, density):passdef create_mask(weight, pattern, density=0.5):t = weight.float().contiguous()shape = weight.shapettype = weight.type()func = getattr(sys.modules[__name__], pattern, None) # automatically find the function you want, and call itmask = func(t, density)return mask.view(shape).type(ttype)if __name__ == "__main__":weight = torch.randn(8, 16).to("cuda")def create_mask_from_pattern(weight):return create_mask(weight, "m4n2_1d").bool() # 工厂模式 factory method 不同的情况创建不同的对象mask = create_mask_from_pattern(weight)mask = ~mask # for visualize# visualize the weightimport matplotlib.pyplot as plt# Calculate the absolute valuesabs_weight = torch.abs(weight)# Convert to a numpy array for plottingabs_weight_np = abs_weight.cpu().numpy()# visualize the maskmask = mask.cpu().numpy().astype(float)# Plot the matrixdef annotate_image(image_data, ax=None, text_color='red', fontsize=50):if ax is None:ax = plt.gca()for i in range(image_data.shape[0]):for j in range(image_data.shape[1]):ax.text(j, i, f"{image_data[i, j]:.2f}", ha="center", va="center", color=text_color, fontsize=fontsize)fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(80, 20))ax1.imshow(abs_weight_np, cmap="gray", vmin=0, vmax=torch.max(abs_weight).item())ax1.axis("off")annotate_image(abs_weight_np, ax=ax1)ax1.set_title("weight", fontsize=100)ax2.imshow(mask, cmap="gray", vmin=0, vmax=np.max(mask).item())ax2.axis("off")ax2.set_title("mask", fontsize=100)plt.savefig("param_and_mask.jpg", bbox_inches='tight', dpi=100)
在上面的示例代码中,mn_1d_best
函数实现了在指定大小的矩阵中寻找最佳的m:4的mask矩阵,具体实现可看上述的图示流程,m4n2_1d
函数则是对mn_1d_best
函数的进一步封装,指定了m=4,n=4,即寻找最佳的2:4的mask矩阵,create_mask
函数则根据给定的权重矩阵、mask生成函数名和稀疏度生成相应的mask矩阵。
可视化结果如下,其中mask中白色区域填充的是0,黑色区域代表的是1:
4.4 模型初始化
示例代码如下:
class ASP():model = Noneoptimizer = Nonesparse_parameters = []calculate_mask = None@classmethoddef init_model_for_pruning(cls,model,mask_calculater,whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d],custom_layer_dict={}):assert cls.model is None, "ASP has initialized already"cls.model = modelif isinstance(mask_calculater, str):def create_mask_from_pattern(param):return create_mask(param, mask_calculater).bool()cls.calculate_mask = create_mask_from_pattern # dynamic function assignmentsparse_parameter_list = {torch.nn.Linear: ["weight"],torch.nn.Conv1d: ["weight"],torch.nn.Conv2d: ["weight"]}if (custom_layer_dict):# Update default list to include user supplied custom (layer type : parameter tensor), make sure this tensor type is something ASP knows how to prunesparse_parameter_list.update(custom_layer_dict)whitelist += list(custom_layer_dict.keys())for module_type in whitelist:assert module_type in sparse_parameter_list, ("Module %s : Don't know how to sparsify module." % module_type)# find all sparse modules, extract sparse parameters and decoratedef add_sparse_attributes(module_name, module):...def eligible_modules(model, whitelist_layer_types):eligible_modules_list = []for name, mod in model.named_modules():if(isinstance(mod, whitelist_layer_types)):eligible_modules_list.append((name, mod))return eligible_modules_listfor name, sparse_module in eligible_modules(model, tuple(whitelist)):add_sparse_attributes(name, sparse_module)
上面示例代码主要实现ASP
中的类方法init_model_for_pruning
,它的作用是初始化模型,该类方法主要有以下几点说明:
- 该方法通过传入的参数
mask_calculater
调用函数create_mask_from_pattern
,这个函数的作用是根据传入的参数生成一个稀疏矩阵掩码,也就是4.3小节的内容 - 该方法会根据传入的参数
whitelist
和custom_layer_dict
找到所有需要进行稀疏化的模块,这些模块的类型必须在whitelist
中指定,并且每种模块包含一个或多个需要稀疏化的参数。这些信息都被保存在一个字典sparse_parameter_list
中 - 如果这个模块的类型在
whitelist
中,那么就会调用add_sparse_attributes
方法对这个模块进行稀疏化处理(该函数的具体实现可参考4.5小节)
4.5 Layer嵌入稀疏特性
示例代码如下:
# find all sparse modules, extract sparse parameters and decorate
def add_sparse_attributes(module_name, module):sparse_parameters = sparse_parameter_list[type(module)]for p_name, p in module.named_parameters():if p_name in sparse_parameters and p.requires_grad:# check for NVIDIA's TC compatibility: we check along the horizontal directionif p.dtype == torch.float32 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): # User defines FP32 and APEX internally uses FP16 mathprint("[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity"% (module_name, p_name, str(p.size()), str(p.dtype)))continueif p.dtype == torch.float16 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): # For Conv2d dim= K x CRS; we prune along Cprint("[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity"% (module_name, p_name, str(p.size()), str(p.dtype)))continuemask = torch.ones_like(p).bool()buffname = p_name.split(".")[-1] # buffer name cannot contain "."module.register_buffer("__%s_mma_mask" % buffname, mask)cls.sparse_parameters.append((module_name, module, p_name, p, mask))
函数add_sparse_attributes
的作用是给模型中的每个可稀疏化的参数添加相应的稀疏度掩码。具体来说,函数首先检查模型中每个模块的参数是否在可稀疏化的参数列表中,并且梯度需要计算。然后,函数会检查参数的尺寸是否满足NVIDIA TC(Tensor Cores)的要求。如果满足,则添加一个与参数形状相同的稀疏度掩码。掩码是一个布尔张量,对应于参数中的每个元素。掩码初始化全1,表示所有参数都被保留。在后续的稀疏化操作中,将根据每个参数的稀疏度掩码来确定哪些参数需要被稀疏化。最后,函数将所有的稀疏化参数(包括稀疏度掩码)的元组添加到类变量sparse_parameters
中。
4.6 优化器初始化
示例代码如下:
def init_optimizer_for_pruning(cls, optimizer):assert cls.optimizer is None, "ASP has initialized optimizer already."assert (cls.calculate_mask is not None), "Called ASP.init_optimizer_for_pruning before ASP.init_model_for_pruning."cls.optimizer = optimizercls.optimizer.__step = optimizer.stepdef __step(opt_self, *args, kwargs): # two pruning part: 1) grad 2) weight# p.grad p.datawith torch.no_grad():for (module_name, module, p_name, p, mask) in cls.sparse_parameters:if p.grad is not None:p.grad.mul_(mask) # inplace# call original optimizer.steprval = opt_self.__step(*args, kwargs)# prune parameter after step methodwith torch.no_grad():for (module_name, module, p_name, p, mask) in cls.sparse_parameters:p.mul_(mask)return rval
上面的示例代码为初始化优化器,主要是为优化器注册一个新的step
方法,以便在每次更新权重之前进行剪枝。__step
方法先对梯度进行剪枝操作,再调用原优化器对象的__step
方法完成权重更新,然后对权重进行裁剪。先对梯度进行裁剪是因为最终的结果会影响权重的裁剪,如果不对梯度进行裁剪而只对权重进行裁剪可能导致权重大的元素被裁剪。
4.7 拓展-dynamic function assignment
动态函数赋值(Dynamic Function Assignment)是指在运行时动态地指定对象的某个方法实现。
在Python中,我们可以使用函数名作为变量名,将函数赋值给变量。这意味着我们可以根据不同的条件,将不同的函数赋值给同一个变量,以便在后续的代码中调用该变量的函数时,根据不同的条件执行不同的函数。这就是动态函数赋值。
下面是使用了DFA的示例代码:
class Pruner:def __init__(self, pruning_pattern):self.pruning_pattern = pruning_patternif pruning_pattern == 'pattern_A':self.prune = self.prune_pattern_Aelif pruning_pattern == 'pattern_B':self.prune = self.prune_pattern_Belif pruning_pattern == 'pattern_C':self.prune = self.prune_pattern_Cdef prune_pattern_A(self, network):# Perform pruning with pattern A logicpruned_network = ...return pruned_networkdef prune_pattern_B(self, network):# Perform pruning with pattern B logicpruned_network = ...return pruned_networkdef prune_pattern_C(self, network):# Perform pruning with pattern B logicpruned_network = ...return pruned_networkpruner_A = Pruner('pattern_A')
pruned_network_A = pruner_A.prune(network)pruner_B = Pruner('pattern_B')
pruned_network_B = pruner_B.prune(network)pruner_C = Pruner('pattern_C')
pruned_network_C = pruner_C.prune(network)
下面是没有使用DFA的示例代码:
class Pruner:def __init__(self, pruning_pattern):self.pruning_pattern = pruning_patterndef prune(self, network):if self.pruning_pattern == 'pattern_A':return self.prune_pattern_A(network)elif self.pruning_pattern == 'pattern_B':return self.prune_pattern_B(network)def prune_pattern_A(self, network):# Perform pruning with pattern A logicpruned_network = ...return pruned_networkdef prune_pattern_B(self, network):# Perform pruning with pattern B logicpruned_network = ...return pruned_networkpruner_A = Pruner('pattern_A')
pruned_network_A = pruner_A.prune(network)pruner_B = Pruner('pattern_B')
pruned_network_B = pruner_B.prune(network)
从二者的对比可以看出动态函数赋值的优点在于它可以使代码更加灵活、可扩展和可维护。它使我们能够动态地改变函数的行为,从而根据不同的条件来处理数据或执行任务。这使得我们的代码更容易理解和维护,也更具可读性和可重用性。此外,动态函数赋值还可以提高代码的灵活性,使得我们可以更容易地在不同的上下文中使用相同的代码。
4.8 完整示例代码
完整的示例代码如下:
import os
import sys
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from itertools import permutationsmodel = None
optimizer = Nonedef reshape_1d(matrix, m):# If not a nice multiple of m, fill with zerosif matrix.shape[1] % m > 0:mat = torch.cuda.FloatTensor(matrix.shape[0], matrix.shape[1] + (m - matrix.shape[1] % m)).fill_(0)mat[:, : matrix.shape[1]] = matrixshape = mat.shapereturn mat.view(-1, m), shapeelse:return matrix.view(-1, m), matrix.shapedef compute_valid_1d_patterns(m,n):patterns = torch.zeros(m)patterns[:n] = 1valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist()))))return valid_patternsdef mn_1d_best(matrix, m, n):# find all possible patternspatterns = compute_valid_1d_patterns(m,n).cuda()# find the best m:n patternmask = torch.cuda.IntTensor(matrix.shape).fill_(1).view(-1,m)mat, shape = reshape_1d(matrix, m)pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1)mask[:] = patterns[pmax[:]]mask = mask.view(matrix.shape)return maskdef m4n2_1d(mat, density):return mn_1d_best(mat, 4, 2)def m4n3_1d(mat, density):passdef create_mask(weight, pattern, density=0.5):t = weight.float().contiguous()shape = weight.shapettype = weight.type()func = getattr(sys.modules[__name__], pattern, None) # automatically find the function you want, and call itmask = func(t, density)return mask.view(shape).type(ttype)class ToyDataset(Dataset):def __init__(self):x = torch.round(torch.rand(1000) * 200) # (1000,)x = x.unsqueeze(1) # (1000,1)x = torch.cat((x, x * 2, x * 3, x * 4, x * 5, x * 6, x * 7, x * 8), 1) # (1000,8)self.X = xself.Y = self.Xdef __getitem__(self, index):return self.X[index], self.Y[index]def __len__(self):return len(self.X)training_loader = DataLoader(ToyDataset(), batch_size=100, shuffle=True)def train():criterion = nn.MSELoss()for i in range(500):for x, y in training_loader:loss = criterion(model(x.to("cuda")), y.to("cuda"))optimizer.zero_grad()loss.backward()optimizer.step()print("epoch #%d: loss: %f" % (i, loss.item()))def test():x = torch.tensor([2, 4, 6, 8, 10, 12, 14, 16]).float()y_hat = model(x.to("cuda"))print("input: ", x, "\\n", "predict: ", y_hat)def get_model(path):global model, optimizerif os.path.exists(path):model = torch.load(path).cuda()optimizer = optim.Adam(model.parameters(), lr=0.01)else:model = nn.Sequential(nn.Linear(8, 16),nn.PReLU(),nn.Linear(16, 8)).cuda()optimizer = optim.Adam(model.parameters(), lr=0.01)train()torch.save(model, path)class ASP():model = Noneoptimizer = Nonesparse_parameters = []calculate_mask = None@classmethoddef init_model_for_pruning(cls,model,mask_calculater,whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d],custom_layer_dict={}):assert cls.model is None, "ASP has initialized already"cls.model = modelif isinstance(mask_calculater, str):def create_mask_from_pattern(param):return create_mask(param, mask_calculater).bool()cls.calculate_mask = create_mask_from_pattern # dynamic function assignmentsparse_parameter_list = {torch.nn.Linear: ["weight"],torch.nn.Conv1d: ["weight"],torch.nn.Conv2d: ["weight"]}if (custom_layer_dict):# Update default list to include user supplied custom (layer type : parameter tensor), make sure this tensor type is something ASP knows how to prunesparse_parameter_list.update(custom_layer_dict)whitelist += list(custom_layer_dict.keys())for module_type in whitelist:assert module_type in sparse_parameter_list, ("Module %s : Don't know how to sparsify module." % module_type)# find all sparse modules, extract sparse parameters and decoratedef add_sparse_attributes(module_name, module):sparse_parameters = sparse_parameter_list[type(module)]for p_name, p in module.named_parameters():if p_name in sparse_parameters and p.requires_grad:# check for NVIDIA's TC compatibility: we check along the horizontal directionif p.dtype == torch.float32 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): # User defines FP32 and APEX internally uses FP16 mathprint("[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity"% (module_name, p_name, str(p.size()), str(p.dtype)))continueif p.dtype == torch.float16 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): # For Conv2d dim= K x CRS; we prune along Cprint("[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity"% (module_name, p_name, str(p.size()), str(p.dtype)))continuemask = torch.ones_like(p).bool()buffname = p_name.split(".")[-1] # buffer name cannot contain "."module.register_buffer("__%s_mma_mask" % buffname, mask)cls.sparse_parameters.append((module_name, module, p_name, p, mask))def eligible_modules(model, whitelist_layer_types):eligible_modules_list = []for name, mod in model.named_modules():if(isinstance(mod, whitelist_layer_types)):eligible_modules_list.append((name, mod))return eligible_modules_listfor name, sparse_module in eligible_modules(model, tuple(whitelist)):add_sparse_attributes(name, sparse_module)@classmethoddef init_optimizer_for_pruning(cls, optimizer):assert cls.optimizer is None, "ASP has initialized optimizer already."assert (cls.calculate_mask is not None), "Called ASP.init_optimizer_for_pruning before ASP.init_model_for_pruning."cls.optimizer = optimizercls.optimizer.__step = optimizer.stepdef __step(opt_self, *args, kwargs): # two pruning part: 1) grad 2) weight# p.grad p.datawith torch.no_grad():for (module_name, module, p_name, p, mask) in cls.sparse_parameters:if p.grad is not None:p.grad.mul_(mask) # inplace# call original optimizer.steprval = opt_self.__step(*args, kwargs)# prune parameter after step methodwith torch.no_grad():for (module_name, module, p_name, p, mask) in cls.sparse_parameters:p.mul_(mask)return rval@classmethoddef compute_sparse_masks():pass@classmethoddef prune_trained_model(cls, model, optimizer):cls.init_model_for_pruning(model,mask_calculater = "m4n2_1d",whitelist = [torch.nn.Linear, torch.nn.Conv2d])cls.init_optimizer_for_pruning(optimizer)cls.compute_sparse_masks() # 2:4if __name__ == "__main__":# ---------------- train ----------------get_model("./model.pt")print("-------orig-------")test()# ---------------- prune ----------------ASP.prune_trained_model(model, optimizer)print("-------pruned-------")test()# ---------------- finetune ----------------train()print("-------retrain-------")test()torch.save(model, "./model_sparse.pt")
总结
本次课程主要学习了NVIDIA的2:4 pattern稀疏方案,并手写复现了一部分重要的功能。