一、加载数据集方式
torchvision.datasets 用来加载预处理数据集,有些数据集是自带的,我们可以直接加载,例如:
# 定义数据预处理操作
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载数据集
trainset = datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)
testset = datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
但是大部分数据集是我们所没有的,所以就需要我们自定义数据集,该数据集一般要继承自Dataset类:
from torch.utils.data import Dataset
class PairedDataset(Dataset):
def __init__(self):pass
def __getitem__(self, index):pass
def __len__(self):pass
然后使用如下DataLoader读取dataset:
# 导入数据
trainloader = DataLoader(trainset, batch_size=4, shuffle=True, num_workers=0)
testloader = DataLoader(testset, batch_size=4,shuffle=False, num_workers=0)
# 输出信息
print("训练集一共有{}/{}={}个的批次,其中{}是mini-batch".format(len(trainset),4,len(trainloader),4))
print("测试集一共有{}/{}={}个的批次,其中{}是mini-batch".format(len(testset),4,len(testloader),4))
二、封装PairedDataset类
1、获取所有图片的路径
大多数数据集在读取图像后要使用一些图像处理手段,常见的是使用torchvision库里面的API:
import glob
input_file_path = './dataset/train/a/*.png'
target_file_path = './dataset/train/b/*.png'
# 获取样本文件的所有路径
def _get_dataset_path(self,input_file_path,target_file_path):
files_a =sorted(glob.glob(input_file_path,recursive=True))
files_b =sorted(glob.glob(target_file_path, recursive=True))
assert len(files_a) == len(files_b)
return files_a, files_b
2、数据集的初始化方法
def __init__(self,
input_file_path: str, # 输入文件所在的路径
target_file_path: str, # 输出文件所在的路径
preprocess_fn: Callable, # 预处理手段
):
self.preload =False
# 获取所有图片的路径
self.files_path_a,self.files_path_b = self._get_dataset_path(input_file_path,target_file_path)
self.len = len(self.files_path_a)
# 获取图像预处理函数
self.preprocess_fn = preprocess_fn
print(f'含有{self.len} 个样本的数据集已被创建')
3、数据集的初始化方法
def __getitem__(self, index):
# 获取一组样本的路径
a_path, b_path = self.files_path_a[index % self.len], self.files_path_b[index % self.len]
# 读取该组样本的图片
a, b = map(self._read_img, (a_path, b_path))
a,b = self.preprocess_fn(a,b)
# 向量化和归一化
a_tensor,b_tensor =map(self._normalize(),(a,b))
return {'a': a_tensor, 'b': b_tensor}
def __len__(self):
return self.len
4、数据集的读取图片和向量化和归一化方法
# 使用第三方图片库读取图片
def _read_img(self, x: str):
img = cv2.imread(x)
if img is None:
print(f'警告:无法使用OpenCV读取图片{x} ,正在切换为scikit-image进行读取...')
img = io.imread(x)[:, :, ::-1]
return img
# 向量化和归一化
def _normalize(self,):
transform = transforms.Compose([
transforms.ToTensor(), # 向量化,归一化
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
return transform
三、读取配置文件加载数据预处理手段
假设我们的预处理手段都从配置文件中读取,首先我们要将可能会用到的函数列举出来:
def _resolve_aug_fn(name):
# 所有可能使用到的函数
d = {
'cutout': albu.Cutout,
'rgb_shift': albu.RGBShift,
'hsv_shift': albu.HueSaturationValue,
'motion_blur': albu.MotionBlur,
'median_blur': albu.MedianBlur,
'snow': albu.RandomSnow,
'shadow': albu.RandomShadow,
'fog': albu.RandomFog,
'brightness_contrast': albu.RandomBrightnessContrast,
'gamma': albu.RandomGamma,
'sun_flare': albu.RandomSunFlare,
'sharpen': albu.Sharpen,
'jpeg': albu.ImageCompression,
'gray': albu.ToGray,
'pixelize': albu.Downscale,
# ToDo: partial gray
}
return d[name]
然后定义一个yaml文件来设置本次运行使用的函数:
# corrupt操作的函数字典列表
corrupt:
- name: cutout
prob: 0.5
num_holes: 3
max_h_size: 25
max_w_size: 25
- name: jpeg
quality_lower: 70
quality_upper: 90
- name: motion_blur
- name: median_blur
- name: gamma
- name: rgb_shift
- name: hsv_shift
- name: sharpen
然后我们将该参数读取到程序中:
if __name__=="__main__":
with open('config1.yaml', 'r', encoding="utf-8") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
augs = []
for aug_params in config['corrupt']:
name = aug_params.pop('name') # 获取函数名称
print(name)
cls = _resolve_aug_fn(name) # 根据名称获取对应函数
prob = aug_params.pop('prob') if 'prob' in aug_params else .5 # 尝试获取对应参数
augs.append(cls(p=prob, **aug_params)) # 配置该函数并添加进列表
print(len(augs))