​ 本博客是作者参考原文代码进行的代码重构的学习笔记。

一、ANT的定义

1、每个节点的定义

ANT中的每个节点都被定义成两部分的数据结构:节点图信息字典(meta)和节点网络模块(module)。

meta, module = define_node(args=args,node_index=0, level=0, parent_index=-1, tree_struct=tree_struct,)

其中meta的定义如下:

meta = {'index': node_index,  # 当前节点在数组中的序号,根节点为0
        'parent': parent_index,  # 当前节点的父节点的序号,根节点的父节点为-1
        'left_child': 0,
        'right_child': 0,
        'level': level, # 当前节点在树结构中的层,根节点为0
        'is_leaf': True, # 当前节点是否是叶节点,根节点初始化时默认True
        'visited': False,# 当前节点是否可访问,根节点初始化时默认False
        
        'extended': False,
        'split': False,
        
        
        'train_accuracy_gain_split': -np.inf,
        'valid_accuracy_gain_split': -np.inf,
        'test_accuracy_gain_split': -np.inf,
        'train_accuracy_gain_ext': -np.inf,
        'valid_accuracy_gain_ext': -np.inf,
        'test_accuracy_gain_ext': -np.inf,
        'num_transforms': num_transforms
       }

其中module的定义如下:

module = {'transform': transformer,
          'classifier': solver,
          'router': router
         }

2、ANT的定义

ANT的定义为若干个节点的结合:

tree_struct = []  # 存储每个节点的图信息
tree_modules = []  # 存储每个节点的网络模块

3、将ANT转换为网络模型Tree

ANT的信息存储是两个字典列表,但是为了方便进行前向传播和反向传播,我们定义一个统一的模型类。

该模型类由以下部分组成:

(1)获取所有的叶子节点(self.leaves_list)

self.leaves_list = self.get_leaf_nodes(tree_struct) # 本质是一个整数列表,每个整数表示节点的序号
# 遍历ANT每个节点的meta[is_leaf]
def get_leaf_nodes(self,struct):
    """ Get the list of leaf nodes.
        """
    leaf_list = []
    for idx, node in enumerate(struct):
        if node['is_leaf']:
            leaf_list.append(idx)
    return leaf_list

(2)由叶节点遍历获得所有的预测路径(self.paths_list)

self.paths_list = [self.get_path_to_root(i, tree_struct) for i in self.leaves_list] # 二维数组,每个元素(root-leaf,is_left_child)
def get_path_to_root(self,node_idx, struct):
    paths_list = [] # 整数数组,每个数表示节点,该序号组成一条root-leaf路径
    left_child_status = [] # bool列表,是否属于左节点
    while node_idx >= 0:
    # 记录当前节点是否是左节点
    if node_idx > 0:  # ignore parent node
    lcs = self.get_left_or_right(node_idx, struct) # 获取当前节点是否是左节点
    left_child_status.append(lcs)
    # 记录当前节点
    paths_list.append(node_idx)
    node_idx = self.get_parent(node_idx, struct) # 获取父节点的序号

    paths_list = paths_list[::-1] # 反序,root-leaf
    left_child_status = left_child_status[::-1] # 反序,root-leaf
    return paths_list, left_child_status

(3)定义每个节点的三个网络模块(self.tree_modules)

# 获取所有的节点模块
self.tree_modules = nn.ModuleList()
for i, node in enumerate(tree_modules):
    node_modules = nn.Sequential()
    node_modules.add_module('transform', node["transform"])
    node_modules.add_module('router', node["router"])
    node_modules.add_module('classifier', node["classifier"])
    self.tree_modules.append(node_modules)

4、Tree的前向传播

(1)一条路径的前向传播


    def node_pred(self, input, nodes, edges):
        """ Perform prediction on a given node given its path on the tree.
        e.g.
        nodes = [0, 1, 4, 10]
        edges = [True, False, False]
        """
        prob = 1.0 # 叶节点分配概率
        # 每条路径的内部节点进行前向传播
        for node, state in zip(nodes[:-1], edges):
            input = self.tree_modules[node].transform(input) # 变换数据
            # 是否是左路径
            if state:
                prob = prob * self.tree_modules[node].router(input)
            else:
                prob = prob * (1.0 - self.tree_modules[node].router(input))

        if not (isinstance(prob, float)):
            prob = torch.unsqueeze(prob, 1) # 升维
        # 每条路径的叶子节点进行前向传播
        node_final = nodes[-1]
        input = self.tree_modules[node_final].transform(input)

        # 最后一个节点的预测分类概率
        y_pred = prob * torch.exp(self.tree_modules[node_final].classifier(input))

        return y_pred

(2)所有路径的前向传播

    def forward(self, input):
        y_pred = 0.0 # 整个模型的预测概率
        prob_last = None
        # 每条路径进行一次预测
        for (nodes, edges) in self.paths_list:
            y_pred += self.node_pred(input, nodes, edges)

        if self.training:
        else:
            return torch.log(1e-10 + y_pred)

二、ANT的固定优化

定义根节点

# 定义根节点
root_meta, root_module = define_node(args=args,node_index=0, level=0, parent_index=-1, tree_struct=tree_struct,)
tree_struct.append(root_meta)
tree_modules.append(root_module)

设置模型的参数是否需要求梯度的属性(根据是否增长判断哪些参数需要更新)

params, names = get_params_node(grow, node_idx,  model)
    for i, (n, p) in enumerate(model.named_parameters()):
        if not(n in names):
            # print('(Fix)   ' + n)
            p.requires_grad = False
        else:
            # print('(Optimize)     ' + n)
            p.requires_grad = True

    for i, p in enumerate(params):
        if not(p.requires_grad):
            print("(Grad not required)" + names[i])

定义优化器和学习率调整器

optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, params), lr=args.lr,
    )
if args.scheduler:
    scheduler = get_scheduler(args.scheduler, optimizer, grow)

训练no_epochs个回合(结合早期终止机制)

# 训练no_epochs个回合
min_improvement = 0.0
valid_loss = np.inf
patience_cnt = 1
for epoch in range(1, no_epochs + 1):
    train(model, train_loader, optimizer, node_idx) # 运行一个回合
    valid_loss_new = valid(model, valid_loader, node_idx, tree_struct) # 验证一个回合,将最好的模型保存下来
    scheduler.step() # 更新学习率
    vtest(model, test_loader)
    # 早期终止机制
    # 如果新损失值>旧损失值并且增长,则计数一个回合
    if not((valid_loss-valid_loss_new) > min_improvement) and grow:
        patience_cnt += 1
    valid_loss = valid_loss_new*1.0
        
    if patience_cnt > args.epochs_patience > 0:
       print('Early stopping')
       break

三、ANT的增长阶段

ANT的增长阶段本质就是进行3种增长手段,看看哪一种增长后性能会变好,伪代码解读如下:

nextind = 1
last_node = 0
# 遍历每一层
for lyr in range(args.maxdepth):
    # 遍历树中的每个节点
	for node_idx in range(len(tree_struct)):
        change = False # 用来剪枝,提高算法效率
        # 如果当前节点是叶节点并且没有访问过
        if tree_struct[node_idx]['is_leaf'] and not(tree_struct[node_idx]['visited']):
            # 定义两个子节点,并将原叶节点的属性赋给新叶节点
            meta_l, node_l = define_node(
                	args,
                    node_index=nextind, level=lyr+1,parent_index=node_idx, 
                	tree_struct=tree_struct,identity=identity,
                )
             meta_r, node_r = define_node(
                    args,
                    node_index=nextind+1, level=lyr+1,parent_index=node_idx, 
                 	tree_struct=tree_struct,identity=identity,
                )
            node_l['classifier'] = tree_modules[node_idx]['classifier']
            node_r['classifier'] = tree_modules[node_idx]['classifier']
            #定义一个新的树模型,新的树模型具有新的叶节点网络模块
            model_split = Tree(tree_struct, tree_modules,
                               split=True, node_split=node_idx,
                               child_left=node_l, child_right=node_r)
            # 优化更新该模型
            model_split, tree_modules_split, node_l, node_r=optimize_fixed_tree(model_split, ...)
            # 获取决策,是否更新树结构
            criteria = get_decision(args.criteria, node_idx, tree_struct)
            # 加入更新该结构:
            if criteria == 'split':
                # 更新父节点的的属性
                tree_struct[node_idx]['is_leaf'] = False
                tree_struct[node_idx]['left_child'] = nextind
                tree_struct[node_idx]['right_child'] = nextind+1
                tree_struct[node_idx]['split'] = True
                # 添加子节点
                tree_struct.append(meta_l)
                tree_modules_split.append(node_l)
                tree_struct.append(meta_r)
                tree_modules_split.append(node_r)
                # 更新树节点的模块
                tree_modules = tree_modules_split
                # 标记当前节点已经访问过
                tree_struct[node_idx]['visited'] = True
                # 更新节点序号
                nextind += 2
                change = True
                # 保存模型和树结构
                checkpoint_model('model.pth', ...)
     # 如果模型不需要发生结构变化则不再进行增长
     if not change: break

如何判断一个节点是否需要变化:

# 获取决策
def get_decision(criteria, node_idx, tree_struct):
    if criteria == 'always':  # always split or extend
        # 如果节点扩展的精度>节点分裂的精度>0.0,则选择'extend'
        if tree_struct[node_idx]['valid_accuracy_gain_ext'] > tree_struct[node_idx]['valid_accuracy_gain_split'] > 0.0:
            return 'extend'
        # 否则选择'split'
        else:
            return 'split'
    elif criteria == 'avg_valid_loss':
        # 如果节点扩展的精度>节点分裂的精度并且节点扩展的精度>0.0,则选择'extend'
        if tree_struct[node_idx]['valid_accuracy_gain_ext'] > tree_struct[node_idx]['valid_accuracy_gain_split'] and \
                        tree_struct[node_idx]['valid_accuracy_gain_ext'] > 0.0:
            print("Average valid loss is reduced by {} ".format(tree_struct[node_idx]['valid_accuracy_gain_ext']))
            return 'extend'
		# 如果节点分裂的精度>0.0,则选择'split'
        elif tree_struct[node_idx]['valid_accuracy_gain_split'] > 0.0:
            print("Average valid loss is reduced by {} ".format(tree_struct[node_idx]['valid_accuracy_gain_split']))
            return 'split'
		# 保持原样
        else:
            print("Average valid loss is aggravated by split/extension."
                  " Keep the node as it is.")
            return 'keep'
    else:
        raise NotImplementedError(   
            "specified growth criteria is not available. ",
        )