> 文章列表 > Pytorch构建自己的数据集

Pytorch构建自己的数据集

Pytorch构建自己的数据集

1.Pytorch内置的Dataset

Pytorch中内置了许多数据集,我们可以从torchvision库中进行导入。比如,我们可以导入Fashion-MNIST数据集

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)

但如果torchvision库中没有该数据集,我们需要自己构建一个。
其中一个方法就是把构建好的数据集使用torch.utils.data.TensorDataset()封装以下,然后再传入torch.utils.data.DataLoader

trainloader  =  torch.utils.data.DataLoader(training_data, batch_size=32, shuffle=True)
testloader  =  torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False)

但是如果自己写一个类的话会高达上一些,嘻嘻。下面看看如何自己构建一个Dataset Class。

2.Build Custom Dataset

构建一个Custom Dataset需要继承``三个函数__init__, __len__, 和 __getitem__

  • __init__: 对类进行初始化
  • __len__: 使该类可以返回dataset样本数量
  • __getitem__: 给定一个idx,从数据集中导入并返回一个样本

下面我们来看看该如何构建Custom Dataset:

import os
import pandas as pd
from torchvision.io import read_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file) # load labelself.img_dir = img_dirself.transform = transform # transformationself.target_transform = target_transformdef __len__(self):return len(self.img_labels) # 返回sample的个数def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path) # load idx-th imagelabel = self.img_labels.iloc[idx, 1] # load idx-th labelif self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

注意:同时,__len__控制着产生样本的总个数。例如,如果总共有20个样本,我们希望20个样本全都放入dataloader中,则:

def __len(self):return 20

但如果我们只希望有20个样本中的15个放入到dataloader中,则:

def __len(self):return 15

但值得注意的是,return返回的数不能大于样本的总个数,即要小于等于20。并且,当返回的数小于总样本个数的时候,是取索引的前几个数,最后的几个数不会被放入dataloader中。比如return 15,是将data[:15]个数放入dataloader,而后5个数要舍去。可以用如下代码验证:

>>> data = np.arange(15).reshape(5,3)
>>> print(data)
array([[ 0, 1, 2],[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11],[12, 13, 14]])
>>> class Data(Dataset):
...		def __init__(self, data) -> None:
...			super(Data, self).__init__()
...			self.data = data
...		def __len__(self):
...			return 4
...		def __getitem__(self, index):
...			out = self.data[index]
...			return torch.from_numpy(out)
>>> loader = DataLoader(Data(data), batch_size=4, shuffle=True)
>>> for i, x in enumerate(loader):
...		print(i, x)0 tensor([[ 3, 4, 5],[ 9, 10, 11],[ 0, 1, 2],[ 6, 7, 8]])

可以发现,无论如何都不会输出[12, 13, 14]

Reference:
Pytorch official tutorial
Writing custom datasets dataloaders and transforms