> 文章列表 > 神经微分方程Resnet变体实现内存下降和保持精度

神经微分方程Resnet变体实现内存下降和保持精度

神经微分方程Resnet变体实现内存下降和保持精度

本文内容:

1、学习神经微分方程的笔记,主要锻炼自己学习新知识的能力和看有很多数学原理的论文能力;

2、神经微分方程可以用于时序数据建模、动力学建模等,但是本文专注于分类问题-resnet变体<比较容易理解>;

个人理解:

联合灵敏度的代码实现比较复杂,代码逻辑和算法步骤是一样的,对照看就很容易明白,其实本质上就是把梯度计算归结为求解微分的问题:

工程上实现OdeintAdjointMethod的方法是继承torch.autograd.Function类,实现forward和backward方法,将forward和backward替换成ODE求解器的方式,而不是用原来torch.autograd.Function的链式法则进行梯度求解。

基本原理:

梯度反传法是用于训练神经网络的方法,可以避免使用反向传播训练导数函数时所遇到的可扩展性问题。这种方法涉及使用普通微分方程(ODE)求解器进行前向传播,然后使用联合灵敏度方法进行反向传播,从而使得可以再次使用ODE求解器进行反向传播。为了更新导数函数的参数,需要使用联合灵敏度方法获取损失函数相对于动态函数参数的梯度。最终算法涉及设置某些变量并将信息打包到其中,然后调用ODE求解器反向传播以获得theta并更新CNN编码器和导数函数参数。

梯度反传法算法流程如下:

在这里插入图片描述

更完整的版本:

在这里插入图片描述

联合灵敏度的代码实现比较复杂,代码逻辑和算法步骤是一样的,对照看就很容易明白,其实本质上就是把梯度计算归结为求解微分的问题:

工程上实现OdeintAdjointMethod的方法是继承torch.autograd.Function类,实现forward和backward方法,将forward和backward替换成ODE求解器的方式,而不是用原来torch.autograd.Function的链式法则进行梯度求解。odeint则是本文使用的ODE求解器。

代码仓库提供的求解器种类如下:

SOLVERS = {'dopri8': Dopri8Solver,'dopri5': Dopri5Solver,'bosh3': Bosh3Solver,'fehlberg2': Fehlberg2,'adaptive_heun': AdaptiveHeunSolver,'euler': Euler,'midpoint': Midpoint,'rk4': RK4,'explicit_adams': AdamsBashforth,'implicit_adams': AdamsBashforthMoulton,# Backward compatibility: use the same name as before'fixed_adams': AdamsBashforthMoulton,# ~Backwards compatibility'scipy_solver': ScipyWrapperODESolver,
}

完整源码在:rtqichen/torchdiffeq: Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation. (github.com),这里放上核心部分和注释,前向和反向传播部分;

class OdeintAdjointMethod(torch.autograd.Function):@staticmethoddef forward(ctx, shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, adjoint_method,adjoint_options, t_requires_grad, *adjoint_params):ctx.shapes = shapesctx.func = funcctx.adjoint_rtol = adjoint_rtolctx.adjoint_atol = adjoint_atolctx.adjoint_method = adjoint_methodctx.adjoint_options = adjoint_optionsctx.t_requires_grad = t_requires_gradctx.event_mode = event_fn is not Nonewith torch.no_grad():ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options, event_fn=event_fn)if event_fn is None:y = ansctx.save_for_backward(t, y, *adjoint_params)else:event_t, y = ansctx.save_for_backward(t, y, event_t, *adjoint_params)return ans@staticmethoddef backward(ctx, *grad_y):with torch.no_grad():func = ctx.funcadjoint_rtol = ctx.adjoint_rtoladjoint_atol = ctx.adjoint_atoladjoint_method = ctx.adjoint_methodadjoint_options = ctx.adjoint_optionst_requires_grad = ctx.t_requires_grad# 反向传播如果积分到达时间,不会在事件时间内反向传播。# Backprop as if integrating up to event time.# Does NOT backpropagate through the event time.event_mode = ctx.event_modeif event_mode:t, y, event_t, *adjoint_params = ctx.saved_tensors_t = tt = torch.cat([t[0].reshape(-1), event_t.reshape(-1)])grad_y = grad_y[1]else:t, y, *adjoint_params = ctx.saved_tensorsgrad_y = grad_y[0]adjoint_params = tuple(adjoint_params)##      创建初始化状态      ### [-1] because y and grad_y are both of shape (len(t), *y0.shape)aug_state = [torch.zeros((), dtype=y.dtype, device=y.device), y[-1], grad_y[-1]]  # vjp_t, y, vjp_yaug_state.extend([torch.zeros_like(param) for param in adjoint_params])  # vjp_params##    创建反向ODE函数    ### TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives.def augmented_dynamics(t, y_aug):# 动力学函数# Dynamics of the original system augmented with# the adjoint wrt y, and an integrator wrt t and args.y = y_aug[1]adj_y = y_aug[2]# ignore gradients wrt time and parameterswith torch.enable_grad():t_ = t.detach()t = t_.requires_grad_(True)y = y.detach().requires_grad_(True)# If using an adaptive solver we don't want to waste time resolving dL/dt unless we need it (which# doesn't necessarily even exist if there is piecewise structure in time), so turning off gradients#如果使用自适应求解器,不想浪费时间来求解dL/dt,除非我们需要它(如果有分段结构,它甚至不存在),所以关闭梯度# wrt t here means we won't compute that if we don't need it.func_eval = func(t if t_requires_grad else t_, y)# Workaround for PyTorch bug #39784_t = torch.as_strided(t, (), ())  # noqa_y = torch.as_strided(y, (), ())  # noqa_params = tuple(torch.as_strided(param, (), ()) for param in adjoint_params)  # noqavjp_t, vjp_y, *vjp_params = torch.autograd.grad(func_eval, (t, y) + adjoint_params, -adj_y,allow_unused=True, retain_graph=True)# autograd.grad returns None if no gradient, set to zero.vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_tvjp_y = torch.zeros_like(y) if vjp_y is None else vjp_yvjp_params = [torch.zeros_like(param) if vjp_param is None else vjp_paramfor param, vjp_param in zip(adjoint_params, vjp_params)]return (vjp_t, func_eval, vjp_y, *vjp_params)##      求解联合ODE       ##if t_requires_grad:time_vjps = torch.empty(len(t), dtype=t.dtype, device=t.device)else:time_vjps = Nonefor i in range(len(t) - 1, 0, -1):if t_requires_grad:# Compute the effect of moving the current time measurement point.# We don't compute this unless we need to, to save some computation.func_eval = func(t[i], y[i])dLd_cur_t = func_eval.reshape(-1).dot(grad_y[i].reshape(-1))aug_state[0] -= dLd_cur_ttime_vjps[i] = dLd_cur_t# Run the augmented system backwards in time.# 运行增强系统反向aug_state = odeint(augmented_dynamics, tuple(aug_state),t[i - 1:i + 1].flip(0),rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options)aug_state = [a[1] for a in aug_state]  # extract just the t[i - 1] valueaug_state[1] = y[i - 1]  # update to use our forward-pass estimate of the stateaug_state[2] += grad_y[i - 1]  # update any gradients wrt state at this time pointif t_requires_grad:time_vjps[0] = aug_state[0]# 计算梯度# Only compute gradient wrt initial time when in event handling mode.if event_mode and t_requires_grad:time_vjps = torch.cat([time_vjps[0].reshape(-1), torch.zeros_like(_t[1:])])adj_y = aug_state[2]adj_params = aug_state[3:]return (None, None, adj_y, time_vjps, None, None, None, None, None, None, None, None, None, None, *adj_params)

模型架构:

使用small的residual net,在下采样后,有6个标准残差块被ODESolve替代,得到ODE-Net。

Sequential((0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))(1): GroupNorm(32, 64, eps=1e-05, affine=True)(2): ReLU(inplace=True)(3): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))(4): GroupNorm(32, 64, eps=1e-05, affine=True)(5): ReLU(inplace=True)(6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))(7): ODEBlock((odefunc): ODEfunc((norm1): GroupNorm(32, 64, eps=1e-05, affine=True)(relu): ReLU(inplace=True)(conv1): ConcatConv2d((_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)(conv2): ConcatConv2d((_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))(norm3): GroupNorm(32, 64, eps=1e-05, affine=True)))(8): GroupNorm(32, 64, eps=1e-05, affine=True)(9): ReLU(inplace=True)(10): AdaptiveAvgPool2d(output_size=(1, 1))(11): Flatten()(12): Linear(in_features=64, out_features=10, bias=True)
)

结果如下:

参数量和内存比Resnet优,但是时间没有明确展示出来,只是展示时间复杂度;

在这里插入图片描述

运行resnet demo:

import os
import argparse
import logging
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transformsparser = argparse.ArgumentParser()
parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')
parser.add_argument('--tol', type=float, default=1e-3)
parser.add_argument('--adjoint', type=eval, default=False, choices=[True, False])
parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])
parser.add_argument('--nepochs', type=int, default=160)
parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--test_batch_size', type=int, default=1000)parser.add_argument('--save', type=str, default='./experiment1')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
args = parser.parse_args()if args.adjoint:from torchdiffeq import odeint_adjoint as odeint
else:from torchdiffeq import odeintdef conv3x3(in_planes, out_planes, stride=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)def conv1x1(in_planes, out_planes, stride=1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)def norm(dim):return nn.GroupNorm(min(32, dim), dim)class ResBlock(nn.Module):expansion = 1def __init__(self, inplanes, planes, stride=1, downsample=None):super(ResBlock, self).__init__()self.norm1 = norm(inplanes)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.conv1 = conv3x3(inplanes, planes, stride)self.norm2 = norm(planes)self.conv2 = conv3x3(planes, planes)def forward(self, x):shortcut = xout = self.relu(self.norm1(x))if self.downsample is not None:shortcut = self.downsample(out)out = self.conv1(out)out = self.norm2(out)out = self.relu(out)out = self.conv2(out)return out + shortcutclass ConcatConv2d(nn.Module):def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):super(ConcatConv2d, self).__init__()module = nn.ConvTranspose2d if transpose else nn.Conv2dself._layer = module(dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,bias=bias)def forward(self, t, x):tt = torch.ones_like(x[:, :1, :, :]) * tttx = torch.cat([tt, x], 1)return self._layer(ttx)class ODEfunc(nn.Module):def __init__(self, dim):super(ODEfunc, self).__init__()self.norm1 = norm(dim)self.relu = nn.ReLU(inplace=True)self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)self.norm2 = norm(dim)self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)self.norm3 = norm(dim)self.nfe = 0def forward(self, t, x):self.nfe += 1out = self.norm1(x)out = self.relu(out)out = self.conv1(t, out)out = self.norm2(out)out = self.relu(out)out = self.conv2(t, out)out = self.norm3(out)return outclass ODEBlock(nn.Module):def __init__(self, odefunc):super(ODEBlock, self).__init__()self.odefunc = odefuncself.integration_time = torch.tensor([0, 1]).float()def forward(self, x):self.integration_time = self.integration_time.type_as(x)out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol)return out[1]@propertydef nfe(self):return self.odefunc.nfe@nfe.setterdef nfe(self, value):self.odefunc.nfe = valueclass Flatten(nn.Module):def __init__(self):super(Flatten, self).__init__()def forward(self, x):shape = torch.prod(torch.tensor(x.shape[1:])).item()return x.view(-1, shape)class RunningAverageMeter(object):"""Computes and stores the average and current value"""def __init__(self, momentum=0.99):self.momentum = momentumself.reset()def reset(self):self.val = Noneself.avg = 0def update(self, val):if self.val is None:self.avg = valelse:self.avg = self.avg * self.momentum + val * (1 - self.momentum)self.val = valdef get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0):if data_aug:transform_train = transforms.Compose([transforms.RandomCrop(28, padding=4),transforms.ToTensor(),])else:transform_train = transforms.Compose([transforms.ToTensor(),])transform_test = transforms.Compose([transforms.ToTensor(),])train_loader = DataLoader(datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,shuffle=True, num_workers=2, drop_last=True)train_eval_loader = DataLoader(datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test),batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True)test_loader = DataLoader(datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test),batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True)return train_loader, test_loader, train_eval_loaderdef inf_generator(iterable):"""Allows training with DataLoaders in a single infinite loop:for i, (x, y) in enumerate(inf_generator(train_loader)):"""iterator = iterable.__iter__()while True:try:yield iterator.__next__()except StopIteration:iterator = iterable.__iter__()def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):initial_learning_rate = args.lr * batch_size / batch_denomboundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]vals = [initial_learning_rate * decay for decay in decay_rates]def learning_rate_fn(itr):lt = [itr < b for b in boundaries] + [True]i = np.argmax(lt)return vals[i]return learning_rate_fndef one_hot(x, K):return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)def accuracy(model, dataset_loader):total_correct = 0for x, y in dataset_loader:x = x.to(device)y = one_hot(np.array(y.numpy()), 10)target_class = np.argmax(y, axis=1)predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)total_correct += np.sum(predicted_class == target_class)return total_correct / len(dataset_loader.dataset)def count_parameters(model):return sum(p.numel() for p in model.parameters() if p.requires_grad)def makedirs(dirname):if not os.path.exists(dirname):os.makedirs(dirname)def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):logger = logging.getLogger()if debug:level = logging.DEBUGelse:level = logging.INFOlogger.setLevel(level)if saving:info_file_handler = logging.FileHandler(logpath, mode="a")info_file_handler.setLevel(level)logger.addHandler(info_file_handler)if displaying:console_handler = logging.StreamHandler()console_handler.setLevel(level)logger.addHandler(console_handler)logger.info(filepath)with open(filepath, "r") as f:logger.info(f.read())for f in package_files:logger.info(f)with open(f, "r") as package_f:logger.info(package_f.read())return loggerif __name__ == '__main__':makedirs(args.save)logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))logger.info(args)device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')is_odenet = args.network == 'odenet'if args.downsampling_method == 'conv':downsampling_layers = [nn.Conv2d(1, 64, 3, 1),norm(64),nn.ReLU(inplace=True),nn.Conv2d(64, 64, 4, 2, 1),norm(64),nn.ReLU(inplace=True),nn.Conv2d(64, 64, 4, 2, 1),]elif args.downsampling_method == 'res':downsampling_layers = [nn.Conv2d(1, 64, 3, 1),ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),]feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)logger.info(model)logger.info('Number of parameters: {}'.format(count_parameters(model)))criterion = nn.CrossEntropyLoss().to(device)train_loader, test_loader, train_eval_loader = get_mnist_loaders(args.data_aug, args.batch_size, args.test_batch_size)data_gen = inf_generator(train_loader)batches_per_epoch = len(train_loader)lr_fn = learning_rate_with_decay(args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140],decay_rates=[1, 0.1, 0.01, 0.001])optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)best_acc = 0batch_time_meter = RunningAverageMeter()f_nfe_meter = RunningAverageMeter()b_nfe_meter = RunningAverageMeter()end = time.time()for itr in range(args.nepochs * batches_per_epoch):for param_group in optimizer.param_groups:param_group['lr'] = lr_fn(itr)optimizer.zero_grad()x, y = data_gen.__next__()x = x.to(device)y = y.to(device)logits = model(x)loss = criterion(logits, y)if is_odenet:nfe_forward = feature_layers[0].nfefeature_layers[0].nfe = 0loss.backward()optimizer.step()if is_odenet:nfe_backward = feature_layers[0].nfefeature_layers[0].nfe = 0batch_time_meter.update(time.time() - end)if is_odenet:f_nfe_meter.update(nfe_forward)b_nfe_meter.update(nfe_backward)end = time.time()if itr % batches_per_epoch == 0:with torch.no_grad():train_acc = accuracy(model, train_eval_loader)val_acc = accuracy(model, test_loader)if val_acc > best_acc:torch.save({'state_dict': model.state_dict(), 'args': args}, os.path.join(args.save, 'model.pth'))best_acc = val_acclogger.info("Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | ""Train Acc {:.4f} | Test Acc {:.4f}".format(itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,b_nfe_meter.avg, train_acc, val_acc))

训练过程:

Epoch 0000 | Time 3.425 (3.425) | NFE-F 32.0 | NFE-B 0.0 | Train Acc 0.0987 | Test Acc 0.0958
Epoch 0001 | Time 3.279 (0.840) | NFE-F 20.3 | NFE-B 0.0 | Train Acc 0.9755 | Test Acc 0.9779
Epoch 0002 | Time 3.500 (0.839) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9858 | Test Acc 0.9875
Epoch 0003 | Time 3.403 (0.828) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9884 | Test Acc 0.9879
Epoch 0004 | Time 3.303 (0.807) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9926 | Test Acc 0.9921
Epoch 0005 | Time 3.308 (0.801) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9940 | Test Acc 0.9930
Epoch 0006 | Time 3.255 (0.804) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9917 | Test Acc 0.9894
Epoch 0007 | Time 3.376 (0.808) | NFE-F 20.2 | NFE-B 0.0 | Train Acc 0.9948 | Test Acc 0.9929
Epoch 0008 | Time 3.260 (0.806) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9935 | Test Acc 0.9934
Epoch 0009 | Time 3.248 (0.832) | NFE-F 20.4 | NFE-B 0.0 | Train Acc 0.9948 | Test Acc 0.9909
Epoch 0010 | Time 3.286 (0.817) | NFE-F 20.4 | NFE-B 0.0 | Train Acc 0.9959 | Test Acc 0.9947
Epoch 0011 | Time 3.281 (0.827) | NFE-F 20.8 | NFE-B 0.0 | Train Acc 0.9967 | Test Acc 0.9951
Epoch 0012 | Time 3.382 (0.825) | NFE-F 20.9 | NFE-B 0.0 | Train Acc 0.9949 | Test Acc 0.9929
Epoch 0013 | Time 3.299 (0.862) | NFE-F 22.0 | NFE-B 0.0 | Train Acc 0.9976 | Test Acc 0.9949
Epoch 0014 | Time 3.326 (0.824) | NFE-F 20.8 | NFE-B 0.0 | Train Acc 0.9947 | Test Acc 0.9936
Epoch 0015 | Time 3.291 (0.839) | NFE-F 21.3 | NFE-B 0.0 | Train Acc 0.9974 | Test Acc 0.9948
Epoch 0016 | Time 3.467 (0.935) | NFE-F 24.4 | NFE-B 0.0 | Train Acc 0.9977 | Test Acc 0.9941
Epoch 0017 | Time 3.483 (0.900) | NFE-F 23.2 | NFE-B 0.0 | Train Acc 0.9970 | Test Acc 0.9939
Epoch 0018 | Time 3.309 (0.872) | NFE-F 22.2 | NFE-B 0.0 | Train Acc 0.9961 | Test Acc 0.9932
Epoch 0019 | Time 3.294 (0.913) | NFE-F 23.6 | NFE-B 0.0 | Train Acc 0.9974 | Test Acc 0.9954
Epoch 0020 | Time 3.504 (0.984) | NFE-F 25.9 | NFE-B 0.0 | Train Acc 0.9983 | Test Acc 0.9951
Epoch 0021 | Time 3.589 (0.966) | NFE-F 25.2 | NFE-B 0.0 | Train Acc 0.9966 | Test Acc 0.9929
Epoch 0022 | Time 3.503 (0.994) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9977 | Test Acc 0.9949
Epoch 0023 | Time 3.457 (0.995) | NFE-F 26.1 | NFE-B 0.0 | Train Acc 0.9978 | Test Acc 0.9939
Epoch 0024 | Time 3.529 (0.985) | NFE-F 26.0 | NFE-B 0.0 | Train Acc 0.9985 | Test Acc 0.9958
Epoch 0025 | Time 3.459 (0.988) | NFE-F 26.0 | NFE-B 0.0 | Train Acc 0.9973 | Test Acc 0.9947
Epoch 0026 | Time 3.541 (0.988) | NFE-F 26.0 | NFE-B 0.0 | Train Acc 0.9979 | Test Acc 0.9946
Epoch 0027 | Time 3.513 (0.993) | NFE-F 26.1 | NFE-B 0.0 | Train Acc 0.9986 | Test Acc 0.9959
Epoch 0028 | Time 3.505 (0.996) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9982 | Test Acc 0.9953
Epoch 0029 | Time 3.501 (0.990) | NFE-F 26.1 | NFE-B 0.0 | Train Acc 0.9985 | Test Acc 0.9953
Epoch 0030 | Time 3.475 (0.992) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9983 | Test Acc 0.9954
Epoch 0031 | Time 3.506 (0.993) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9986 | Test Acc 0.9947
Epoch 0032 | Time 3.527 (0.995) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9981 | Test Acc 0.9954
Epoch 0033 | Time 3.529 (0.996) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9976 | Test Acc 0.9945
Epoch 0034 | Time 3.545 (0.996) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9988 | Test Acc 0.9959
Epoch 0035 | Time 3.479 (0.995) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9990 | Test Acc 0.9953
Epoch 0036 | Time 3.479 (0.997) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9989 | Test Acc 0.9963
Epoch 0037 | Time 3.540 (0.998) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9988 | Test Acc 0.9957

参考:

1、神经常微分方程 (Neural ODE):入门教程 - 知乎 (zhihu.com)

2、rtqichen/torchdiffeq: Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation. (github.com)