
def calcTime():import numpy as npfrom torchvision.models import resnet50import torchfrom torch.backends import cudnnimport tqdm''' 导入你的模型from module.amsnet import amsnet, anet, msnet, iresnet18, anet2, iresnet2, amsnet2from module.resnet import resnet18, resnet34from module.alexnet import AlexNetfrom module.vgg import vggfrom module.lenet import LeNetfrom module.googLenet import GoogLeNetfrom module.ivgg import iVGG'''cudnn.benchmark = Truedevice = 'cuda:0'model = anet().to(device)repetitions = 1000dummy_input = torch.rand(1, 3, 224, 224).to(device)print('warm up ...\\n')with torch.no_grad():for _ in range(100):_ = model(dummy_input)torch.cuda.synchronize()starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)timings = np.zeros((repetitions, 1))print('testing ...\\n')with torch.no_grad():for rep in tqdm.tqdm(range(repetitions)):starter.record()_ = model(dummy_input)ender.record()torch.cuda.synchronize() curr_time = starter.elapsed_time(ender) timings[rep] = curr_timeavg = timings.sum() / repetitionsprint('\\navg={}\\n'.format(avg))