pytorch | 记录一次register_hook不起作用以及为什么
0. 问题描述
register_hook
用于给某个tensor注册hooks,- 笔者这个时候loss不收敛,debug发现梯度为0,因此通过加钩子,试图发现在传播时哪里出了问题。因此发现了
register_hook
并不是100%能work,而且不是100%的可以打印出grad
1. 探究过程
- 钩子函数:
def save_grad(name):print("") # 这行可以说明有没有执行钩子函数def hook(grad):print(f"name={name}, grad={grad}")return hook
- 注册过程
# U_head是loss function中的一个中间tensor,需要计算梯度
U_head.register_hook(save_grad("U_head"))
2. 不work的原因
-
来自Stack Overflow的建议:
- ‘register_hook’ won’t only in two cases:
- It was registered on a Tensor for which the gradients was never computed./ 梯度没计算
- The register_hook function is some part of your code that did not run during the forward. / 正向传播没执行这条语句
- ‘register_hook’ won’t only in two cases:
-
因此结合了自己的代码,修正了如下几个bug,然后就work了
- 最关键的修改:
for t in idx:# work了U_tail= torch.cat([U_head, f_Equi_ts(t, T,v, d, b, alpha, labda,device)])# 这样写不work,莫非属于原地修改?# U_tail = torch.tensor([f_Equi_ts(t, T,v, d, b, alpha, labda,device) for t in idx],requires_grad=True)
- 对于 需要loss function的代码重写,尽可能简单,易读性高,先不追求效率