> 文章列表 > pytorch lightning之验证与测试

pytorch lightning之验证与测试

pytorch lightning之验证与测试

训练

训练部分已在《入门篇》介绍。

验证集和测试集中评估模型

通常将数据集分为三部分,train/val/test,val集在训练时评估模型的泛化性,选择其中表现最好的checkpoint。test集只在模型训练完成后使用,用于评估模型的真实性能。

添加test流程

划分数据集

以下代码使用torchvision包内实现的MNIST。如果使用自定义的数据集,先用pytorch实现Dataset子类,再继承pl.LightningDataModule类,实现相应接口。

import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms# Load data sets
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)

实现test_step()接口

在trainer.test()阶段会自动调用test_step方法,根据需要内部可以增加保存图片、评估模型等功能。

class LitAutoEncoder(pl.LightningModule):def training_step(self, batch, batch_idx):...def test_step(self, batch, batch_idx):# this is the test loopx, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)test_loss = F.mse_loss(x_hat, x)self.log("test_loss", test_loss)

测试

模型训练完成后,即可调用test()方法进入测试流程

from torch.utils.data import DataLoader# initialize the Trainer
trainer = Trainer()# 训练模型
trainer.fit(model, data)# 训练完成后测试
trainer.test(model, dataloaders=DataLoader(test_set))

验证阶段validation的流程

与test 流程类似,实现validation_step()接口,可以配合on_validation_epoch_end()方法在计算所有样例后评估模型。

class LitAutoEncoder(pl.LightningModule):def training_step(self, batch, batch_idx):...def validation_step(self, batch, batch_idx):# this is the validation loopx, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)val_loss = F.mse_loss(x_hat, x)self.log("val_loss", val_loss)self.metric.update(x_hat, y)# metric是任务相关的评价方法,比如更新混淆矩阵def on_validation_epoch_step(self, batch, batch_idx):# 从混淆矩阵中计算tp,fp, tn, fn, acc, F1等指标 score = self.metric.get_scores()# 记录,横坐标为epochself.log('val/F1', score['F1'], logger=True, on_epoch=True)

预测predict流程

实现predict_step方法,然后调用trainer.predict()

其它HOOK见LightningModule,了解LightningModule的接口基本就会用pl了。