Pytorch学习笔记(二):载入数据
本篇笔记主要记录了Pytorch项目程序中作为第一步的“载入数据”的常用代码、注解和心得,后续遇到更新的表达方式之后会持续更新。
部分学习资源链接:
- Pytorch官方在优达学城公布的教学视频:
https://classroom.udacity.com/courses/ud188 - 与其配套的代码和练习答案:
https://github.com/udacity/deep-learning-v2-pytorch - 陈云《深度学习框架Pytorch入门与实践》:
https://github.com/chenyuntc/pytorch-book - 《Pytorch模型训练实用教程》:
https://github.com/tensor-yu/PyTorch_Tutorial
之前的相关文章:
数据处理
在解决深度学习问题的过程中,往往需要花费大量的精力去处理数据。数据的处理对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练,更会提高模型效果。考虑到这点,PyTorch提供了几个高效便捷的工具,以便使用者进行数据处理或增强等操作,同时可通过并行化加速数据加载。
加载数据的方法
在PyTorch中,数据加载可通过自定义的数据集对象。数据集对象被抽象为 Dataset 类,实现自定义的数据集需要继承 Dataset,并实现两个Python魔法方法(关于魔术方法详细解释可以参考:《python的魔法方法是什么?》,同时也可以参考魔法方法解释最详细的论文《MagicMethods》——RafeKettler):
- __getitem__:返回一条数据,或一个样本。obj[index] 等价于obj.__getitem__(index)
- __len__:返回样本的数量。len(obj) 等价于obj.__len__()
这里重点看 __getitem__ 函数,__getitem__ 接收一个index,然后返回图片数据和标签,这个 index 通常指的是一个 list 的 index ,这个 list 的每个元素就包含了图片数据的路径和标签信息。
然而,如何制作这个 list 呢?通常的方法是将图片的路径和标签信息存储在一个 txt 中,然后从该 txt 中读取。
那么读取自己数据的基本流程就是:
- 制作存储了图片的路径和标签信息的txt
- 将这些信息转化为 list,该 list 每一个元素对应一个样本
- 通过 __getitem__ 函数,读取数据和标签,并返回数据和标签
在训练代码里是感觉不到这些操作的,只会看到通过 DataLoader 就可以获取一个 batch 的数据,其实触发去读取图片这些操作的是 DataLoader 里的 __iter__(self) ,后面会详细说明。
因此,要让PyTorch能读取自己的数据集,只需要两步:
- 制作图片数据的索引
- 构建Dataset子类
制作图片数据的索引
这个比较简单,就是读取图片路径,标签,保存到txt文件中,这里注意格式就好。特别注意的是,txt中的路径,是以训练时的那个 py 文件所在的目录为工作目录,所以这里需要提前算好相对路径!
# coding:utf-8
import os
'''
为数据集生成对应的txt文件
'''
train_txt_path = os.path.join("..", "..", "Data", "train.txt")
train_dir = os.path.join("..", "..", "Data", "train")
valid_txt_path = os.path.join("..", "..", "Data", "valid.txt")
valid_dir = os.path.join("..", "..", "Data", "valid")
def gen_txt(txt_path, img_dir):
f = open(txt_path, 'w')
for root, s_dirs, _ in os.walk(img_dir, topdown=True): # 获取 train文件下各文件夹名称
for sub_dir in s_dirs:
i_dir = os.path.join(root, sub_dir) # 获取各类的文件夹 绝对路径
img_list = os.listdir(i_dir) # 获取类别文件夹下所有png图片的路径
for i in range(len(img_list)):
if not img_list[i].endswith('png'): # 若不是png文件,跳过
continue
label = img_list[i].split('_')[0]
img_path = os.path.join(i_dir, img_list[i])
line = img_path + ' ' + label + '\n'
f.write(line)
f.close()
if __name__ == '__main__':
gen_txt(train_txt_path, train_dir)
gen_txt(valid_txt_path, valid_dir)
运行代码,即会在/Data/文件夹下面看到 train.txt valid.txt 。txt中是这样的:

构建Dataset子类
# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, txt_path, transform=None, target_transform=None):
fh = open(txt_path, 'r')
imgs = []
for line in fh:
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))
self.imgs = imgs # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
fn, label = self.imgs[index]
img = Image.open(fn).convert('RGB') # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.imgs)
首先看看初始化,初始化中从我们准备好的txt里获取图片的路径和标签,并且存储在 self.imgs ,self.imgs就是上面提到的 list ,其一个元素对应一个样本的路径和标签,其实就是 txt 中的一行。
初始化中还会初始化 transform ,transform 是一个 Compose 类型,里边有一个 list ,list 中就会定义了各种对图像进行处理的操作,可以设置减均值,除标准差,随机裁剪,旋转,翻转,仿射变换等操作。
在这里我们可以知道,一张图片读取进来之后,会经过数据处理(数据增强),最终变成输入模型的数据。这里就有一点需要注意,PyTorch 的数据增强是将原始图片进行了处理,并不会生成新的一份图片,而是“覆盖”原图,当采用 randomcrop 之类的随机操作时,每个 epoch 输入进来的图片几乎不会是一模一样的,这达到了样本多样性的功能。
然后看看核心的 __getitem__ 函数:
第一行:self.imgs 是一个 list ,也就是一开始提到的 list ,self.imgs 的一个元素是一个 str ,包含图片路径,图片标签,这些信息是从 txt 文件中读取。
第二行:利用 Image.open 对图片进行读取,img 类型为 Image ,mode=‘RGB’
第三行与第四行: 对图片进行处理,这个 transform 里边可以实现 减均值,除标准差,随机裁剪,旋转,翻转,放射变换,等等操作。
当 Mydataset 构建好,剩下的操作就交给 DataLoder ,在 DataLoder 中,会触发 Mydataset 中的 __getitem__ 函数读取一张图片的数据和标签,并拼接成一个 batch 返回,作为模型真正的输入。下面将会通过一个小例子,说明 DataLoder 是如何获取一个 batch ,以及一张图片是如何被 PyTorch 读取,最终变为模型的输入的。
例子就是上一篇笔记的举例代码。
大体流程:
- main.py: train_data = MyDataset(txt_path=train_txt_path, …) --->
- main.py: train_loader = DataLoader(dataset=train_data, …) --->
- main.py: for i, data in enumerate(train_loader, 0) --->
- dataloder.py: class DataLoader(): def __iter__(self): return _DataLoaderIter(self) --->
- dataloder.py: class _DataLoderIter(): def __next__(self): batch = self.collate_fn([self.dataset[i] for i in indices]) --->
- tool.py: class MyDataset(): def __getitem__(): img = Image.open(fn).convert('RGB') --->
- tool.py: class MyDataset(): img = self.transform(img) --->
- main.py: inputs, labels = data inputs, labels = Variable(inputs), Variable(labels) outputs = net(inputs)
一开始通过 MyDataset 创建一个实例,在该实例中有路径,有读取图片的方法(函数)。
然后需要 Pytroch 的一系列规范化流程,在第6步中,才会调用 MyDataset 中的 __getitem__() 函数,最终通过 Image.open() 读取图片数据。
然后对原始图片数据进行一系列预处理( transform 中设置),最后回到 main.py,对数据进行转换成 Variable 类型,最终成为模型的输入。
流程详细描述:
- 从 MyDataset 类中初始化 txt,txt 中有图片路径和标签
- 初始化 DataLoder 时,将 train_data 传入,从而使 DataLoder 拥有图片的路径
- 在一个 iteration 进行时,才读取一个 batch 的图片数据 enumerate()函数会返回可迭代数据的一个“元素”。在这里 data 是一个 batch 的图片数据和标签,data是一个list。
- class DataLoader() 中再调用 class _DataLoderIter()
- 在 _DataLoderiter() 类中会跳到 __next(self)__ 函数,在该函数中会通过 indices = next(self.sample_iter) 获取一个 batch 的 indices ,再通过 batch = self.collate_fn([self.dataset[i] for i in indices]) 获取一个 batch 的数据在 batch = self.collate_fn([self.dataset[i] for i in indices]) 中会调用 self.collate_fn 函数
- self.collate_fn 中会调用 MyDataset 类中的 __getitem()__ 函数,在 __getitem()__ 中通过 Image.open(fn).convert('RGB') 读取图片
- 通过 Image.open(fn).convert('RGB') 读取图片之后,会对图片进行预处理,例如减均值,除以标准差,随机裁剪等等一系列提前设置好的操作。最后返回 img,label,再通过 self.collate_fn 来拼接成一个 batch 。一个 batch 是一个 list ,有两个元素,第一个元素是图片数据,是一个 4D 的 Tensor ,shape 为(64,3,32,32),第二个元素是标签 shape 为(64)。
- 将图片数据转换成 Variable 类型,然后称为模型真正的输入
inputs, labels = Variable(inputs), Variable(labels)
outputs = net(inputs)
通过了解图片从硬盘到模型的过程,可以更好的对数据做处理(减均值,除以标准差,裁剪,翻转,放射变换等等),也可以灵活的为模型准备数据,最后总结两个需要注意的地方。
- 图片是通过Image.open()函数读取进来的,当涉及如下问题:
图片的通道顺序(RGB ? BGR ?)
图片是w*h*c ? c*w*h ?
像素值范围[0-1] or [0-255] ?
就要查看 MyDataset() 类中 __getitem__()下读取图片用的是什么方法 - 从 MyDataset() 类中 __getitem__()函数中发现,PyTorch 做数据增强的方法是在原始图片上进行的,并覆盖原始图片,这一点需要注意。
数据增强与数据标准化
在实际应用过程中,我们会在数据进入模型之前进行一些预处理,例如数据中心化(仅减均值),数据标准化(减均值,再除以标准差),随机裁剪,旋转一定角度,镜像等一系列操作。
在PyTorch中,这些数据增强方法放在了 transforms.py 文件中。这些数据处理可以满足我们大部分的需求,通过熟悉 transforms.py ,也可以自定义数据处理函数,实现自己的数据增强。
先在这里从宏观地说明 transform 的使用,不久之后我会总结 transform 的详细用法。参考代码依然是上一篇笔记的示例代码,如下。
# 数据预处理设置
normMean = [0.4948052, 0.48568845, 0.44682974]
normStd = [0.24580306, 0.24236229, 0.2603115]
normTransform = transforms.Normalize(normMean, normStd)
trainTransform = transforms.Compose([
transforms.Resize(32),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
normTransform
])
validTransform = transforms.Compose([
transforms.ToTensor(),
normTransform
])
前三行设置均值,标准差,以及数据标准化:transforms.Normalize() 函数,这里是以通道为单位进行计算均值,标准差。
然后用 transforms.Compose 将所需要进行的处理给组合起来,并且需要注意顺序!
在训练时,依次对图片进行以下操作:
- 随机裁剪:在裁剪之前先对图片的上下左右均填充上4个pixel,值为0,即变成一个40*40的数据,然后再随机进行32*32的裁剪。
- Totensor:在这里会对数据进行transpose,原来是hwc,会经过img = img.transpose(0, 1).transpose(0, 2).contiguous(),变成chw再除以255,使得像素值归一化至[0-1]之间
- 数据标准化(减均值,除以标准差):关于标准化的分析文章后面有详细说明。
数据加载实例分析
接下来以Kaggle的“Dog vs. Cat”数据为例再示范一些加载数据的方法,"Dogs vs. Cats"是一个分类问题,判断一张图片是狗还是猫,其所有图片都存放在一个文件夹下,根据文件名的前缀判断是狗还是猫。
首先先在程序中以树形图显示文件列表:
%env LS_COLORS = None
!tree ".\data\dogcat" /F
#输出如下 env: LS_COLORS=None data/dogcat/ |-- cat.12484.jpg |-- cat.12485.jpg |-- cat.12486.jpg |-- cat.12487.jpg |-- dog.12496.jpg |-- dog.12497.jpg |-- dog.12498.jpg `-- dog.12499.jpg 0 directories, 8 files
这里需要注意的是第二行代码与陈云给出的代码不一样,原因是因为陈云使用的是Linux系统,陈云原本的代码在Windows上运行会报错,Windows运行以上即可。
接下来定义一下自己的数据集,并依次获取图片的数据。
import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np
class DogCat(data.Dataset):
def __init__(self, root):
imgs = os.listdir(root)
# 所有图片的绝对路径
# 这里不实际加载图片,只是指定路径,当调用__getitem__时才会真正读图片
self.imgs = [os.path.join(root, img) for img in imgs]
def __getitem__(self, index):
img_path = self.imgs[index]
# dog->1, cat->0
label = 1 if 'dog' in img_path.split('/')[-1] else 0
pil_img = Image.open(img_path)
array = np.asarray(pil_img)
data = t.from_numpy(array)
return data, label
def __len__(self):
return len(self.imgs)
dataset = DogCat('./data/dogcat/')
img, label = dataset[0] # 相当于调用dataset.__getitem__(0)
for img, label in dataset:
print(img.size(), img.float().mean(), label)
#输出如下 torch.Size([375, 499, 3]) tensor(116.8187) 1 torch.Size([499, 379, 3]) tensor(171.8088) 0 torch.Size([236, 289, 3]) tensor(130.3022) 0 torch.Size([377, 499, 3]) tensor(151.7141) 1 torch.Size([374, 499, 3]) tensor(115.5157) 0 torch.Size([375, 499, 3]) tensor(150.5086) 1 torch.Size([400, 300, 3]) tensor(128.1548) 1 torch.Size([500, 497, 3]) tensor(106.4917) 0
通过以上代码我们可以发现两个主要问题:
- 返回样本的形状不一,因每张图片的大小不一样,这对于需要取batch训练的神经网络来说很不友好
- 返回样本的数值较大,未归一化至[-1, 1]
所以这种方法在实际使用中其实并不常用。
针对上述问题,PyTorch提供了 torchvision 。它是一个视觉工具包,提供了很多视觉图像处理的工具,其中 transforms 模块提供了对 PIL Image 对象和 Tensor 对象的常用操作。
对PIL Image的操作包括:
- Scale:调整图片尺寸,长宽比保持不变
- CenterCrop、RandomCrop、RandomResizedCrop: 裁剪图片
- Pad:填充
- ToTensor:将PIL Image对象转成Tensor,会自动将[0, 255]归一化至[0, 1]
对Tensor的操作包括:
- Normalize:标准化,即减均值,除以标准差
- ToPILImage:将Tensor转为PIL Image对象
我曾经在代码中总是无法理解 Normalize( mean , std ) 这个函数的意义,后来经过搜索,我发现标准化函数的计算其实非常简单:
\[ image = \frac{image-mean}{std}\]
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1, 1]
而我们时常看见将标准化平均值和标准差设置为0.5的代码,其注释都会表明将图片标准化至 [-1,1] ,就是因为我们原本的图片归一化值均在[0,1]之间,最小值0经过运算 \(\frac{0-0.5}{0.5}=-1\) ,而1经过运算 \(\frac{1-0.5}{0.5}=1\) ,进而将图片标准化为了[-1,1]。
如果要对图片进行多个操作,可通过 Compose 函数将这些操作拼接起来,类似于 nn.Sequential 。注意,这些操作定义后是以函数的形式存在,真正使用时需调用它的 __call__ 方法,这点类似于 nn.Module 。例如要将图片调整为224×224,首先应构建这个操作 trans = Resize((224, 224)) ,然后调用 trans(img) 。下面我们就用 transforms 的这些操作来优化上面实现的dataset。
import os
from PIL import Image
import numpy as np
from torchvision import transforms as T
transform = T.Compose([
T.Resize(224), # 缩放图片(Image),保持长宽比不变,最短边为224像素
T.CenterCrop(224), # 从图片中间切出224*224的图片
T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1, 1],规定均值和标准差
])
class DogCat(data.Dataset):
def __init__(self, root, transforms=None):
imgs = os.listdir(root)
self.imgs = [os.path.join(root, img) for img in imgs]
self.transforms=transforms
def __getitem__(self, index):
img_path = self.imgs[index]
label = 0 if 'dog' in img_path.split('/')[-1] else 1
data = Image.open(img_path)
if self.transforms:
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.imgs)
dataset = DogCat('./data/dogcat/', transforms=transform)
img, label = dataset[0]
for img, label in dataset:
print(img.size(), label)
#输出如下 torch.Size([3, 224, 224]) 0 torch.Size([3, 224, 224]) 1 torch.Size([3, 224, 224]) 1 torch.Size([3, 224, 224]) 0 torch.Size([3, 224, 224]) 1 torch.Size([3, 224, 224]) 0 torch.Size([3, 224, 224]) 0 torch.Size([3, 224, 224]) 1
除了上述操作之外,transforms 还可通过 Lambda 封装自定义的转换策略。例如想对PIL Image进行随机旋转,则可写成这样 trans=T.Lambda(lambda img: img.rotate(random()*360)) 。
torchvision 已经预先实现了常用的 Dataset ,包括前面使用过的 CIFAR-10,以及 ImageNet、COCO、MNIST、LSUN 等数据集,可通过诸如 torchvision.datasets.CIFAR10 来调用,具体使用方法请参看官方文档。
在这里介绍一个会经常使用到的 Dataset——ImageFolder,它的实现和上述的DogCat很相似。ImageFolder 假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下:
ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
它主要有四个参数:
- root:在 root 指定的路径下寻找图片
- transform:对PIL Image进行的转换操作,transform 的输入是使用loader读取图片的返回对象
- target_transform:对 label 的转换
- loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象
label 是按照文件夹名顺序排序后存成字典,即{类名:类序号(从0开始)},一般来说最好直接将文件夹命名为从0开始的数字,这样会和 ImageFolder 实际的 label 一致,如果不是这种命名规范,建议看看 self.class_to_idx 属性以了解 label 和文件夹名的映射关系。
接下来看一下 self.class_to_idx 的使用例子。
!tree ".\data\dogcat_2" /F
#输出 data/dogcat_2/ |-- cat | |-- cat.12484.jpg | |-- cat.12485.jpg | |-- cat.12486.jpg | `-- cat.12487.jpg `-- dog |-- dog.12496.jpg |-- dog.12497.jpg |-- dog.12498.jpg `-- dog.12499.jpg 2 directories, 8 files
from torchvision.datasets import ImageFolder
dataset = ImageFolder('data/dogcat_2/')
# cat文件夹的图片对应label 0,dog对应1
dataset.class_to_idx
#输出 {'cat': 0, 'dog': 1}
#所有图片的路径和对应的label
dataset.imgs
#输出 [('data/dogcat_2/cat/cat.12484.jpg', 0), ('data/dogcat_2/cat/cat.12485.jpg', 0), ('data/dogcat_2/cat/cat.12486.jpg', 0), ('data/dogcat_2/cat/cat.12487.jpg', 0), ('data/dogcat_2/dog/dog.12496.jpg', 1), ('data/dogcat_2/dog/dog.12497.jpg', 1), ('data/dogcat_2/dog/dog.12498.jpg', 1), ('data/dogcat_2/dog/dog.12499.jpg', 1)]
# 没有任何的transform,所以返回的还是PIL Image对象
dataset[0][1] # 第一维是第几张图,第二维为1返回label
dataset[0][0] # 为0返回图片数据
# 加上transform
normalize = T.Normalize(mean=[0.4, 0.4, 0.4], std=[0.2, 0.2, 0.2])
transform = T.Compose([
T.RandomResizedCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize,
])
dataset = ImageFolder('data/dogcat_2/', transform=transform)
# 深度学习中图片数据一般保存成CxHxW,即通道数x图片高x图片宽
dataset[0][0].size()
#输出 torch.Size([3, 224, 224])
to_img = T.ToPILImage()
# 0.2和0.4是标准差和均值的近似
to_img(dataset[0][0]*0.2+0.4)
Dataset只负责数据的抽象,一次调用 __getitem__ 只返回一个样本。前面提到过,在训练神经网络时,最好是对一个 batch 的数据进行操作,同时还需要对数据进行 shuffle 和并行加速等。对此,PyTorch 提供了 DataLoader 帮助我们实现这些功能。
DataLoader 的函数定义如下:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
- dataset:加载的数据集(Dataset对象)
- batch_size:batch size
- shuffle:是否将数据打乱
- sampler: 样本抽样,后续会详细介绍
- num_workers:使用多进程加载的进程数,0代表不使用多进程
- collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可
- pin_memory:是否将数据保存在 pin memory 区,pin memory 中的数据转到GPU会快一些
- drop_last:dataset中的数据个数可能不是 batch_size 的整数倍, drop_last 为True会将多出来不足一个batch的数据丢弃
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)
dataiter = iter(dataloader)
imgs, labels = next(dataiter)
imgs.size() # batch_size, channel, height, weight
#输出 torch.Size([3, 3, 224, 224])
dataloader 是一个可迭代的对象,意味着我们可以像使用迭代器一样使用它,例如:
for batch_datas, batch_labels in dataloader:
train()
或者:
dataiter = iter(dataloader)
batch_datas, batch_labesl = next(dataiter)
在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在 __getitem__ 函数中将出现异常,此时最好的解决方案即是将出错的样本剔除。如果实在是遇到这种情况无法处理,则可以返回 None 对象,然后在 Dataloader 中实现自定义的 collate_fn ,将空对象过滤掉。但要注意,在这种情况下 dataloader 返回的 batch 数目会少于 batch_size 。
class NewDogCat(DogCat): # 继承前面实现的DogCat数据集
def __getitem__(self, index):
try:
# 调用父类的获取函数,即 DogCat.__getitem__(self, index)
return super(NewDogCat,self).__getitem__(index)
except:
return None, None
from torch.utils.data.dataloader import default_collate # 导入默认的拼接方式
def my_collate_fn(batch):
'''
batch中每个元素形如(data, label)
'''
# 过滤为None的数据
batch = list(filter(lambda x:x[0] is not None, batch))
if len(batch) == 0: return t.Tensor()
return default_collate(batch) # 用默认方式拼接过滤后的batch数据
dataset = NewDogCat('data/dogcat_wrong/', transforms=transform)
dataset[5]
(tensor([[[ 3.0000, 3.0000, 3.0000, ..., 2.5490, 2.6471, 2.7059], [ 3.0000, 3.0000, 3.0000, ..., 2.5686, 2.6471, 2.7059], [ 2.9804, 3.0000, 3.0000, ..., 2.5686, 2.6471, 2.6863], ..., [-0.9020, -0.8627, -0.8235, ..., 0.0392, 0.2745, 0.4314], [-0.9216, -0.8824, -0.8235, ..., 0.0588, 0.2745, 0.4314], [-0.9412, -0.9020, -0.8431, ..., 0.0980, 0.3137, 0.4510]], [[ 3.0000, 2.9608, 2.8824, ..., 2.5294, 2.6275, 2.6863], [ 3.0000, 2.9804, 2.9216, ..., 2.5098, 2.6275, 2.6863], [ 3.0000, 2.9804, 2.9608, ..., 2.5098, 2.6078, 2.6667], ..., [-1.0392, -1.0000, -0.9608, ..., -0.0196, 0.1961, 0.3333], [-1.0392, -1.0000, -0.9608, ..., 0.0000, 0.1961, 0.3333], [-1.0392, -1.0196, -0.9804, ..., 0.0392, 0.2353, 0.3529]], [[ 2.1765, 2.1961, 2.2157, ..., 1.5882, 1.6863, 1.7451], [ 2.2157, 2.2157, 2.2353, ..., 1.6078, 1.7255, 1.7843], [ 2.2157, 2.2353, 2.2549, ..., 1.6667, 1.7647, 1.8235], ..., [-1.4118, -1.3725, -1.3333, ..., -0.5098, -0.3922, -0.2941], [-1.4118, -1.3725, -1.3333, ..., -0.4902, -0.3725, -0.2941], [-1.4314, -1.3922, -1.3529, ..., -0.4510, -0.3333, -0.2745]]]), 0)
dataloader = DataLoader(dataset, 2, collate_fn=my_collate_fn, num_workers=1,shuffle=True)
for batch_datas, batch_labels in dataloader:
print(batch_datas.size(),batch_labels.size())
torch.Size([2, 3, 224, 224]) torch.Size([2]) torch.Size([1, 3, 224, 224]) torch.Size([1]) torch.Size([2, 3, 224, 224]) torch.Size([2]) torch.Size([2, 3, 224, 224]) torch.Size([2]) torch.Size([1, 3, 224, 224]) torch.Size([1])
来看一下上述 batch_size 的大小。其中第2个的 batch_size 为1,这是因为有一张图片损坏,导致其无法正常返回。而最后1个的 batch_size 也为1,这是因为共有9张(包括损坏的文件)图片,无法整除2(batch_size),因此最后一个batch的数据会少于 batch_size,可通过指定 drop_last=True 来丢弃最后一个不足 batch_size 的batch。
对于诸如样本损坏或数据集加载异常等情况,还可以通过其它方式解决。例如但凡遇到异常情况,就随机取一张图片代替:
class NewDogCat(DogCat):
def __getitem__(self, index):
try:
return super(NewDogCat, self).__getitem__(index)
except:
new_index = random.randint(0, len(self)-1)
return self[new_index]
相比较丢弃异常图片而言,这种做法会更好一些,因为它能保证每个 batch 的数目仍是 batch_size 。但在大多数情况下,最好的方式还是对数据进行彻底清洗。
DataLoader 里面并没有太多的魔法方法,它封装了Python的标准库 multiprocessing ,使其能够实现多进程加速。在此提几点关于 Dataset 和 DataLoader 使用方面的建议:
- 高负载的操作放在 __getitem__ 中,如加载图片等。
- dataset中应尽量只包含只读对象,避免修改任何可变对象,利用多线程进行操作。
第一点是因为多进程会并行的调用 __getitem__ 函数,将负载高的放在 __getitem__ 函数中能够实现并行加速。
第二点是因为 dataloader 使用多进程加载,如果在 Dataset 实现中使用了可变对象,可能会有意想不到的冲突。在多线程/多进程中,修改一个可变对象,需要加锁,但是 dataloader 的设计使得其很难加锁(在实际使用中也应尽量避免锁的存在),因此最好避免在 dataset 中修改可变对象。
例如下面就是一个不好的例子,在多进程处理中 self.num 可能与预期不符,这种问题不会报错,因此难以发现。如果一定要修改可变对象,建议使用 Python 标准库 Queue 中的相关数据结构。
class BadDataset(Dataset):
def __init__(self):
self.datas = range(100)
self.num = 0 # 取数据的次数
def __getitem__(self, index):
self.num += 1
return self.datas[index]
使用 Python multiprocessing 库的另一个问题是,在使用多进程时,如果主程序异常终止(比如用 Ctrl+C 强行退出),相应的数据加载进程可能无法正常退出。这时你可能会发现程序已经退出了,但GPU显存和内存依旧被占用着,或通过top、ps aux依旧能够看到已经退出的程序,这时就需要手动强行杀掉进程。建议使用如下命令:
ps x | grep <cmdline> | awk '{print $1}' | xargs kill
- ps x:获取当前用户的所有进程
- grep :找到已经停止的PyTorch程序的进程,例如你是通过 python train.py 启动的,那你就需要写 grep 'python train.py'
- awk '{print $1}':获取进程的pid
- xargs kill:杀掉进程,根据需要可能要写成xargs kill -9强制杀掉进程
在执行这句命令之前,建议先打印确认一下是否会误杀其它进程 :
ps x | grep <cmdline> | ps x
PyTorch中还单独提供了一个 sampler 模块,用来对数据进行采样。常用的有随机采样器:RandomSampler,当dataloader的 shuffle 参数为True时,系统会自动调用这个采样器,实现打乱数据。默认的是采用 SequentialSampler,它会按顺序一个一个进行采样。这里介绍另外一个很有用的采样方法: WeightedRandomSampler ,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。
构建 WeightedRandomSampler 时需提供两个参数:每个样本的权重weights、共选取的样本总数num_samples,以及一个可选参数replacement。权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。replacement 用于指定是否可以重复选取某一个样本,默认为 True,即允许在一个 epoch 中重复采样某一个数据。如果设为False,则当某一类的样本被全部选取完,但其样本数目仍未达到 num_samples 时,sampler 将不会再从该类中选择数据,此时可能导致 weights 参数失效。下面举例说明。
dataset = DogCat('data/dogcat/', transforms=transform)
# 狗的图片被取出的概率是猫的概率的两倍
# 两类图片被取出的概率与weights的绝对大小无关,只和比值有关
weights = [2 if label == 1 else 1 for data, label in dataset]
weights
#输出 [1, 2, 2, 1, 2, 1, 1, 2]
from torch.utils.data.sampler import WeightedRandomSampler
sampler = WeightedRandomSampler(weights,\
num_samples=9,\
replacement=True)
dataloader = DataLoader(dataset,
batch_size=3,
sampler=sampler)
for datas, labels in dataloader:
print(labels.tolist())
#输出 [1, 0, 1] [0, 0, 1] [1, 0, 1]
可见猫狗样本比例约为1:2,另外一共只有8个样本,但是却返回了9个,说明肯定有被重复返回的,这就是 replacement 参数的作用,下面将 replacement 设为 False 试试。
sampler = WeightedRandomSampler(weights, 8, replacement=False)
dataloader = DataLoader(dataset, batch_size=4, sampler=sampler)
for datas, labels in dataloader:
print(labels.tolist())
#输出 [0, 1, 1, 1] [1, 0, 0, 0]
在这种情况下,num_samples 等于 dataset 的样本总数,为了不重复选取, sampler 会将每个样本都返回,这样就失去 weight 参数的意义了。
实战解析—— Kaggle 上的经典比赛:Dogs vs. Cats 的数据载入
数据的相关处理主要保存在data/dataset.py中。关于数据加载的相关操作,在上面已经提到过,其基本原理就是使用 Dataset 提供数据集的封装,再使用 Dataloader 实现数据并行加载。
Kaggle 提供的数据包括训练集和测试集,而我们在实际使用中,还需专门从训练集中取出一部分作为验证集。对于这三类数据集,其相应操作也不太一样,而如果专门写三个Dataset,则稍显复杂和冗余,因此这里通过加一些判断来区分。
对于训练集,我们希望做一些数据增强处理,如随机裁剪、随机翻转、加噪声等,而验证集和测试集则不需要。下面看 dataset.py 的代码:
import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T
class DogCat(data.Dataset):
def __init__(self, root, transforms=None, train=True, test=False):
"""
目标:获取所有图片地址,并根据训练、验证、测试划分数据
"""
self.test = test
imgs = [os.path.join(root, img) for img in os.listdir(root)]
# test1: data/test1/8973.jpg
# train: data/train/cat.10004.jpg
if self.test:
# 如果是进行测试,截取得到数据的数字标码,
# 如上面的8973,根据key=8973的值进行排序,返回对所有的图片路径进行排序后返回
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))
else:
# 如果是进行训练,截取得到数据的数字标识,
# 如上面的10004,根据key=10004的值进行排序,返回对所有的图片路径进行排序后返回
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))
imgs_num = len(imgs) # 然后就可以得到数据的大小
# 划分训练、验证集,验证:训练 = 3:7
if self.test:
self.imgs = imgs
elif train:
self.imgs = imgs[:int(0.7*imgs_num)]
else :
self.imgs = imgs[int(0.7*imgs_num):]
if transforms is None:
# 数据转换操作,测试验证和训练的数据转换有所区别
normalize = T.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])
# 测试集和验证集
if self.test or not train:
self.transforms = T.Compose([
T.Resize(224), #重新设定大小
T.CenterCrop(224), #从图片中心截取
T.ToTensor(), #转成Tensor格式,大小范围为[0,1]
normalize #归一化处理,大小范围为[-1,1]
])
# 训练集
else :
self.transforms = T.Compose([
T.Resize(256),
T.RandomReSizedCrop(224), #从图片的任何部位随机截取224*224大小的图
T.RandomHorizontalFlip(), #随机水平翻转给定的PIL.Image,翻转概率为0.5
T.ToTensor(),
normalize
])
def __getitem__(self, index):
"""
返回一张图片的数据
对于测试集,没有label,返回图片id,如1000.jpg返回1000
"""
img_path = self.imgs[index]
if self.test:
#如果是测试,得到图片路径中的数字标识作为label
label = int(self.imgs[index].split('.')[-2].split('/')[-1])
else:
#如果是训练,判断图片路径中是猫狗来设定label,猫为0,狗为1
label = 1 if 'dog' in img_path.split('/')[-1] else 0
data = Image.open(img_path) #打开该路径获得数据
data = self.transforms(data) #然后对图片数据进行transform
return data, label #最后得到统一的图片信息和label信息
def __len__(self):
"""
返回数据集中所有图片的个数
"""
return len(self.imgs)
关于数据集使用的注意事项:
将文件读取等费时操作放在 __getitem__ 函数中,利用多进程加速。避免一次性将所有图片都读进内存,不仅费时也会占用较大内存,而且不易进行数据增强等操作。
另外在这里,将训练集中的30%作为验证集,可用来检查模型的训练效果,避免过拟合。在使用时,可通过 dataloader 加载数据。
#进行训练,定义训练数据和label集合train_dataset
train_dataset = DogCat(opt.train_data_root, train=True)
#开始进行数据的加载,将数据打乱(shuffle = True),
#随机将数据分批,一批有opt.batch_size个,并行处理,打开opt.num_workers个进程
trainloader = DataLoader(train_dataset,
batch_size = opt.batch_size,
shuffle = True,
num_workers = opt.num_workers)
#显示得到的数据
for ii, (data, label) in enumerate(trainloader):
train() #然后进行训练