Day3 自学Pytorch 数据集 torchvision.transforms类&torchvision.datasets.ImageFolder类
1.torchvision.transforms类
可调用的函数列表https://pytorch.org/vision/stable/transforms.html
介绍几个常用的函数:
① transforms.Resize()
将图像转换成目标大小
参数列表:
- size (sequence or int): (h,w)目标图像的大小,若只输入一个数字i,默认(i,i)
- interpolation (InterpolationMode, optional):插值方法
- max_size (int,optional):允许调整大小的图像的较长边缘的最大值。如果图像的较长边缘在根据大小调整后大于max_size,那么图像将再次调整大小,使较长边缘等于max_size。
- antialias (bool, optional) :(True/ False/None)是否应用抗锯齿。
②transforms.CenterCrop()
从图像的中心裁剪为目标大小
参数列表:
- size (sequence or int): (h,w)目标图像的大小,若只输入一个数字i,默认(i,i)
③transforms.ToTensor()
将PIL图像或narray转换为张量并相应地缩放值。
无参
③transforms.Normalize()
用均值和标准差归一化张量图像。此转换不支持PIL图像。
参数列表:
- mean (sequence) – Sequence of means for each channel.每个通道的均值
- std (sequence) – Sequence of standard deviations for each channel.每个通道的标准差
- inplace (bool,optional) – Bool to make this operation in-place.
实例:
# 单通道图像归一化
transforms.Normalize([0.485], [0.229])
# 三通道图像归一化
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
2.torchvision.datasets.ImageFolder类
① 构造函数:
torchvision.datasets.ImageFolder(root,transform,target_transform,loader,is_valid_file)
-
root:数据集所在路径
默认的数据集格式如下:
此时,root = r’/root/’ (r防止转义) -
transform(可选):图像预处理操作
torchvision.transforms类
-
target_transform(可选):对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
-
loader(可选):表示数据集加载方式,通常默认加载方式即可。
-
is_valid_file :检查图像文件是否有效(用于检查损坏的文件)
-
实例:
数据集目录结构
data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}train_dataset= torchvision.datasets.ImageFolder(root=r'/root/src/data/oneimg_v_t/train/' ,transform=data_transform["train"])
val_dataset = torchvision.datasets.ImageFolder(root=r'/root/src/data/oneimg_v_t/val/' ,transform=data_transform["val"])
②属性:
- classes (list): List of the class names sorted alphabetically.
类别名称列表 - class_to_idx (dict): Dict with items (class_name, class_index).
一个字典,类别名称:类别索引 - imgs (list): List of (image path, class_index) tuples
一个tuples的列表,返回(图像路径,类别索引) - 实例:
数据集目录结构
from torchvision import datasets
dataset = datasets.ImageFolder(root=r'F:\\DataSetOneDir\\\\')
print(dataset.classes)
print(dataset.class_to_idx)
print(dataset.imgs)