模型权重不匹配的解决思路——PyTorch预训练权重shape修改

前言

  • 最近太忙没时间细致的写,所以这篇讲的是思想,不是实际代码,大家意会一下。
  • 因为GPU有限,这篇文章讲的有些东西只是我自己的观点,缺乏实验证明,欢迎大家批评指正。
  • 今天题图在实验室用轨迹球画的,比较艺术一点。

为什么要修改预训练权重shape?

  • 预训练能大幅提升精度是毋庸置疑的,但间断加载预训练权重有效吗?
  • 我试过一些网络结构的修改,这些都会使得预训练权重shape mismatch,其中最简单的一个修改是将输入的(3, 224, 224)改为(4, 224, 224)(即conv(3, 768, 16, 16)变为conv(4, 768, 16, 16)),它只会使得第一层权重无法加载,但精度大幅降低
  • 微调通过冻结前面若干层网络权重,训练尾部网络,而不是冻结尾部训练头部,也表明了这点。
  • 网络是逐层计算的,权重梯度也是逐层传递,当某一层的预训练权重被修改后,后部的权重也理所应当失去了它应有的效果。

如何修改预训练权重shape?

  • 我个人认为根据实际修改方式修改预训练权重会有更好的效果。
  • 比如原来是一路的网络,现在把它拓展到四路,每路的shape与原来一致。
    • conv(3, 768, 16, 16)conv(3, 768*4, 16, 16)
    • 那权重直接改为 torch.cat((x, x, x, x), dim=0),因为权重的shape为(768, 3, 16, 16),即(C_out, C_in, H, W),所以在dim=0上改。
    • 这样conv出来的结果直接分成四份,每份的结果依然保持原来的结构
  • 如果是一路的网络,和上面一样拓展四倍,但不分成四路,那为了保持原来的权重结构,就应该用上采样的方法,而不是cat。
    • 但我没有找到一个合适的上采样函数能上采样(C1, C2, H, W)中的某一路,只有对(H, W)上采样的。
    • 不过当(H, W)中有某个维度=1时,可以permute, squeeze变为(H, C1, C2)来上采样
    • 如果(C1, C2, H, W)都不为1时,理论上将其permute(H, W, C1, C2),在分为H * (W, C1, C2) 后上采样应该也能保持较好的结构信息。
    • 但正如开篇说的那样,我没有实验能够证明结论,也没有实际探究过上采样的实现方式,所以只是给大家提供一个思路,实际还是以炼丹结果为导向。

在哪修改预训练权重shape?

  • 加载权重后,用 k, v for checkpoint.item()的形式取出对应键值对。
  • v 用上采样函数即可,同特征的上采样。
  • 修改完后,保存 k, new_v 到新的权重字典,保存字典。

版权声明:
作者:MWHLS
链接:https://mwhls.top/4034.html
来源:无镣之涯
文章版权归作者所有,未经允许请勿转载。

THE END
分享
二维码
打赏
< <上一篇
下一篇>>