PyTorch模型文件.pth浅析

保存模型

在pytorch进行模型保存的时候,一般有两种保存方式:

  • 一种是保存整个模型
  • 另一种是只保存模型的参数
1
2
torch.save(model.state_dict(), "my_model.pth")  # 只保存模型的参数
torch.save(model, "my_model.pth")  # 保存整个模型

后缀的格式

我们在训练模型的时候保存模型一般是保存.pth.pkl,有时候也用.pt,有什么区别呢?

其实,他们只是后缀名不同而已,格式没啥区别
一般惯例是使用.pth
另外,为什么会有 .pkl这种后缀名呢?因为Python有一个序列化模块pickle,使用它保存模型时,通常会起一个以.pkl为后缀名的文件。刚好torch.save()也是使用pickle来保存模型的。

模型文件浅析

查看一下模型文件

1
2
3
4
5
6
7
8
9
module_save = "model_save/module_attunet.pkl"

if os.path.exists(module_save):
    # net.load_state_dict(torch.load(module_save))
    a = torch.load(module_save)
    print(type(a))  # <class 'collections.OrderedDict'>
    print(len(a))   # 240
    for k in a.keys():
        print(k)    # 查看键

https://geoer666-1257264766.cos.ap-beijing.myqcloud.com/20221021100200.png

1
2
for k in net.keys():
         print(k, net[k].shape, sep="    ")

https://geoer666-1257264766.cos.ap-beijing.myqcloud.com/20221021100618.png

0%