一、加载数据集方式

torchvision.datasets 用来加载预处理数据集,有些数据集是自带的,我们可以直接加载,例如:

# 定义数据预处理操作
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载数据集
trainset = datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)
testset = datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)

但是大部分数据集是我们所没有的,所以就需要我们自定义数据集,该数据集一般要继承自Dataset类:

from torch.utils.data import Dataset
class PairedDataset(Dataset):
    def __init__(self):pass
    def __getitem__(self, index):pass
    def __len__(self):pass

​ 然后使用如下DataLoader读取dataset:

# 导入数据
trainloader = DataLoader(trainset, batch_size=4, shuffle=True, num_workers=0)
testloader = DataLoader(testset, batch_size=4,shuffle=False, num_workers=0)
# 输出信息
print("训练集一共有{}/{}={}个的批次,其中{}是mini-batch".format(len(trainset),4,len(trainloader),4))
print("测试集一共有{}/{}={}个的批次,其中{}是mini-batch".format(len(testset),4,len(testloader),4))

二、封装PairedDataset类

1、获取所有图片的路径

大多数数据集在读取图像后要使用一些图像处理手段,常见的是使用torchvision库里面的API:

import glob
input_file_path = './dataset/train/a/*.png'
target_file_path = './dataset/train/b/*.png'
# 获取样本文件的所有路径
def _get_dataset_path(self,input_file_path,target_file_path):
    files_a =sorted(glob.glob(input_file_path,recursive=True))
    files_b =sorted(glob.glob(target_file_path, recursive=True))
    assert len(files_a) == len(files_b)
    return files_a, files_b

2、数据集的初始化方法

def __init__(self,
             input_file_path: str,  # 输入文件所在的路径
             target_file_path: str,  # 输出文件所在的路径
             preprocess_fn: Callable, # 预处理手段
            ):
    self.preload =False
    # 获取所有图片的路径
    self.files_path_a,self.files_path_b = self._get_dataset_path(input_file_path,target_file_path)
    self.len = len(self.files_path_a)
    # 获取图像预处理函数
    self.preprocess_fn = preprocess_fn
    print(f'含有{self.len} 个样本的数据集已被创建')

3、数据集的初始化方法

def __getitem__(self, index):
    # 获取一组样本的路径
    a_path, b_path = self.files_path_a[index % self.len], self.files_path_b[index % self.len]
    # 读取该组样本的图片
    a, b = map(self._read_img, (a_path, b_path))
    a,b = self.preprocess_fn(a,b)
    # 向量化和归一化
    a_tensor,b_tensor =map(self._normalize(),(a,b))
    return {'a': a_tensor, 'b': b_tensor}
def __len__(self):
    return self.len

4、数据集的读取图片和向量化和归一化方法

  # 使用第三方图片库读取图片
    def _read_img(self, x: str):
        img = cv2.imread(x)
        if img is None:
            print(f'警告:无法使用OpenCV读取图片{x} ,正在切换为scikit-image进行读取...')
            img = io.imread(x)[:, :, ::-1]
        return img

    # 向量化和归一化
    def _normalize(self,):
        transform = transforms.Compose([
            transforms.ToTensor(), # 向量化,归一化
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        return transform

三、读取配置文件加载数据预处理手段

假设我们的预处理手段都从配置文件中读取,首先我们要将可能会用到的函数列举出来:

def _resolve_aug_fn(name):
    # 所有可能使用到的函数
    d = {
        'cutout': albu.Cutout,
        'rgb_shift': albu.RGBShift,
        'hsv_shift': albu.HueSaturationValue,
        'motion_blur': albu.MotionBlur,
        'median_blur': albu.MedianBlur,
        'snow': albu.RandomSnow,
        'shadow': albu.RandomShadow,
        'fog': albu.RandomFog,
        'brightness_contrast': albu.RandomBrightnessContrast,
        'gamma': albu.RandomGamma,
        'sun_flare': albu.RandomSunFlare,
        'sharpen': albu.Sharpen,
        'jpeg': albu.ImageCompression,
        'gray': albu.ToGray,
        'pixelize': albu.Downscale,
        # ToDo: partial gray
    }
    return d[name]

然后定义一个yaml文件来设置本次运行使用的函数:

# corrupt操作的函数字典列表
  corrupt: 
    - name: cutout
      prob: 0.5
      num_holes: 3
      max_h_size: 25
      max_w_size: 25
    - name: jpeg
      quality_lower: 70
      quality_upper: 90
    - name: motion_blur
    - name: median_blur
    - name: gamma
    - name: rgb_shift
    - name: hsv_shift
    - name: sharpen

然后我们将该参数读取到程序中:

if __name__=="__main__":
    with open('config1.yaml', 'r', encoding="utf-8") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    augs = []
    for aug_params in config['corrupt']:
        name = aug_params.pop('name') # 获取函数名称
        print(name)
        cls = _resolve_aug_fn(name) # 根据名称获取对应函数
        prob = aug_params.pop('prob') if 'prob' in aug_params else .5  # 尝试获取对应参数
        augs.append(cls(p=prob, **aug_params)) # 配置该函数并添加进列表
    print(len(augs))