Pytorch中實現只導入部分模型參數的方式

發布時間: 2020-01-02 19:26:13 來源: 互聯網 欄目: python 點擊:

今天小編就為大家分享一篇Pytorch中實現只導入部分模型參數的方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

我們在做遷移學習,或者在分割,檢測等任務想使用預訓練好的模型,同時又有自己修改之后的結構,使得模型文件保存的參數,有一部分是不需要的(don't expected)。我們搭建的網絡對保存文件來說,有一部分參數也是沒有的(missed)。如果依舊使用torch.load(model.state_dict())的辦法,就會出現 xxx expected,xxx missed類似的錯誤。那么在這種情況下,該如何導入模型呢?

好在Pytorch中的模型參數使用字典保存的,鍵是參數的名稱,值是參數的具體數值。我們使用model.state_dict()獲得這個字典,之后就能利用參數名稱來實現導入。

請看下面的一個例子。

我們先搭建一個小小的網絡。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
class Net(Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1 = nn.Conv2d(3,32,3,1)
    self.conv2 = nn.Conv2d(32,3,3,1)
    self.w = nn.Parameter(t.randn(3,10))
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out,(out.shape[2],out.shape[3]))
    out = F.linear(out,weight=self.w)
    return out

然后我們保存這個網絡的初始值。

model = Net()
t.save(model.state_dict(),'xxx.pth')

現在我們將Net修改一下,多加幾個卷積層,但并不加入到forward中,僅僅出于少些幾行的目的。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
 
 
class Net(Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 3, 3, 1)
    self.conv3 = nn.Conv2d(3,64,3,1)
    self.conv4 = nn.Conv2d(64,32,3,1)
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
 
    self.w = nn.Parameter(t.randn(3, 10))
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out, (out.shape[2], out.shape[3]))
    out = F.linear(out, weight=self.w)
    return out

我們現在試著導入之前保存的模型參數。

path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path))
 
'''
RuntimeError: Error(s) in loading state_dict for Net:
 Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias". 
'''

出現了沒有在模型文件中找到error中的關鍵字的錯誤。

現在我們這樣導入模型

path = 'xxx.pth'
model = Net()
save_model = t.load(path)
model_dict = model.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)

看看上面的代碼,很容易弄明白。其中model_dict.update的作用是更新代碼中搭建的模型參數字典。為啥更新我其實并不清楚,但這一步驟是必須的,否則還會報錯。

為了弄清楚為什么要更新model_dict,我們不妨分別輸出state_dict和model_dict的關鍵值看一看。

for k in state_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
'''
for k in model_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
conv4.weight
conv4.bias
'''

這個結果也是預料之中的,所以我猜測,update之后,model_dict和state_dict中具有相同鍵的值已經同步了。updata的目的就是使model_dict帶有state_dict中都具有的那一部分參數的值,對于model_dict中有的,但是save_dict中沒有的參數,值不改變,參數仍然使用初始值。

以上這篇Pytorch中實現只導入部分模型參數的方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持我們。

本文標題: Pytorch中實現只導入部分模型參數的方式
本文地址: http://www.1921352.live/jiaoben/python/296878.html

如果認為本文對您有所幫助請贊助本站

支付寶掃一掃贊助微信掃一掃贊助

  • 支付寶掃一掃贊助
  • 微信掃一掃贊助
  • 支付寶先領紅包再贊助
    聲明:凡注明"本站原創"的所有文字圖片等資料,版權均屬編程客棧所有,歡迎轉載,但務請注明出處。
    PyTorch中topk函數的用法詳解PyTorch和Keras計算模型參數的例子
    Top 网上挖矿机赚钱