目錄
1. PyTorch 預訓練模型Pytorch 提供了許多 Pre-Trained Model on ImageNet,僅需調用 torchvision.models 即可,具體細節(jié)可查看官方文檔。 往往我們需要對 Pre-Trained Model 進行相應的修改,以適應我們的任務。這種情況下,我們可以先輸出 Pre-Trained Model 的結構,確定好對哪些層修改,或者添加哪些層,接著,再將其修改即可。 比如,我需要將 ResNet-50 的 Layer 3 后的所有層去掉,在分別連接十個分類器,分類器由 ResNet-50.layer4 和 AvgPool Layer 和 FC Layer 構成。這里就需要用到 torch.nn.ModuleList 了,比如:: 代碼中的 [nn.Linear(10, 10) for i in range(10)] 是一個python列表,必須要把它轉換成一個Module Llist列表才可以被 PyTorch 使用,否則在運行的時候會報錯: RuntimeError: Input type (CUDAFloatTensor) and weight type (CPUFloatTensor) should be the same 2. 保存模型參數PyTorch 中保存模型的方式有許多種: # 保存整個網絡torch.save(model, PATH) # 保存網絡中的參數, 速度快,占空間少torch.save(model.state_dict(),PATH)# 選擇保存網絡中的一部分參數或者額外保存其余的參數torch.save({'state_dict': model.state_dict(), 'fc_dict':model.fc.state_dict(), 'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma}, PATH)3. 讀取模型參數同樣的,PyTorch 中讀取模型參數的方式也有許多種:
4. 凍結部分模型參數,進行 fine-tuning加載完 Pre-Trained Model 后,我們需要對其進行 Finetune。但是在此之前,我們往往需要凍結一部分的模型參數: # 第一種方式for p in freeze.parameters(): # 將需要凍結的參數的 requires_grad 設置為 False p.requires_grad = Falsefor p in no_freeze.parameters(): # 將fine-tuning 的參數的 requires_grad 設置為 True p.requires_grad = True# 將需要 fine-tuning 的參數放入optimizer 中optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)# 第二種方式optim_param = []for p in freeze.parameters(): # 將需要凍結的參數的 requires_grad 設置為 False p.requires_grad = Falsefor p in no_freeze.parameters(): # 將fine-tuning 的參數的 requires_grad 設置為 True p.requires_grad = True optim_param.append(p)optimizer.SGD(optim_param, lr=1e-3) # 將需要 fine-tuning 的參數放入optimizer 中5. 模型訓練與測試的設置訓練時,應調用 model.train() ;測試時,應調用 model.eval(),以及 with torch.no_grad(): model.train():使 model 變成訓練模式,此時 dropout 和 batch normalization 的操作在訓練起到防止網絡過擬合的問題。 model.eval():PyTorch會自動把 BN 和 DropOut 固定住,不會取平均,而是用訓練好的值。不然的話,一旦測試集的 Batch Size 過小,很容易就會被 BN 層導致生成圖片顏色失真極大。 with torch.no_grad():PyTorch 將不再計算梯度,這將使得模型 forward 的時候,顯存的需求大幅減少,速度大幅提高。 注意:若模型中具有 Batch Normalization 操作,想固定該操作進行訓練時,需調用對應的 module 的 eval() 函數。這是因為 BN Module 除了參數以外,還會對輸入的數據進行統計,若不調用 eval(),統計量將發(fā)生改變!具體代碼可以這樣寫:
在其他地方看到的解釋:
6. 利用 torch.nn.DataParallel 進行多 GPU 訓練
-完- |
|
|
來自: 520jefferson > 《待分類》