Pytorch学习笔记(二):载入数据

重要:本文最后更新于,某些文章具有时效性,若有错误或失效,请在下方留言

本篇笔记主要记录了Pytorch项目程序中作为第一步的“载入数据”的常用代码、注解和心得,后续遇到更新的表达方式之后会持续更新。

部分学习资源链接:

之前的相关文章:

数据处理

在解决深度学习问题的过程中,往往需要花费大量的精力去处理数据。数据的处理对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练,更会提高模型效果。考虑到这点,PyTorch提供了几个高效便捷的工具,以便使用者进行数据处理或增强等操作,同时可通过并行化加速数据加载。

加载数据的方法

在PyTorch中,数据加载可通过自定义的数据集对象。数据集对象被抽象为 Dataset 类,实现自定义的数据集需要继承 Dataset,并实现两个Python魔法方法(关于魔术方法详细解释可以参考:《python的魔法方法是什么?》,同时也可以参考魔法方法解释最详细的论文《MagicMethods》——RafeKettler):

  • __getitem__:返回一条数据,或一个样本。obj[index] 等价于obj.__getitem__(index)
  • __len__:返回样本的数量。len(obj) 等价于obj.__len__()

这里重点看 __getitem__ 函数,__getitem__ 接收一个index,然后返回图片数据和标签,这个 index 通常指的是一个 listindex ,这个 list 的每个元素就包含了图片数据的路径和标签信息。

然而,如何制作这个 list 呢?通常的方法是将图片的路径和标签信息存储在一个 txt 中,然后从该 txt 中读取。

那么读取自己数据的基本流程就是:

  1. 制作存储了图片的路径和标签信息的txt
  2. 将这些信息转化为 list,该 list 每一个元素对应一个样本
  3. 通过 __getitem__ 函数,读取数据和标签,并返回数据和标签

在训练代码里是感觉不到这些操作的,只会看到通过 DataLoader 就可以获取一个 batch 的数据,其实触发去读取图片这些操作的是 DataLoader 里的 __iter__(self) ,后面会详细说明。

因此,要让PyTorch能读取自己的数据集,只需要两步:

  1. 制作图片数据的索引
  2. 构建Dataset子类

制作图片数据的索引

这个比较简单,就是读取图片路径,标签,保存到txt文件中,这里注意格式就好。特别注意的是,txt中的路径,是以训练时的那个 py 文件所在的目录为工作目录,所以这里需要提前算好相对路径!

# coding:utf-8
import os
'''
    为数据集生成对应的txt文件
'''

train_txt_path = os.path.join("..", "..", "Data", "train.txt")
train_dir = os.path.join("..", "..", "Data", "train")

valid_txt_path = os.path.join("..", "..", "Data", "valid.txt")
valid_dir = os.path.join("..", "..", "Data", "valid")


def gen_txt(txt_path, img_dir):
    f = open(txt_path, 'w')
    
    for root, s_dirs, _ in os.walk(img_dir, topdown=True):  # 获取 train文件下各文件夹名称
        for sub_dir in s_dirs:
            i_dir = os.path.join(root, sub_dir)             # 获取各类的文件夹 绝对路径
            img_list = os.listdir(i_dir)                    # 获取类别文件夹下所有png图片的路径
            for i in range(len(img_list)):
                if not img_list[i].endswith('png'):         # 若不是png文件,跳过
                    continue
                label = img_list[i].split('_')[0]
                img_path = os.path.join(i_dir, img_list[i])
                line = img_path + ' ' + label + '\n'
                f.write(line)
    f.close()


if __name__ == '__main__':
    gen_txt(train_txt_path, train_dir)
    gen_txt(valid_txt_path, valid_dir)

运行代码,即会在/Data/文件夹下面看到 train.txt valid.txt 。txt中是这样的:

构建Dataset子类

# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset


class MyDataset(Dataset):
    def __init__(self, txt_path, transform=None, target_transform=None):
        fh = open(txt_path, 'r')
        imgs = []
        for line in fh:
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0], int(words[1])))

        self.imgs = imgs        # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = Image.open(fn).convert('RGB')     # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.imgs)

首先看看初始化,初始化中从我们准备好的txt里获取图片的路径和标签,并且存储在 self.imgsself.imgs就是上面提到的 list ,其一个元素对应一个样本的路径和标签,其实就是 txt 中的一行。

初始化中还会初始化 transformtransform 是一个 Compose 类型,里边有一个 listlist 中就会定义了各种对图像进行处理的操作,可以设置减均值,除标准差,随机裁剪,旋转,翻转,仿射变换等操作。

在这里我们可以知道,一张图片读取进来之后,会经过数据处理(数据增强),最终变成输入模型的数据。这里就有一点需要注意,PyTorch 的数据增强是将原始图片进行了处理,并不会生成新的一份图片,而是“覆盖”原图,当采用 randomcrop 之类的随机操作时,每个 epoch 输入进来的图片几乎不会是一模一样的,这达到了样本多样性的功能。

然后看看核心的 __getitem__ 函数:

第一行:self.imgs 是一个 list ,也就是一开始提到的 listself.imgs 的一个元素是一个 str ,包含图片路径,图片标签,这些信息是从 txt 文件中读取。

第二行:利用 Image.open 对图片进行读取,img 类型为 Imagemode=‘RGB’

第三行与第四行: 对图片进行处理,这个 transform 里边可以实现 减均值,除标准差,随机裁剪,旋转,翻转,放射变换,等等操作。

Mydataset 构建好,剩下的操作就交给 DataLoder ,在 DataLoder 中,会触发 Mydataset 中的 __getitem__ 函数读取一张图片的数据和标签,并拼接成一个 batch 返回,作为模型真正的输入。下面将会通过一个小例子,说明 DataLoder 是如何获取一个 batch ,以及一张图片是如何被 PyTorch 读取,最终变为模型的输入的。

例子就是上一篇笔记的举例代码

大体流程:

  1. main.py: train_data = MyDataset(txt_path=train_txt_path, …) --->
  2. main.py: train_loader = DataLoader(dataset=train_data, …) --->
  3. main.py: for i, data in enumerate(train_loader, 0) --->
  4. dataloder.py: class DataLoader(): def __iter__(self): return _DataLoaderIter(self) --->
  5. dataloder.py: class _DataLoderIter(): def __next__(self): batch = self.collate_fn([self.dataset[i] for i in indices]) --->
  6. tool.py: class MyDataset(): def __getitem__(): img = Image.open(fn).convert('RGB') --->
  7. tool.py: class MyDataset(): img = self.transform(img) --->
  8. main.py: inputs, labels = data inputs, labels = Variable(inputs), Variable(labels) outputs = net(inputs)

一开始通过 MyDataset 创建一个实例,在该实例中有路径,有读取图片的方法(函数)。

然后需要 Pytroch 的一系列规范化流程,在第6步中,才会调用 MyDataset 中的 __getitem__() 函数,最终通过 Image.open() 读取图片数据。

然后对原始图片数据进行一系列预处理( transform 中设置),最后回到 main.py,对数据进行转换成 Variable 类型,最终成为模型的输入。

流程详细描述:

  1. MyDataset 类中初始化 txt,txt 中有图片路径和标签
  2. 初始化 DataLoder 时,将 train_data 传入,从而使 DataLoder 拥有图片的路径
  3. 在一个 iteration 进行时,才读取一个 batch 的图片数据 enumerate()函数会返回可迭代数据的一个“元素”。在这里 data 是一个 batch 的图片数据和标签,data是一个list。
  4. class DataLoader() 中再调用 class _DataLoderIter()
  5. _DataLoderiter() 类中会跳到 __next(self)__ 函数,在该函数中会通过 indices = next(self.sample_iter) 获取一个 batch indices ,再通过 batch = self.collate_fn([self.dataset[i] for i in indices]) 获取一个 batch 的数据在 batch = self.collate_fn([self.dataset[i] for i in indices]) 中会调用 self.collate_fn 函数
  6. self.collate_fn 中会调用 MyDataset 类中的 __getitem()__ 函数,在 __getitem()__ 中通过 Image.open(fn).convert('RGB') 读取图片
  7. 通过 Image.open(fn).convert('RGB') 读取图片之后,会对图片进行预处理,例如减均值,除以标准差,随机裁剪等等一系列提前设置好的操作。最后返回 imglabel,再通过 self.collate_fn 来拼接成一个 batch 。一个 batch 是一个 list ,有两个元素,第一个元素是图片数据,是一个 4D 的 Tensor ,shape 为(64,3,32,32),第二个元素是标签 shape 为(64)。
  8. 将图片数据转换成 Variable 类型,然后称为模型真正的输入
    inputs, labels = Variable(inputs), Variable(labels)
    outputs = net(inputs)

通过了解图片从硬盘到模型的过程,可以更好的对数据做处理(减均值,除以标准差,裁剪,翻转,放射变换等等),也可以灵活的为模型准备数据,最后总结两个需要注意的地方。

  1. 图片是通过Image.open()函数读取进来的,当涉及如下问题:
    图片的通道顺序(RGB ? BGR ?)
    图片是w*h*c ? c*w*h ?
    像素值范围[0-1] or [0-255] ?
    就要查看 MyDataset() 类中 __getitem__()下读取图片用的是什么方法
  2. MyDataset() 类中 __getitem__(函数中发现,PyTorch 做数据增强的方法是在原始图片上进行的,并覆盖原始图片,这一点需要注意。

数据增强与数据标准化

在实际应用过程中,我们会在数据进入模型之前进行一些预处理,例如数据中心化(仅减均值),数据标准化(减均值,再除以标准差),随机裁剪,旋转一定角度,镜像等一系列操作。

在PyTorch中,这些数据增强方法放在了 transforms.py 文件中。这些数据处理可以满足我们大部分的需求,通过熟悉 transforms.py ,也可以自定义数据处理函数,实现自己的数据增强。

先在这里从宏观地说明 transform 的使用,不久之后我会总结 transform 的详细用法。参考代码依然是上一篇笔记的示例代码,如下。

# 数据预处理设置
normMean = [0.4948052, 0.48568845, 0.44682974]
normStd = [0.24580306, 0.24236229, 0.2603115]
normTransform = transforms.Normalize(normMean, normStd)
trainTransform = transforms.Compose([
    transforms.Resize(32),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    normTransform
])

validTransform = transforms.Compose([
    transforms.ToTensor(),
    normTransform
])

前三行设置均值,标准差,以及数据标准化:transforms.Normalize() 函数,这里是以通道为单位进行计算均值,标准差。

然后用 transforms.Compose 将所需要进行的处理给组合起来,并且需要注意顺序!

在训练时,依次对图片进行以下操作:

  1. 随机裁剪:在裁剪之前先对图片的上下左右均填充上4个pixel,值为0,即变成一个40*40的数据,然后再随机进行32*32的裁剪。
  2. Totensor:在这里会对数据进行transpose,原来是hwc,会经过img = img.transpose(0, 1).transpose(0, 2).contiguous(),变成chw再除以255,使得像素值归一化至[0-1]之间
  3. 数据标准化(减均值,除以标准差):关于标准化的分析文章后面有详细说明。

数据加载实例分析

接下来以Kaggle的“Dog vs. Cat”数据为例再示范一些加载数据的方法,"Dogs vs. Cats"是一个分类问题,判断一张图片是狗还是猫,其所有图片都存放在一个文件夹下,根据文件名的前缀判断是狗还是猫。

首先先在程序中以树形图显示文件列表:

%env LS_COLORS = None
!tree ".\data\dogcat" /F
#输出如下
env: LS_COLORS=None
data/dogcat/
|-- cat.12484.jpg
|-- cat.12485.jpg
|-- cat.12486.jpg
|-- cat.12487.jpg
|-- dog.12496.jpg
|-- dog.12497.jpg
|-- dog.12498.jpg
`-- dog.12499.jpg
0 directories, 8 files

这里需要注意的是第二行代码与陈云给出的代码不一样,原因是因为陈云使用的是Linux系统,陈云原本的代码在Windows上运行会报错,Windows运行以上即可。
接下来定义一下自己的数据集,并依次获取图片的数据。

import torch as t
from torch.utils import data
import os
from PIL import  Image
import numpy as np
class DogCat(data.Dataset):
    def __init__(self, root):
        imgs = os.listdir(root)
        # 所有图片的绝对路径
        # 这里不实际加载图片,只是指定路径,当调用__getitem__时才会真正读图片
        self.imgs = [os.path.join(root, img) for img in imgs]
        
    def __getitem__(self, index):
        img_path = self.imgs[index]
        # dog->1, cat->0
        label = 1 if 'dog' in img_path.split('/')[-1] else 0
        pil_img = Image.open(img_path)
        array = np.asarray(pil_img)
        data = t.from_numpy(array)
        return data, label
    
    def __len__(self):
        return len(self.imgs)
dataset = DogCat('./data/dogcat/')
img, label = dataset[0] # 相当于调用dataset.__getitem__(0)
for img, label in dataset:
    print(img.size(), img.float().mean(), label)
#输出如下
torch.Size([375, 499, 3]) tensor(116.8187) 1
torch.Size([499, 379, 3]) tensor(171.8088) 0
torch.Size([236, 289, 3]) tensor(130.3022) 0
torch.Size([377, 499, 3]) tensor(151.7141) 1
torch.Size([374, 499, 3]) tensor(115.5157) 0
torch.Size([375, 499, 3]) tensor(150.5086) 1
torch.Size([400, 300, 3]) tensor(128.1548) 1
torch.Size([500, 497, 3]) tensor(106.4917) 0

通过以上代码我们可以发现两个主要问题:

  • 返回样本的形状不一,因每张图片的大小不一样,这对于需要取batch训练的神经网络来说很不友好
  • 返回样本的数值较大,未归一化至[-1, 1]

所以这种方法在实际使用中其实并不常用。

针对上述问题,PyTorch提供了 torchvision 。它是一个视觉工具包,提供了很多视觉图像处理的工具,其中 transforms 模块提供了对 PIL Image 对象和 Tensor 对象的常用操作。

对PIL Image的操作包括:

  • Scale:调整图片尺寸,长宽比保持不变
  • CenterCropRandomCropRandomResizedCrop: 裁剪图片
  • Pad:填充
  • ToTensor:将PIL Image对象转成Tensor,会自动将[0, 255]归一化至[0, 1]

对Tensor的操作包括:

  • Normalize:标准化,即减均值,除以标准差
  • ToPILImage:将Tensor转为PIL Image对象

我曾经在代码中总是无法理解 Normalize( mean , std ) 这个函数的意义,后来经过搜索,我发现标准化函数的计算其实非常简单:

\[ image = \frac{image-mean}{std}\]

transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1, 1]

而我们时常看见将标准化平均值和标准差设置为0.5的代码,其注释都会表明将图片标准化至 [-1,1] ,就是因为我们原本的图片归一化值均在[0,1]之间,最小值0经过运算 \(\frac{0-0.5}{0.5}=-1\) ,而1经过运算 \(\frac{1-0.5}{0.5}=1\) ,进而将图片标准化为了[-1,1]。

如果要对图片进行多个操作,可通过 Compose 函数将这些操作拼接起来,类似于 nn.Sequential 。注意,这些操作定义后是以函数的形式存在,真正使用时需调用它的 __call__ 方法,这点类似于 nn.Module 。例如要将图片调整为224×224,首先应构建这个操作 trans = Resize((224, 224)) ,然后调用 trans(img) 。下面我们就用 transforms 的这些操作来优化上面实现的dataset

import os
from PIL import  Image
import numpy as np
from torchvision import transforms as T
transform = T.Compose([
    T.Resize(224), # 缩放图片(Image),保持长宽比不变,最短边为224像素
    T.CenterCrop(224), # 从图片中间切出224*224的图片
    T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
    T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1, 1],规定均值和标准差
])
class DogCat(data.Dataset):
    def __init__(self, root, transforms=None):
        imgs = os.listdir(root)
        self.imgs = [os.path.join(root, img) for img in imgs]
        self.transforms=transforms
        
    def __getitem__(self, index):
        img_path = self.imgs[index]
        label = 0 if 'dog' in img_path.split('/')[-1] else 1
        data = Image.open(img_path)
        if self.transforms:
            data = self.transforms(data)
        return data, label
    
    def __len__(self):
        return len(self.imgs)
dataset = DogCat('./data/dogcat/', transforms=transform)
img, label = dataset[0]
for img, label in dataset:
    print(img.size(), label)
#输出如下
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 1

除了上述操作之外,transforms 还可通过 Lambda 封装自定义的转换策略。例如想对PIL Image进行随机旋转,则可写成这样 trans=T.Lambda(lambda img: img.rotate(random()*360))

torchvision 已经预先实现了常用的 Dataset ,包括前面使用过的 CIFAR-10,以及 ImageNet、COCO、MNIST、LSUN 等数据集,可通过诸如 torchvision.datasets.CIFAR10 来调用,具体使用方法请参看官方文档

在这里介绍一个会经常使用到的 Dataset——ImageFolder,它的实现和上述的DogCat很相似。ImageFolder 假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下:

ImageFolder(root, transform=None, target_transform=None, loader=default_loader)

它主要有四个参数:

  • root:在 root 指定的路径下寻找图片
  • transform:对PIL Image进行的转换操作,transform 的输入是使用loader读取图片的返回对象
  • target_transform:对 label 的转换
  • loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象

label 是按照文件夹名顺序排序后存成字典,即{类名:类序号(从0开始)},一般来说最好直接将文件夹命名为从0开始的数字,这样会和 ImageFolder 实际的 label 一致,如果不是这种命名规范,建议看看 self.class_to_idx 属性以了解 label 和文件夹名的映射关系。

接下来看一下 self.class_to_idx 的使用例子。

!tree ".\data\dogcat_2" /F
#输出
data/dogcat_2/
|-- cat
|   |-- cat.12484.jpg
|   |-- cat.12485.jpg
|   |-- cat.12486.jpg
|   `-- cat.12487.jpg
`-- dog
    |-- dog.12496.jpg
    |-- dog.12497.jpg
    |-- dog.12498.jpg
    `-- dog.12499.jpg
2 directories, 8 files
from torchvision.datasets import ImageFolder
dataset = ImageFolder('data/dogcat_2/')
# cat文件夹的图片对应label 0,dog对应1
dataset.class_to_idx
#输出
{'cat': 0, 'dog': 1}
#所有图片的路径和对应的label
dataset.imgs
#输出
[('data/dogcat_2/cat/cat.12484.jpg', 0),
 ('data/dogcat_2/cat/cat.12485.jpg', 0),
 ('data/dogcat_2/cat/cat.12486.jpg', 0),
 ('data/dogcat_2/cat/cat.12487.jpg', 0),
 ('data/dogcat_2/dog/dog.12496.jpg', 1),
 ('data/dogcat_2/dog/dog.12497.jpg', 1),
 ('data/dogcat_2/dog/dog.12498.jpg', 1),
 ('data/dogcat_2/dog/dog.12499.jpg', 1)]
# 没有任何的transform,所以返回的还是PIL Image对象
dataset[0][1] # 第一维是第几张图,第二维为1返回label
dataset[0][0] # 为0返回图片数据
# 加上transform
normalize = T.Normalize(mean=[0.4, 0.4, 0.4], std=[0.2, 0.2, 0.2])
transform  = T.Compose([
         T.RandomResizedCrop(224),
         T.RandomHorizontalFlip(),
         T.ToTensor(),
         normalize,
])
dataset = ImageFolder('data/dogcat_2/', transform=transform)
# 深度学习中图片数据一般保存成CxHxW,即通道数x图片高x图片宽
dataset[0][0].size()
#输出
torch.Size([3, 224, 224])
to_img = T.ToPILImage()
# 0.2和0.4是标准差和均值的近似
to_img(dataset[0][0]*0.2+0.4)

Dataset只负责数据的抽象,一次调用 __getitem__ 只返回一个样本。前面提到过,在训练神经网络时,最好是对一个 batch 的数据进行操作,同时还需要对数据进行 shuffle 和并行加速等。对此,PyTorch 提供了 DataLoader 帮助我们实现这些功能。

DataLoader 的函数定义如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
  • dataset:加载的数据集(Dataset对象)
  • batch_size:batch size
  • shuffle:是否将数据打乱
  • sampler: 样本抽样,后续会详细介绍
  • num_workers:使用多进程加载的进程数,0代表不使用多进程
  • collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可
  • pin_memory:是否将数据保存在 pin memory 区,pin memory 中的数据转到GPU会快一些
  • drop_last:dataset中的数据个数可能不是 batch_size 的整数倍, drop_last 为True会将多出来不足一个batch的数据丢弃
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)
dataiter = iter(dataloader)
imgs, labels = next(dataiter)
imgs.size() # batch_size, channel, height, weight
#输出
torch.Size([3, 3, 224, 224])

dataloader 是一个可迭代的对象,意味着我们可以像使用迭代器一样使用它,例如:

for batch_datas, batch_labels in dataloader:
    train()

或者:

dataiter = iter(dataloader)
batch_datas, batch_labesl = next(dataiter)

在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在 __getitem__ 函数中将出现异常,此时最好的解决方案即是将出错的样本剔除。如果实在是遇到这种情况无法处理,则可以返回 None 对象,然后在 Dataloader 中实现自定义的 collate_fn ,将空对象过滤掉。但要注意,在这种情况下 dataloader 返回的 batch 数目会少于 batch_size

class NewDogCat(DogCat): # 继承前面实现的DogCat数据集
    def __getitem__(self, index):
        try:
            # 调用父类的获取函数,即 DogCat.__getitem__(self, index)
            return super(NewDogCat,self).__getitem__(index)
        except:
            return None, None
from torch.utils.data.dataloader import default_collate # 导入默认的拼接方式
def my_collate_fn(batch):
    '''
    batch中每个元素形如(data, label)
    '''
    # 过滤为None的数据
    batch = list(filter(lambda x:x[0] is not None, batch))
    if len(batch) == 0: return t.Tensor()
    return default_collate(batch) # 用默认方式拼接过滤后的batch数据
  dataset = NewDogCat('data/dogcat_wrong/', transforms=transform)
  dataset[5]
(tensor([[[ 3.0000,  3.0000,  3.0000,  ...,  2.5490,  2.6471,  2.7059],
          [ 3.0000,  3.0000,  3.0000,  ...,  2.5686,  2.6471,  2.7059],
          [ 2.9804,  3.0000,  3.0000,  ...,  2.5686,  2.6471,  2.6863],
          ...,
          [-0.9020, -0.8627, -0.8235,  ...,  0.0392,  0.2745,  0.4314],
          [-0.9216, -0.8824, -0.8235,  ...,  0.0588,  0.2745,  0.4314],
          [-0.9412, -0.9020, -0.8431,  ...,  0.0980,  0.3137,  0.4510]],
 
         [[ 3.0000,  2.9608,  2.8824,  ...,  2.5294,  2.6275,  2.6863],
          [ 3.0000,  2.9804,  2.9216,  ...,  2.5098,  2.6275,  2.6863],
          [ 3.0000,  2.9804,  2.9608,  ...,  2.5098,  2.6078,  2.6667],
          ...,
          [-1.0392, -1.0000, -0.9608,  ..., -0.0196,  0.1961,  0.3333],
          [-1.0392, -1.0000, -0.9608,  ...,  0.0000,  0.1961,  0.3333],
          [-1.0392, -1.0196, -0.9804,  ...,  0.0392,  0.2353,  0.3529]],
 
         [[ 2.1765,  2.1961,  2.2157,  ...,  1.5882,  1.6863,  1.7451],
          [ 2.2157,  2.2157,  2.2353,  ...,  1.6078,  1.7255,  1.7843],
          [ 2.2157,  2.2353,  2.2549,  ...,  1.6667,  1.7647,  1.8235],
          ...,
          [-1.4118, -1.3725, -1.3333,  ..., -0.5098, -0.3922, -0.2941],
          [-1.4118, -1.3725, -1.3333,  ..., -0.4902, -0.3725, -0.2941],
          [-1.4314, -1.3922, -1.3529,  ..., -0.4510, -0.3333, -0.2745]]]), 0)
dataloader = DataLoader(dataset, 2, collate_fn=my_collate_fn, num_workers=1,shuffle=True)
for batch_datas, batch_labels in dataloader:
    print(batch_datas.size(),batch_labels.size())
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([1, 3, 224, 224]) torch.Size([1])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([1, 3, 224, 224]) torch.Size([1])

来看一下上述 batch_size 的大小。其中第2个的 batch_size 为1,这是因为有一张图片损坏,导致其无法正常返回。而最后1个的 batch_size 也为1,这是因为共有9张(包括损坏的文件)图片,无法整除2(batch_size),因此最后一个batch的数据会少于 batch_size,可通过指定 drop_last=True 来丢弃最后一个不足 batch_size 的batch。

对于诸如样本损坏或数据集加载异常等情况,还可以通过其它方式解决。例如但凡遇到异常情况,就随机取一张图片代替:

class NewDogCat(DogCat):
    def __getitem__(self, index):
        try:
            return super(NewDogCat, self).__getitem__(index)
        except:
            new_index = random.randint(0, len(self)-1)
            return self[new_index]

相比较丢弃异常图片而言,这种做法会更好一些,因为它能保证每个 batch 的数目仍是 batch_size 。但在大多数情况下,最好的方式还是对数据进行彻底清洗。

DataLoader 里面并没有太多的魔法方法,它封装了Python的标准库 multiprocessing ,使其能够实现多进程加速。在此提几点关于 Dataset 和 DataLoader 使用方面的建议:

  1. 高负载的操作放在 __getitem__ 中,如加载图片等。
  2. dataset中应尽量只包含只读对象,避免修改任何可变对象,利用多线程进行操作。

第一点是因为多进程会并行的调用 __getitem__ 函数,将负载高的放在 __getitem__ 函数中能够实现并行加速。

第二点是因为 dataloader 使用多进程加载,如果在 Dataset 实现中使用了可变对象,可能会有意想不到的冲突。在多线程/多进程中,修改一个可变对象,需要加锁,但是 dataloader 的设计使得其很难加锁(在实际使用中也应尽量避免锁的存在),因此最好避免在 dataset 中修改可变对象。

例如下面就是一个不好的例子,在多进程处理中 self.num 可能与预期不符,这种问题不会报错,因此难以发现。如果一定要修改可变对象,建议使用 Python 标准库 Queue 中的相关数据结构。

class BadDataset(Dataset):
    def __init__(self):
        self.datas = range(100)
        self.num = 0 # 取数据的次数
    def __getitem__(self, index):
        self.num += 1
        return self.datas[index]

使用 Python multiprocessing 库的另一个问题是,在使用多进程时,如果主程序异常终止(比如用 Ctrl+C 强行退出),相应的数据加载进程可能无法正常退出。这时你可能会发现程序已经退出了,但GPU显存和内存依旧被占用着,或通过top、ps aux依旧能够看到已经退出的程序,这时就需要手动强行杀掉进程。建议使用如下命令:

ps x | grep <cmdline> | awk '{print $1}' | xargs kill
  • ps x:获取当前用户的所有进程
  • grep :找到已经停止的PyTorch程序的进程,例如你是通过 python train.py 启动的,那你就需要写 grep 'python train.py'
  • awk '{print $1}':获取进程的pid
  • xargs kill:杀掉进程,根据需要可能要写成xargs kill -9强制杀掉进程

在执行这句命令之前,建议先打印确认一下是否会误杀其它进程 :

ps x | grep <cmdline> | ps x

PyTorch中还单独提供了一个 sampler 模块,用来对数据进行采样。常用的有随机采样器:RandomSampler,当dataloader的 shuffle 参数为True时,系统会自动调用这个采样器,实现打乱数据。默认的是采用 SequentialSampler,它会按顺序一个一个进行采样。这里介绍另外一个很有用的采样方法: WeightedRandomSampler ,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。

构建 WeightedRandomSampler 时需提供两个参数:每个样本的权重weights、共选取的样本总数num_samples,以及一个可选参数replacement。权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。replacement 用于指定是否可以重复选取某一个样本,默认为 True,即允许在一个 epoch 中重复采样某一个数据。如果设为False,则当某一类的样本被全部选取完,但其样本数目仍未达到 num_samples 时,sampler 将不会再从该类中选择数据,此时可能导致 weights 参数失效。下面举例说明。

dataset = DogCat('data/dogcat/', transforms=transform)

# 狗的图片被取出的概率是猫的概率的两倍
# 两类图片被取出的概率与weights的绝对大小无关,只和比值有关
weights = [2 if label == 1 else 1 for data, label in dataset]
weights
#输出
[1, 2, 2, 1, 2, 1, 1, 2]
from torch.utils.data.sampler import  WeightedRandomSampler
sampler = WeightedRandomSampler(weights,\
                                num_samples=9,\
                                replacement=True)
dataloader = DataLoader(dataset,
                        batch_size=3,
                        sampler=sampler)
for datas, labels in dataloader:
    print(labels.tolist())
#输出
[1, 0, 1]
[0, 0, 1]
[1, 0, 1]

可见猫狗样本比例约为1:2,另外一共只有8个样本,但是却返回了9个,说明肯定有被重复返回的,这就是 replacement 参数的作用,下面将 replacement 设为 False 试试。

sampler = WeightedRandomSampler(weights, 8, replacement=False)
dataloader = DataLoader(dataset, batch_size=4, sampler=sampler)
for datas, labels in dataloader:
    print(labels.tolist())
#输出
[0, 1, 1, 1]
[1, 0, 0, 0]

在这种情况下,num_samples 等于 dataset 的样本总数,为了不重复选取, sampler 会将每个样本都返回,这样就失去 weight 参数的意义了。

实战解析—— Kaggle 上的经典比赛:Dogs vs. Cats 的数据载入

数据的相关处理主要保存在data/dataset.py中。关于数据加载的相关操作,在上面已经提到过,其基本原理就是使用 Dataset 提供数据集的封装,再使用 Dataloader 实现数据并行加载。

Kaggle 提供的数据包括训练集测试集,而我们在实际使用中,还需专门从训练集中取出一部分作为验证集。对于这三类数据集,其相应操作也不太一样,而如果专门写三个Dataset,则稍显复杂和冗余,因此这里通过加一些判断来区分。

对于训练集,我们希望做一些数据增强处理,如随机裁剪、随机翻转、加噪声等,而验证集和测试集则不需要。下面看 dataset.py 的代码:

import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T

class DogCat(data.Dataset):
    
    def __init__(self, root, transforms=None, train=True, test=False):
        """
        目标:获取所有图片地址,并根据训练、验证、测试划分数据
        """
        self.test = test
        imgs = [os.path.join(root, img) for img in os.listdir(root)] 

        # test1: data/test1/8973.jpg
        # train: data/train/cat.10004.jpg 
        if self.test:
            # 如果是进行测试,截取得到数据的数字标码,
            # 如上面的8973,根据key=8973的值进行排序,返回对所有的图片路径进行排序后返回
            imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))
        else:
            # 如果是进行训练,截取得到数据的数字标识,
            # 如上面的10004,根据key=10004的值进行排序,返回对所有的图片路径进行排序后返回
            imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))  
            
        imgs_num = len(imgs) # 然后就可以得到数据的大小
        
        # 划分训练、验证集,验证:训练 = 3:7
        if self.test:
            self.imgs = imgs
        elif train:
            self.imgs = imgs[:int(0.7*imgs_num)]
        else :
            self.imgs = imgs[int(0.7*imgs_num):]            
    
        if transforms is None:
        
            # 数据转换操作,测试验证和训练的数据转换有所区别
	        
            normalize = T.Normalize(mean = [0.485, 0.456, 0.406], 
                                     std = [0.229, 0.224, 0.225])

            # 测试集和验证集
            if self.test or not train: 
                self.transforms = T.Compose([
                    T.Resize(224), #重新设定大小
                    T.CenterCrop(224), #从图片中心截取
                    T.ToTensor(), #转成Tensor格式,大小范围为[0,1]
                    normalize #归一化处理,大小范围为[-1,1]
                ]) 
            # 训练集
            else :
                self.transforms = T.Compose([
                    T.Resize(256),
                    T.RandomReSizedCrop(224), #从图片的任何部位随机截取224*224大小的图
                    T.RandomHorizontalFlip(), #随机水平翻转给定的PIL.Image,翻转概率为0.5
                    T.ToTensor(),
                    normalize
                ]) 
                
        
    def __getitem__(self, index):
        """
        返回一张图片的数据
        对于测试集,没有label,返回图片id,如1000.jpg返回1000
        """
        img_path = self.imgs[index]
        if self.test: 
          	 #如果是测试,得到图片路径中的数字标识作为label
       		 label = int(self.imgs[index].split('.')[-2].split('/')[-1])
        else: 
          	 #如果是训练,判断图片路径中是猫狗来设定label,猫为0,狗为1
             label = 1 if 'dog' in img_path.split('/')[-1] else 0
        data = Image.open(img_path) #打开该路径获得数据
        data = self.transforms(data) #然后对图片数据进行transform
        return data, label #最后得到统一的图片信息和label信息
    
    def __len__(self):
        """
        返回数据集中所有图片的个数
        """
        return len(self.imgs)

关于数据集使用的注意事项:

将文件读取等费时操作放在 __getitem__ 函数中,利用多进程加速。避免一次性将所有图片都读进内存,不仅费时也会占用较大内存,而且不易进行数据增强等操作。

另外在这里,将训练集中的30%作为验证集,可用来检查模型的训练效果,避免过拟合。在使用时,可通过 dataloader 加载数据。

#进行训练,定义训练数据和label集合train_dataset
train_dataset = DogCat(opt.train_data_root, train=True) 
#开始进行数据的加载,将数据打乱(shuffle = True),
#随机将数据分批,一批有opt.batch_size个,并行处理,打开opt.num_workers个进程
trainloader = DataLoader(train_dataset,
                        batch_size = opt.batch_size,
                        shuffle = True,
                        num_workers = opt.num_workers)
#显示得到的数据                  
for ii, (data, label) in enumerate(trainloader):
    train() #然后进行训练
展开阅读更多
来源:孟繁阳的博客
本篇文章来源于孟繁阳的博客,如需转载,请注明出处,谢谢配合。
孟繁阳的博客 » Pytorch学习笔记(二):载入数据
Loading...

发表评论

表情
图片 链接 代码

孟繁阳的博客

博客首页 文章归档