深度学习中通常有大量的实验参数需要设置,有时候需要做大量的对比实验,每次实验的参数记录就变得很重要。

​ 本文介绍作者目前见到的两种比较常见的读取配置参数的方式,然后给出个人对于读取器的封装。

一、两种读取配置参数的方式

我们介绍命令行传入参数的方式和配置文件传入参数的方式。

1、使用命令行参数配置程序

argsparse是python的命令行解析的标准模块,内置于python,不需要安装。这个库可以让我们直接在命令行中就可以向程序中传入参数并让程序运行。

当我们不需要经常性地打开IDE时,就可以直接在命令行来启动程序并且输入参数来配置程序。

其基本用法如下:

import argparse
def get_options():
    # 参数解析
    parse = argparse.ArgumentParser()
    parse.add_argument("--action", type=str, default='test',help="train or test")
    parse.add_argument("--batch_size", type=int, default=1)
    args = parse.parse_args()
    return args

options = get_options()
print("options:",options," type:",type(options))
print("options.action:",options.action," type:",type(options.action))
print("options.batch_size:",options.batch_size," type:",type(options.batch_size))

我们可以先在命令行中输入 python demo.py –help 来查看程序的配置参数用法:

usage: baseOption.py [-h] [–action ACTION] [–batch_size BATCH_SIZE]

optional arguments:

-h, –help show this help message and exit

–action ACTION train or test –batch_size BATCH_SIZE

然后输入相应参数:python demo.py –action=train

options: Namespace(action=’train’, batch_size=1) type: <class ‘argparse.Namespace’> options.action: train type: <class ‘str’> options.batch_size: 1 type: <class ‘int’>

说明上述打印成功,而且获取相应参数简单,不需要进行额外的数据类型转换。

2、使用yaml文件配置程序

(1)yaml文件的数据结构

YAML是一种简洁的以数据为中心的非标记语言,在深度学习中经常用来作为训练参数的配置文件。

其文件(config.yaml)的数据结构表示如下:

project: gan
experiment_desc: fpn

optimizer:
  name: adam
  lr: 0.0001
scheduler:
  name: linear
  start_epoch: 50
  min_lr: 0.0000001

可以看出其内容为“键值对”的形式,“键值对”可以嵌套“键值对”。

(2)读取yaml文件

# 读取yaml文件数据
file = open('config.yaml', 'r', encoding="utf-8")
file_data = file.read()
file.close()

# 将yaml文件数据转化为字典
data = yaml.load(file_dataLoader=yaml.FullLoader)
print(data)

打印结果为一个Python字典:

{‘project’: ‘deblur_gan’, ‘experiment_desc’: ‘fpn’, ‘optimizer’: {‘name’: ‘adam’, ‘lr’: 0.0001}, ‘scheduler’: {‘name’: ‘linear’, ‘start_epoch’: 50, ‘min_lr’: 1e-07}}

封装为一个函数:

# 读取yaml配置文件
def get_yaml_data(yaml_file):
    import yaml
    # 打开yaml文件
    with open(yaml_file, 'r', encoding="utf-8") as file:
        data = yaml.load(file, Loader=yaml.FullLoader) # 转化yaml数据为字典或列表,此处为字典
    print("成功获取yaml文件配置数据")
    return data
    

(3)获取配置参数的常见方式

batch_size = config.pop('batch_size') # 取出该参数并且删除,这个字典的方法

二、配置参数读取器的设计

我们的读取器目前只封装命令行方式,但是会默认将命令行参数保存为一个单独文件夹,并且针对深度学习多次运行实验的特点,每次运行都根据其主题参数theme来生成对应的文件夹,另外也包含了一些设置随机种子的常见操作。

1、读取器的整体设计

首先我们给出基于Python的Optioner()类的简略定义:

import argparse
from torch.utils.data import DataLoader
class Optioner():
    def __init__(self):self.options = self._set_options()
    # 【外部接口】,用来获取配置参数
    def get_options(self,theme:str,seed:int):pass
    # 【内部接口】,负责初始化配置参数
    def _set_options(self):pass
    # 【内部接口】,负责设置随机种子,非必要
    def _set_seed(self,seed):pass
    # 【内部接口】,负责设置输出路径,非必要
    def _check_output_path(self,options,use_time=False):pass
    # 【内部接口】,负责保存配置参数,非必要
    def _save_config_yml(self,options,is_print=False):pass

使用该配置器获取参数:

if __name__=="__main__":
    options = Optioner().get_options(theme='test',seed=0)
    print("options:", options, " type:", type(options))
    print("options.action:", options.action, " type:", type(options.action))
    print("options.batch_size:", options.batch_size, " type:", type(options.batch_size))

2、训练器的局部设计

(1)获取配置参数

def get_options(self,theme:str,seed:int):
    # 重新设置重要参数
    self.options.theme = theme
    self.options.seed = seed
    # 执行必要操作
    self._set_seed(self.options.seed) # 设置随机种子
    self._check_output_path(self.options) # 设置输出路径
    self._save_config_yml(self.options) # 保存配置文件
    return self.options

(2)初始化配置参数

def _set_options(self):
    p = argparse.ArgumentParser()  # 创建解析对象
    # 实验输出选项
    p.add_argument('--theme', type=str, default='deblur', help='实验主题')
    p.add_argument("--action", type=str, default='train', help="train 或者 test")
    p.add_argument('--device', type=str, default='cuda', help='计算设备名称:cpu or cuda:0')
    p.add_argument('--seed', type=int, default=0, help='实验环境的随机种子')
    p.add_argument('--output_path', type=str, default='./runs/',  help='输出路径')
    p.add_argument('--print_freq', type=int, default=1, help='训练时的输出间隔,单位:回合')
    p.add_argument('--save_freq', type=int, default=1, help='训练时的保存间隔,单位:回合')
    # 数据集配置选项
    p.add_argument('--batch_size', type=int, default=1,help='批次大小为128时可能出现loss为NAN的问题')
    # 训练配置选项
    p.add_argument('--num_epochs', type=int, default=10, help='训练的回合数')
    p.add_argument('--continus_train',type=bool, default=False,  help='是否继续训练')
    # 优化器配置选项
    p.add_argument('--optim_name', type=str, default='adam', help='优化器名称')
    p.add_argument('--lr', type=float, default=0.0001, help='学习率')
    # 学习率调度器配置选项
    p.add_argument('--scheduler_name', type=str, default='linear', help='学习率调度器名称')
    options = p.parse_args()  # 解析
    return options

(3)设置随机种子

    def _set_seed(self,seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

(4)检查输出路径

    def _check_output_path(self,options,use_time=False):
        # 记录时间
        #time_stamp = time.strftime("%F-%H-%M-%S")
        if use_time:
            import time
            time_stamp = time.strftime("%F")
            options.theme = time_stamp
        # 拼接文件路径
        output_path = os.path.join(options.output_path, options.theme)
        model_path = os.path.join(output_path, 'model')
        # 创建文件夹
        os.makedirs(options.output_path, exist_ok=True)
        os.makedirs(model_path, exist_ok=True)
        options.output_path = output_path
        options.model_path = model_path

(5)保存配置文件

    def _save_config_yml(self,options,is_print=False):
        # 写入配置信息
        with open(os.path.join(options.output_path, 'config.yml'), 'wt') as f:
            f.write('------------ Options -------------\n')
            f.write('\n'.join(["%s: %s" % (key, value) for key, value in vars(options).items()]))
            f.write('\n-------------- End ----------------\n')
        # 输出配置信息
        if is_print:
            for key,value in vars(options).items():
                print('%s: %s' % (key, value))

###