数据pipeline优化-2

之前因为服务器性能原因,对ImageNet数据集进行训练时出现数据加载瓶颈,当时搜索了多种方式,尝试对数据pipeline进行优化,从而提高数据加载能力。详情参见数据pipeline优化

这一次遇到了新的问题,就是在千万级别数据训练情况下如何在固定内存空间和其他硬件性能的情况下提高数据加载和预处理能力。

章节

  1. 图像读取
  2. 图像预处理
  3. 多进程数据加载

图像读取

图像读取耗费了不少时间,目前Python中常用的图像读取方式有:

  1. PIL

    1
    2
    from PIL import Image
    img = Image.open('...')

  2. OpenCV

    1
    2
    import cv2
    img = cv2.imread('...')

  3. imageio

    1
    2
    import imageio
    img = imageio.imread('...')

  4. Scipy

    1
    2
    from scipy.misc import imread
    img = imread('...')

文章Python 图像读写谁最快?不信就比一比对上面几种方式进行了对比,发现OpenCV效率最高;

另外文章0.伏笔:图像读取方式以及效率对比比较了MXNet提供的读取函数,发现其效率最高。

我也比较了上面5种读取实现

  1. 如果不关心后续的使用,那么PIL的读取效果最好;
  2. 如果后续需要np.ndarray,那么OpenCV效果最均衡,在各个环境下都能实现非常好的效果。

图像预处理

Pytorch提供的torchvision库使用PIL为后端提供了不少的图像预处理操作,但是其实现速度不如其他图像库。虽然torchvision也支持Pillow后端,但是相比于其他实现还是约有差别。

之前虽然也找过相关的图像预处理库,但是并没有真的去使用(毕竟torchvision的使用体验还是棒棒的)。这一次花了不少时间去熟悉和使用albumentation,自己也做了一些包装类,以更符合torchvision的使用体验,详见transforms/realization

多进程数据加载

除了上述的优化外,最大的一个改进就是实现了一个自定义数据加载类zjykzj/MPDataset

使用Pytorch进行训练,可以通过设置DataLoadernum_workers参数来实现多进程数据加载。虽然实际使用时仅需其中的一小部分,但是每个子进程都会拷贝一份完整的数据集数据,内存也因此约束了进程数目扩充。随着数据量的大大增加,单个进程所占据的内存大小也越来越大,这导致了能够设置的进程数变少了,从而减慢了训练速度。

文章RankDataset:超大规模数据集加载利器 [炼丹炉番外篇-1]也遇到了这个情况,它给出了几种思路

  1. 分离数据加载和数据训练服务器,通过网络传输方式逐次请求训练数据;
  2. 对数据进行分片操作,每个进程保留各自所需的数据量。每轮训练后重新创建数据加载器。

另外,也在pytorch相关网站上找到一些讨论

它们的解决思路是依据Python实现来规范你的数据加载方式,比如使用Numpy/Tensor而不是list,这样能够保存最大程度的在多进程中共享内存,从而减少内存占用。

上面两种方式都进行了尝试,对于第二种,并没有发现有效(比如用numpy/pandas来替代list的使用),实际上反而内存占用提高了。而对于第一种方式,我也很认同它的在各自进程中保留各自使用数据的思路,原文的实现方式比较简单,而且不太符合自己的使用场景,所以打算自定义一个数据加载类来实现多进程数据加载。

在实现过程中,发现Pytorch提供了两种数据类,一类是map-style,也是最常用的,实现__getitem____len__函数即可自定义数据集,接着通过DataLoader类和Sampler类进行数据加载和数据采样(这种情况下数据采样器在主进程进行操作);另外一种是iterable-style,从v1.2开始,Pytorch支持通过生成器方式进行数据加载,在这种情况下,数据采样可以在各自进程中进行操作,这天然符合我的多进程数据加载需求。相关文章详见torch.utils.data

自定义多进程数据加载类的实现难点如下:

  1. 如何在各种环境下独立采样数据(顺序/随机、单进程/多进程、单卡/多卡),又能够保证相互不重叠;
  2. 如何从整体数据中分离出各自数据的同时保存内存占用的稳定。

针对第一点,参考了Sampler类实现;针对第二点,通过自定义数据文件格式(每行表示一个图片路径以及对应标签值)以及采用文件遍历的方式(将文件读取器作为生成器进行遍历)来保证内存稳定。完整的实现可以参考zjykzj/MPDataset

小结

经过上述的优化之后,能够有效的提高图像读取和预处理速度,同时能够有效的保证多进程数据加载时内存的利用,在实际工作中能够很好的负载千万级数据量的训练。