本博客是作者参考原文代码进行的代码重构的学习笔记。
一、ANT的定义
1、每个节点的定义
ANT中的每个节点都被定义成两部分的数据结构:节点图信息字典(meta)和节点网络模块(module)。
meta, module = define_node(args=args,node_index=0, level=0, parent_index=-1, tree_struct=tree_struct,)
其中meta的定义如下:
meta = {'index': node_index, # 当前节点在数组中的序号,根节点为0
'parent': parent_index, # 当前节点的父节点的序号,根节点的父节点为-1
'left_child': 0,
'right_child': 0,
'level': level, # 当前节点在树结构中的层,根节点为0
'is_leaf': True, # 当前节点是否是叶节点,根节点初始化时默认True
'visited': False,# 当前节点是否可访问,根节点初始化时默认False
'extended': False,
'split': False,
'train_accuracy_gain_split': -np.inf,
'valid_accuracy_gain_split': -np.inf,
'test_accuracy_gain_split': -np.inf,
'train_accuracy_gain_ext': -np.inf,
'valid_accuracy_gain_ext': -np.inf,
'test_accuracy_gain_ext': -np.inf,
'num_transforms': num_transforms
}
其中module的定义如下:
module = {'transform': transformer,
'classifier': solver,
'router': router
}
2、ANT的定义
ANT的定义为若干个节点的结合:
tree_struct = [] # 存储每个节点的图信息
tree_modules = [] # 存储每个节点的网络模块
3、将ANT转换为网络模型Tree
ANT的信息存储是两个字典列表,但是为了方便进行前向传播和反向传播,我们定义一个统一的模型类。
该模型类由以下部分组成:
(1)获取所有的叶子节点(self.leaves_list)
self.leaves_list = self.get_leaf_nodes(tree_struct) # 本质是一个整数列表,每个整数表示节点的序号
# 遍历ANT每个节点的meta[is_leaf]
def get_leaf_nodes(self,struct):
""" Get the list of leaf nodes.
"""
leaf_list = []
for idx, node in enumerate(struct):
if node['is_leaf']:
leaf_list.append(idx)
return leaf_list
(2)由叶节点遍历获得所有的预测路径(self.paths_list)
self.paths_list = [self.get_path_to_root(i, tree_struct) for i in self.leaves_list] # 二维数组,每个元素(root-leaf,is_left_child)
def get_path_to_root(self,node_idx, struct):
paths_list = [] # 整数数组,每个数表示节点,该序号组成一条root-leaf路径
left_child_status = [] # bool列表,是否属于左节点
while node_idx >= 0:
# 记录当前节点是否是左节点
if node_idx > 0: # ignore parent node
lcs = self.get_left_or_right(node_idx, struct) # 获取当前节点是否是左节点
left_child_status.append(lcs)
# 记录当前节点
paths_list.append(node_idx)
node_idx = self.get_parent(node_idx, struct) # 获取父节点的序号
paths_list = paths_list[::-1] # 反序,root-leaf
left_child_status = left_child_status[::-1] # 反序,root-leaf
return paths_list, left_child_status
(3)定义每个节点的三个网络模块(self.tree_modules)
# 获取所有的节点模块
self.tree_modules = nn.ModuleList()
for i, node in enumerate(tree_modules):
node_modules = nn.Sequential()
node_modules.add_module('transform', node["transform"])
node_modules.add_module('router', node["router"])
node_modules.add_module('classifier', node["classifier"])
self.tree_modules.append(node_modules)
4、Tree的前向传播
(1)一条路径的前向传播
def node_pred(self, input, nodes, edges):
""" Perform prediction on a given node given its path on the tree.
e.g.
nodes = [0, 1, 4, 10]
edges = [True, False, False]
"""
prob = 1.0 # 叶节点分配概率
# 每条路径的内部节点进行前向传播
for node, state in zip(nodes[:-1], edges):
input = self.tree_modules[node].transform(input) # 变换数据
# 是否是左路径
if state:
prob = prob * self.tree_modules[node].router(input)
else:
prob = prob * (1.0 - self.tree_modules[node].router(input))
if not (isinstance(prob, float)):
prob = torch.unsqueeze(prob, 1) # 升维
# 每条路径的叶子节点进行前向传播
node_final = nodes[-1]
input = self.tree_modules[node_final].transform(input)
# 最后一个节点的预测分类概率
y_pred = prob * torch.exp(self.tree_modules[node_final].classifier(input))
return y_pred
(2)所有路径的前向传播
def forward(self, input):
y_pred = 0.0 # 整个模型的预测概率
prob_last = None
# 每条路径进行一次预测
for (nodes, edges) in self.paths_list:
y_pred += self.node_pred(input, nodes, edges)
if self.training:
else:
return torch.log(1e-10 + y_pred)
二、ANT的固定优化
定义根节点
# 定义根节点
root_meta, root_module = define_node(args=args,node_index=0, level=0, parent_index=-1, tree_struct=tree_struct,)
tree_struct.append(root_meta)
tree_modules.append(root_module)
设置模型的参数是否需要求梯度的属性(根据是否增长判断哪些参数需要更新)
params, names = get_params_node(grow, node_idx, model)
for i, (n, p) in enumerate(model.named_parameters()):
if not(n in names):
# print('(Fix) ' + n)
p.requires_grad = False
else:
# print('(Optimize) ' + n)
p.requires_grad = True
for i, p in enumerate(params):
if not(p.requires_grad):
print("(Grad not required)" + names[i])
定义优化器和学习率调整器
optimizer = optim.Adam(
filter(lambda p: p.requires_grad, params), lr=args.lr,
)
if args.scheduler:
scheduler = get_scheduler(args.scheduler, optimizer, grow)
训练no_epochs个回合(结合早期终止机制)
# 训练no_epochs个回合
min_improvement = 0.0
valid_loss = np.inf
patience_cnt = 1
for epoch in range(1, no_epochs + 1):
train(model, train_loader, optimizer, node_idx) # 运行一个回合
valid_loss_new = valid(model, valid_loader, node_idx, tree_struct) # 验证一个回合,将最好的模型保存下来
scheduler.step() # 更新学习率
vtest(model, test_loader)
# 早期终止机制
# 如果新损失值>旧损失值并且增长,则计数一个回合
if not((valid_loss-valid_loss_new) > min_improvement) and grow:
patience_cnt += 1
valid_loss = valid_loss_new*1.0
if patience_cnt > args.epochs_patience > 0:
print('Early stopping')
break
三、ANT的增长阶段
ANT的增长阶段本质就是进行3种增长手段,看看哪一种增长后性能会变好,伪代码解读如下:
nextind = 1
last_node = 0
# 遍历每一层
for lyr in range(args.maxdepth):
# 遍历树中的每个节点
for node_idx in range(len(tree_struct)):
change = False # 用来剪枝,提高算法效率
# 如果当前节点是叶节点并且没有访问过
if tree_struct[node_idx]['is_leaf'] and not(tree_struct[node_idx]['visited']):
# 定义两个子节点,并将原叶节点的属性赋给新叶节点
meta_l, node_l = define_node(
args,
node_index=nextind, level=lyr+1,parent_index=node_idx,
tree_struct=tree_struct,identity=identity,
)
meta_r, node_r = define_node(
args,
node_index=nextind+1, level=lyr+1,parent_index=node_idx,
tree_struct=tree_struct,identity=identity,
)
node_l['classifier'] = tree_modules[node_idx]['classifier']
node_r['classifier'] = tree_modules[node_idx]['classifier']
#定义一个新的树模型,新的树模型具有新的叶节点网络模块
model_split = Tree(tree_struct, tree_modules,
split=True, node_split=node_idx,
child_left=node_l, child_right=node_r)
# 优化更新该模型
model_split, tree_modules_split, node_l, node_r=optimize_fixed_tree(model_split, ...)
# 获取决策,是否更新树结构
criteria = get_decision(args.criteria, node_idx, tree_struct)
# 加入更新该结构:
if criteria == 'split':
# 更新父节点的的属性
tree_struct[node_idx]['is_leaf'] = False
tree_struct[node_idx]['left_child'] = nextind
tree_struct[node_idx]['right_child'] = nextind+1
tree_struct[node_idx]['split'] = True
# 添加子节点
tree_struct.append(meta_l)
tree_modules_split.append(node_l)
tree_struct.append(meta_r)
tree_modules_split.append(node_r)
# 更新树节点的模块
tree_modules = tree_modules_split
# 标记当前节点已经访问过
tree_struct[node_idx]['visited'] = True
# 更新节点序号
nextind += 2
change = True
# 保存模型和树结构
checkpoint_model('model.pth', ...)
# 如果模型不需要发生结构变化则不再进行增长
if not change: break
如何判断一个节点是否需要变化:
# 获取决策
def get_decision(criteria, node_idx, tree_struct):
if criteria == 'always': # always split or extend
# 如果节点扩展的精度>节点分裂的精度>0.0,则选择'extend'
if tree_struct[node_idx]['valid_accuracy_gain_ext'] > tree_struct[node_idx]['valid_accuracy_gain_split'] > 0.0:
return 'extend'
# 否则选择'split'
else:
return 'split'
elif criteria == 'avg_valid_loss':
# 如果节点扩展的精度>节点分裂的精度并且节点扩展的精度>0.0,则选择'extend'
if tree_struct[node_idx]['valid_accuracy_gain_ext'] > tree_struct[node_idx]['valid_accuracy_gain_split'] and \
tree_struct[node_idx]['valid_accuracy_gain_ext'] > 0.0:
print("Average valid loss is reduced by {} ".format(tree_struct[node_idx]['valid_accuracy_gain_ext']))
return 'extend'
# 如果节点分裂的精度>0.0,则选择'split'
elif tree_struct[node_idx]['valid_accuracy_gain_split'] > 0.0:
print("Average valid loss is reduced by {} ".format(tree_struct[node_idx]['valid_accuracy_gain_split']))
return 'split'
# 保持原样
else:
print("Average valid loss is aggravated by split/extension."
" Keep the node as it is.")
return 'keep'
else:
raise NotImplementedError(
"specified growth criteria is not available. ",
)