模型权重不匹配的解决思路——PyTorch预训练权重shape修改
前言
- 最近太忙没时间细致的写,所以这篇讲的是思想,不是实际代码,大家意会一下。
- 因为GPU有限,这篇文章讲的有些东西只是我自己的观点,缺乏实验证明,欢迎大家批评指正。
- 今天题图在实验室用轨迹球画的,比较艺术一点。
为什么要修改预训练权重shape?
- 预训练能大幅提升精度是毋庸置疑的,但间断加载预训练权重有效吗?
- 我试过一些网络结构的修改,这些都会使得预训练权重shape mismatch,其中最简单的一个修改是将输入的
(3, 224, 224)
改为(4, 224, 224)
(即conv(3, 768, 16, 16)
变为conv(4, 768, 16, 16)
),它只会使得第一层权重无法加载,但精度大幅降低 - 微调通过冻结前面若干层网络权重,训练尾部网络,而不是冻结尾部训练头部,也表明了这点。
- 网络是逐层计算的,权重梯度也是逐层传递,当某一层的预训练权重被修改后,后部的权重也理所应当失去了它应有的效果。
如何修改预训练权重shape?
- 我个人认为根据实际修改方式修改预训练权重会有更好的效果。
- 比如原来是一路的网络,现在把它拓展到四路,每路的shape与原来一致。
- 如
conv(3, 768, 16, 16)
到conv(3, 768*4, 16, 16)
- 那权重直接改为
torch.cat((x, x, x, x), dim=0)
,因为权重的shape为(768, 3, 16, 16)
,即(C_out, C_in, H, W)
,所以在dim=0
上改。 - 这样conv出来的结果直接分成四份,每份的结果依然保持原来的结构
- 如
- 如果是一路的网络,和上面一样拓展四倍,但不分成四路,那为了保持原来的权重结构,就应该用上采样的方法,而不是cat。
- 但我没有找到一个合适的上采样函数能上采样
(C1, C2, H, W)
中的某一路,只有对(H, W)
上采样的。 - 不过当
(H, W)
中有某个维度=1时,可以permute, squeeze
变为(H, C1, C2)
来上采样 - 如果
(C1, C2, H, W)
都不为1时,理论上将其permute
为(H, W, C1, C2)
,在分为H * (W, C1, C2)
后上采样应该也能保持较好的结构信息。 - 但正如开篇说的那样,我没有实验能够证明结论,也没有实际探究过上采样的实现方式,所以只是给大家提供一个思路,实际还是以炼丹结果为导向。
- 但我没有找到一个合适的上采样函数能上采样
在哪修改预训练权重shape?
- 加载权重后,用
k, v for checkpoint.item()
的形式取出对应键值对。 - 对
v
用上采样函数即可,同特征的上采样。 - 修改完后,保存
k, new_v
到新的权重字典,保存字典。
一之涯
兄弟你这个评论有bug啊,手机没法通过评论的安全验证。
MWHLS@一之涯
!
那我只能忍痛失去手机用户了,这玩意能把所有的机器人都挡下来,我可不换
一之涯
666