批处理数据Dataset+DataLoader使用介绍【Pytorch】
目录
写在前面{\\color{Purple}写在前面}写在前面
由于我们使用的数据可能是多且杂乱的,为了更有效的处理数据同时也方便之后的使用,Pytorch提供了Dataset和DataLoader来帮助我们批量处理数据。
各自的作用{\\color{Purple}各自的作用}各自的作用
∙\\bullet∙ Dataset:主要实现按索引访问对应的数据以及标签
∙\\bullet∙ DataLoader:主要将数据划分batch方便之后训练
引入相应的库{\\color{Purple}引入相应的库}引入相应的库
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
使用Dataset{\\color{Purple}使用Dataset}使用Dataset
自定义实现{\\color{Purple}自定义实现}自定义实现
∙\\bullet∙主要包含三个函数:__init__, __len__, __getitem__.
函数 | 功能 |
---|---|
init | 初始化函数,为自定义类设置成员变量 |
len | 返回样本个数 |
getitem | 核心函数,实现按索引返回数据及标签 |
∙\\bullet∙代码实现 (简洁版:以图像数据为例)
#继承Dataset类,实现自己的tensorDataset类
class tensorDataset(Dataset):'''Inputs:- images: a [Batch size, Channels, Height, Width] tensor- labels: a 1-dimensional tensor corresponding to image labels'''def __init__(self, images, labels, train:bool):#灰度图,将像素压缩到0~1之间self.images = images/255self.labels = labels#魔法方法,实例化对象后,可以支持下标索引,通过下标来读取对应的图片像素与标签def __getitem__(self, index):# Load the image (as tensor)img = self.images[index]label = self.labels[index]return img, labeldef __len__(self):return len(self.images)
使用DataLoader{\\color{Purple}使用DataLoader}使用DataLoader
快速使用{\\color{Purple}快速使用}快速使用
∙\\bullet∙其中batch_size表示按多少个样本进行一次划分,shuffle表示是否打乱划分
∙\\bullet∙比如有1、2、3、4四个样本,batch_size取2,当shuffle = false时,就一定会划分为(1,2)和(3,4);而shuffle = true时,就是划分为任意组合
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
自定义实现{\\color{Purple}自定义实现}自定义实现
可以结合自定义的Dataset一起使用
def create_dataloaders(batch_size, X_train, y_train, X_val, y_val, X_test=None, y_test=None):#先使用Dataset进行数据处理train_dataset = tensorDataset(X_train, y_train, train=True)val_dataset = tensorDataset(X_val, y_val, train=False)#使用Dataloader进行处理,得到训练集和验证集的dataloadertrain_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)val_dataloader = DataLoader(val_dataset , batch_size=batch_size, shuffle=False)#自定义是否也处理测试集if X_test is not None:test_dataset = tensorDataset(X_test, y_test, train=False)test_dataloader = DataLoader(test_dataset , batch_size=batch_size, shuffle=False)return train_dataloader, val_dataloader, test_dataloaderreturn train_dataloader, val_dataloader