Pytorch构建自己的数据集
1.Pytorch内置的Dataset
Pytorch中内置了许多数据集,我们可以从torchvision
库中进行导入。比如,我们可以导入Fashion-MNIST数据集
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)
但如果torchvision
库中没有该数据集,我们需要自己构建一个。
其中一个方法就是把构建好的数据集使用torch.utils.data.TensorDataset()
封装以下,然后再传入torch.utils.data.DataLoader
trainloader = torch.utils.data.DataLoader(training_data, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False)
但是如果自己写一个类的话会高达上一些,嘻嘻。下面看看如何自己构建一个Dataset Class。
2.Build Custom Dataset
构建一个Custom Dataset需要继承``三个函数__init__
, __len__
, 和 __getitem__
。
__init__
: 对类进行初始化__len__
: 使该类可以返回dataset样本数量__getitem__
: 给定一个idx
,从数据集中导入并返回一个样本
下面我们来看看该如何构建Custom Dataset:
import os
import pandas as pd
from torchvision.io import read_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file) # load labelself.img_dir = img_dirself.transform = transform # transformationself.target_transform = target_transformdef __len__(self):return len(self.img_labels) # 返回sample的个数def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path) # load idx-th imagelabel = self.img_labels.iloc[idx, 1] # load idx-th labelif self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label
注意:同时,__len__
控制着产生样本的总个数。例如,如果总共有20个样本,我们希望20个样本全都放入dataloader中,则:
def __len(self):return 20
但如果我们只希望有20个样本中的15个放入到dataloader中,则:
def __len(self):return 15
但值得注意的是,return
返回的数不能大于样本的总个数,即要小于等于20。并且,当返回的数小于总样本个数的时候,是取索引的前几个数,最后的几个数不会被放入dataloader中。比如return 15
,是将data[:15]个数放入dataloader,而后5个数要舍去。可以用如下代码验证:
>>> data = np.arange(15).reshape(5,3)
>>> print(data)
array([[ 0, 1, 2],[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11],[12, 13, 14]])
>>> class Data(Dataset):
... def __init__(self, data) -> None:
... super(Data, self).__init__()
... self.data = data
... def __len__(self):
... return 4
... def __getitem__(self, index):
... out = self.data[index]
... return torch.from_numpy(out)
>>> loader = DataLoader(Data(data), batch_size=4, shuffle=True)
>>> for i, x in enumerate(loader):
... print(i, x)0 tensor([[ 3, 4, 5],[ 9, 10, 11],[ 0, 1, 2],[ 6, 7, 8]])
可以发现,无论如何都不会输出[12, 13, 14]
。
Reference:
Pytorch official tutorial
Writing custom datasets dataloaders and transforms