最近花了一個月,調了一個文本分類的模型,調參過程中從網上學到了很多技巧,在這裡總結整理並附上相關的代碼方便大家參考。
整體分為三部分數據集,模型結構/初始化,模型超參數
一.數據集
1.數據集的選取
為了儘快調整參數和模型,應當選取小數據集進行測試和調參
二.模型結構/初始化
1.模型初始化
作用:參數初始化很重要,它決定了模型的訓練速度與是否可以躲開局部極小
1.lstm的h用orthogonal
2.relu作激活函數初始化使用He normal==kaiming_normal,tanh作激活函數初始化使用Glorot normal=Xavier normal初始化
代碼示例:
parameter=nn.Parameter(nn.init.xavier_normal_(torch.Tensor(128, 64)), requires_grad=True)
2..Dropout
作用:防止過擬合,正則化
tip:
1.dropout對於具有大量參數的全連接效果最好,而CNN的卷積層不是全連接,參數不是很多,所以效果不明顯
2.最好是在開始的層加dropout,越往後的層,越是要小心加dropout
3.在很深的網絡裡可以每層之間都加
加在哪裡:
1.NLP embedding的位置
2.全連接層的激活函數層之後
代碼示例:
dropout = nn.Dropout(p=0.5)
embedding=dropout(embedding)
3..BatchNorm
作用:
1.歸一化防止過擬合
2.加快訓練速度,改進優化
tip:
1.batch size可能會影響它的結果,BN一般要求將訓練集完全打亂,並用一個較大的batch值,否則,一個batch的數據無法較好得代表訓練集的分布,會影響模型訓練的效果
2.對初始化不敏感,可以使用大學習率
加在哪裡:
1.BN+relu
2.Sigmoid+BN
代碼示例:BN+relu+dropout
dropout = nn.Dropout(p=0.5)
mlp1 = nn.Linear(256, 128)
mlp2 = nn.Linear(128, 64)
#假設有64個類別
bn = nn.BatchNorm1d(64)
relu = nn.ReLU()
output=mlp2(dropout(relu(bn(mlp1(document_vector)))))
三.模型超參數
1.學習率lr+優化器
定義:模型每次迭代更新時的變化程度
tip:
1.學習率lr一般嘗試 1e-3 1e-4
2.優化器adam和adam+momentum比較
2.1 adam收斂快但效果沒有sgd+momentum的解好
2.2 adam不需要特別調lr,sgd需要調lr和初始化權重
2.3 建議一開始使用adam(簡單方便),實在不行再嘗試sgd
3.Irscheduler控制學習率衰減
代碼示例:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler =torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=32) #T_max設32或64
optimizer.zero_grad()
#梯度清0
loss.backward()
#反向傳播
optimizer.step()
#優化器更新
scheduler.step()
2.batch size
定義:每次訓練模型時使用的數據數量
tip:
1.batch size越大,每個epoch更新的次數越少,所以需要更大的學習率lr;反之,batch size越小,因為每個epoch更新次數變多,需要更小的學習率lr
2.batch size太大爆顯存時可以考慮多卡並行計算,這樣會大大提高計算速度
3.一般小於等於128
3.embedding szie
定義:輸入向量的維數
tip:
1.embedding size (64 or 128)
2.LSTM/CNN hiddensize (256 or 512)
4.參數正則化
tip:
1.pytorch中Adam已經自帶了參數正則化,需要再調整可以自己輸入指定值
代碼示例:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,eps=1e-7,weight_decay=0.0001)
https://zhuanlan.zhihu.com/p/350763434?utm_source=wechat_session&utm_medium=social&utm_oi=641551032598138880八大模塊一步到位!學習企業急需的數據課程內容!課程內容貼近企業真實工作環境,培養實戰型數據科學人才!掃下圖二維碼直達課程!數據科學時代,網際網路人必備技能!學成可降低跨部門溝通成本。運用數據科學與AI技術,幫助設計更好的網際網路產品,提高業務水平。