使用PyTorch進行情侶幸福度測試指南

2021-01-11 人工智慧遇見磐創

DeepConnection模型框架

計算機視覺--圖像和視頻數據分析是深度學習目前最火的應用領域之一。因此,在學習深度學習的同時嘗試運用某些計算機視覺技術做些有趣的事情會很有意思,也會讓你發現些令人吃驚的事實。長話短說,我的搭檔(Maximiliane Uhlich)和我決定將深度學習應用於浪漫情侶的形象分類上,因為Maximiliane是一位關係研究員和情感治療師。具體來說,我們想知道我們是否可以準確地判斷圖像或視頻中描繪的情侶是否對他們的關係感到滿意? 事實證明,我們可以!我們的最終模型(我們稱之為DeepConnection)分類準確率接近97%,能夠準確地區分幸福與不幸福的情侶。大家可以在我們的論文預覽連結[1]裡閱讀完整介紹,上圖是我們為這個任務設計的框架草圖。

在數據集收集方面,我們使用這個Python腳本[2]進行網頁數據抽取(webscraping)來獲取幸福和不幸福的情侶數據。最後,我們整理出了大約包含1000張圖像的訓練集。這並不是特別多,所以我們使用數據增強與遷移學習來增強我們模型在數據集上的表現。數據增強--圖像方向的微小變化,色調和色彩強度以及許多其他因素都會增強模型的泛化能力,從而避免學習一些不相關信息。 例如,如果數據中幸福夫妻的圖像平均比不幸福夫妻的圖像更亮,我們並不希望我們的模型映射這種關聯。我們使用了強大的ImgAug庫[3]進行了相當多策略的數據擴充,以確保我們模型的魯棒性。基本上對於每個批次的每個圖像,我們至都至少應用多種數據增強技術。下圖是一張圖片應用了48種數據增強策略的示例。

圖像增強後數據示例

我們決定使用ResNet模型作為DeepConnection的基礎網絡,在大型數據集ImageNet上預先訓練。通過預訓練,模型已經具有了一定的識別能力。我們所有的模型都借用PyTorch實現,我們使用Google Colab上的免費GPU資源進行訓練和測試。這個基礎模型本身已經具備了良好的分類能力,但我們決定更進一步,用空間金字塔池化層(SPP)[4] 替換ResNet-34基礎模型的最後一個自適應池模塊。這裡,處理後的圖像數據被分成不同數量的正方形,並且僅傳遞最大值以進行進一步分析(最大池化)。這使得模型可以專注於重要的特徵,使其對不同大小的圖像具有魯棒性,並且不受圖像擾動的影響。之後,我們放置了一個均值變換(PMT)層[5],用數學函數轉換數據以引入非線性,使得DeepConnection可以從數據中捕獲更複雜的關係。這兩個模塊均提高了我們的分類準確度,我們在單獨的驗證集上得到了大約97%準確率。SPP / PMT和後續分類層的代碼如下所示:

class SPP(nn.Module): def __init__(self): super(SPP, self).__init__() ## features incoming from ResNet-34 (after SPP/PMT) self.lin1 = nn.Linear(2*43520, 100) self.relu = nn.ReLU() self.bn1 = nn.BatchNorm1d(100) self.dp1 = nn.Dropout(0.5) self.lin2 = nn.Linear(100, 2) def forward(self, x): # SPP x = spatial_pyramid_pool(x, x.shape[0], [x.shape[2], x.shape[3]], [8, 4, 2, 1]) # PMT x_1 = torch.sign(x)*torch.log(1 + abs(x)) x_2 = torch.sign(x)*(torch.log(1 + abs(x)))**2 x = torch.cat((x_1, x_2), dim = 1) # fully connected classification part x = self.lin1(x) x = self.bn1(self.relu(x)) #1 x1 = self.lin2(self.dp1(x)) #2 x2 = self.lin2(self.dp1(x)) #3 x3 = self.lin2(self.dp1(x)) #4 x4 = self.lin2(self.dp1(x)) #5 x5 = self.lin2(self.dp1(x)) #6 x6 = self.lin2(self.dp1(x)) #7 x7 = self.lin2(self.dp1(x)) #8 x8 = self.lin2(self.dp1(x)) x = torch.mean(torch.stack([x1, x2, x3, x4, x5, x6, x7, x8]), dim = 0) return x仔細觀察代碼可以看出,最終分類層上有八個變種。看似浪費了算力實際上恰恰相反。這個概念是最近提出的,叫做multi-sample dropout(多樣本隨機丟棄),它在訓練期間顯著加速了收斂[6]。它基本上是防止模型學習虛假關係(過度擬合)和試圖不丟棄丟失掩碼中的信息之間的折衷。

我們在項目中對這個方法進行了其他一些調整優化,具體參看我們在GitHub放出的項目代碼[7]以獲取更多信息。簡單地提一下:我們使用混合精度(使用Apex庫[8]實現)訓練模型,以大大降低內存使用率,使用早停(earlystopping)來防止過度擬合,並根據餘弦函數進行學習率退火。

在達到令人滿意的分類準確度(具有相應高的召回率和精確度)後,我們想知道我們是否可以從DeepConnection執行的分類中學到一些東西。因此,我們嘗試模型解釋性探索並使用梯度加權類激活映射技術(Grad-CAM)進行分析[9]。基本地,Grad-CAM獲取最終卷積層的輸入梯度以確定顯著區域,其可以被視為原始圖像之上的上採樣熱圖。具體實現與可視化結果如下:

熱度圖對比

## from https://github.com/eclique/pytorch-gradcam/blob/master/gradcam.ipynbdef GradCAM(img, c, features_fn, classifier_fn): feats = modulelist_conv(img.cuda().half()) feats = feats.cuda() _, N, H, W = feats.size() out = modulelist_fc(feats) c_score = out[0, c] grads = torch.autograd.grad(c_score, feats) w = grads[0][0].mean(-1).mean(-1) sal = torch.matmul(w, feats.view(N, H*W)) sal = sal.view(H, W).cpu().detach().numpy() sal = np.maximum(sal, 0) return sal我們在論文中對此進行了進一步討論,並將其嵌入到了現有的心理學研究中,但DeepConnection似乎主要關注面部區域。從研究的角度來看,這很有意義,因為面部表情會傳達溝通和情感。除了Grad-CAM獲得的視覺感知之外,我們還想看看我們是否可以通過模型解釋得出實際特徵。為此,我們創建了激活狀態圖,以顯示最終分類層的哪些神經元被哪些給定圖像區域激活。

不同幸福程度代表性激活狀態圖

與其他模型相比,DeepConnection還學習到了代表不幸福的特徵,並不僅僅將缺乏代表幸福的特徵的分類為不幸福。但是,我們需要進一步的研究才能將這些特徵實際映射到人類行為可解釋性方面。我們還嘗試過在未知的情侶視頻幀上使用DeepConnection,效果非常好。

總體而言,該模型的穩健性是其強大優勢之一。準確的分類同樣適用於同性戀伴侶,不同膚色人種,除情侶外包含其他人的視頻幀中,不能完整顯示情侶人臉的視頻幀中等等。對於圖像中存在其他人的情況,DeepConnection甚至可以識別其他人是否感到滿意,但仍然將其預測集中在這對情侶身上。

除了進一步的模型解釋之外,下一步的工作將是使用更大的訓練數據集,從而訓練更複雜的模型。使用DeepConnection作為情侶治療師的助手將會很有意思,可以在會話期間或之後對情侶的當前關係狀態進行實時反饋。此外,我建議您與女票/男票一起輸入你們的合照,看看DeepConnection對你們的關係有何看法!希望這會是一個好的開始!

1: https://psyarxiv.com/df25j/2: https://github.com/Bribak/DeepConnection3: https://github.com/aleju/imgaug4: https://arxiv.org/abs/1406.47295: https://www.sciencedirect.com/science/article/pii/S00313203183045036: https://arxiv.org/abs/1905.097887: https://github.com/Bribak/DeepConnection8: https://github.com/NVIDIA/apex9: https://arxiv.org/abs/1610.02391

相關焦點

  • 《PyTorch中文手冊》來了
    本書提供PyTorch快速入門指南並與最新版本保持一致,其中包含的 Pytorch 教程全部通過測試保證可以成功運行。PyTorch 是一個深度學習框架,旨在實現簡單靈活的實驗。這是一本開源的書籍,目標是幫助那些希望和使用 PyTorch 進行深度學習開發和研究的朋友快速入門,其中包含的 Pytorch 教程全部通過測試保證可以成功運行。
  • PyTorch中使用DistributedDataParallel進行多GPU分布式模型訓練
    這篇文章是使用torch.nn.parallel.DistributedDataParallel API在純PyTorch中進行分布式訓練的簡介。 我們會:討論一般的分布式訓練方式,尤其是數據並行化涵蓋torch.dist和DistributedDataParallel的相關功能,並舉例說明如何使用它們測試真實的訓練腳本,以節省時間什麼是分布式訓練?
  • 還不會使用PyTorch框架進行深度學習的小夥伴,看過來
    選自heartbeat.fritz.ai作者:Derrick Mwiti機器之心編譯參與:Geek AI、王淑婷這是一篇關於使用 PyTorch 框架進行深度學習的教程,讀完以後你可以輕鬆地將該框架應用於深度學習模型。
  • 福利,PyTorch中文版官方教程來了
    教程作者來自 pytorchchina.com。教程網站:http://pytorch123.com教程裡有什麼教程根據 PyTorch 官方版本目錄,完整地還原了所有的內容。教程的一部分內容,使用 torch.view 改變 tensor 的大小或形狀用教程設計一個聊天機器人,以上為部分對話。
  • 《泰拉瑞亞》1.4npc幸福度怎麼玩 npc幸福度玩法介紹
    泰拉瑞亞1.4npc幸福度怎麼玩?
  • 使用PyTorch 檢測眼部疾病
    我們要利用這些數據對圖像進行normalize操作。 現在我們使用 pytorch 加載數據。每幅圖像都是中心像素,大小為490x490像素(為了在每幅圖像之間保持統一大小) ,然後轉換為張量,再進行規範化。
  • 新版PyTorch 1.2 已發布:功能更多、兼容更全、操作更快!
    每項工具都進行了新的優化與改進,兼容性更強,使用起來也更加便捷。PyTorch 發布了相關文章介紹了每個工具的更新細節,雷鋒網 AI 開發者將其整理與編譯如下。PyTorch 簡介自 PyTorch 1.0 發布以來,我們的社區不斷在進行擴展、添加入新的工具。這些發展為 PyTorch Hub 中越來越多可用的模型做出了極大的貢獻,並不斷增加了其在研究和生產中的用途。
  • PyTorch 0.4:完全改變API,官方支持Windows
    除了GPU加速和內存使用的高效外,PyTorch受歡迎的主要因素是動態計算圖的使用。已經有其他一些不太知名的深度學習框架使用動態計算圖,例如Chainer。動態圖的優點在於,圖(graph)是由run定義(「define by run」),而不是傳統的「define and run」。
  • 使用PyTorch進行主動遷移學習:讓模型預測自身的錯誤
    通過對被正確預測的置信度最低的項進行抽樣,就是對那些本應由人類檢查的應用標籤的項目進行抽樣。這段代碼是免費 PyTorch 庫中的 advanced_active_learning.py 文件中的代碼的一個稍微簡化的版本:https://github.com/rmunro/pytorch_active_learning/blob/master/advanced_active_learning.py你可以使用以下命令立即在用例——識別與災難相關的消息上運行它
  • 大家心心念念的PyTorch Windows官方支持來了
    GitHub 發布地址:https://github.com/pytorch/pytorch/releasesPyTorch 官網:http://pytorch.org/在沒有官方支持前,Windows 上安裝 PyTorch 需要藉助其它開發者發布的第三方 conda 包,而現在我們可以直接在 PyTorch 首頁上獲取使用 conda 或 pip 安裝的命令行,或跟隨教程使用源文件安裝。
  • 60 題 PyTorch 簡易入門指南,做技術的弄潮兒
    >https://www.kesci.com/home/project/5e0038642823a10036ae9ebf如果你是新新新手,可以先學習以下教程:https://www.kesci.com/home/project/5e0036722823a10036ae9d1dhttps://pytorch-cn.readthedocs.io
  • 用正確方法對度量學習算法進行基準測試
    大多數論文聲稱應用以下變換:將圖像大小調整為 256 x 256,隨機裁剪為 227 x 227,並以 50% 的機率進行水平翻轉。但最近一些論文的官方開源實現表明,他們實際上使用的是 GoogleNet 論文中描述的更複雜的裁剪方法(見「訓練方法」)。3.性能提升技巧在論文中沒有提及。
  • 華為雲應用編排,手把手教您完成pytorch代碼部署
    其歷史可以追溯至 2002 年使用Lua語言的Torch框架,並由其幕後團隊一手打造。PyTorch作為Torch框架的繼任者,並不僅僅只是移植代碼並提供接口,而是深入支持了Python,對大量模塊進行了重構,並新增了最先進的變量自動求導系統,成為時下最流行的動態圖框架。在入門時,PyTorch提供了完整的文檔,並有著活躍的社區論壇,對於新手而言上手遇到的難關容易解決。
  • github 項目推薦:用 edge-connect 進行圖像修復
    邊緣生成器先描繪出圖像缺失區域(規則和不規則)的邊緣,圖像完成網絡先驗使用描繪出的邊緣填充缺失區域。論文對該系統進行了詳細的描述。圖像 這裡使用 Places2, CelebA 以及 Paris Street-View 數據集。從官網下載數據集,在整個數據集上訓練模型。 下載完成後,運行 scripts/flist.py 這個文件來生成訓練、測試和驗證集文件列表。
  • Deep CARs:使用Pytorch學習框架實現遷移學習
    用不可視數據測試模型導入庫這一步只是加載庫,確保GPU是打開的。由於將使用深層網絡的預訓練模型,所以對CPU進行訓練並不是個好的選擇,原因是需要它花費很長的時間。GPU與此同時執行線性代數計算,訓練速度會提高100倍。如果沒有運行GPU,使用的是Colab工具的情況下,那就在電腦上點擊編輯 =>電腦設置。
  • 高性能PyTorch是如何煉成的?整理的10條脫坑指南
    如果你不立即使用它們也可以。只需記住,其他人可能正在用它們來訓練模型,速度可能會比你快 5%、10%、15%-…… 最終可能會導致面向市場或者工作機會時候的不同結果。數據預處理幾乎每個訓練管道都以 Dataset 類開始。它負責提供數據樣本。任何必要的數據轉換和擴充都可能在此進行。簡而言之,Dataset 能報告其規模大小以及在給定索引時,給出數據樣本。
  • mmdetection使用目標檢測工具箱訓練,測試
    2、運行demo測試環境是否安裝成功因為博主之前使用別的博客的demo代碼的時候出現錯誤,找了半天不知道是什麼原因,而當我好好看官方說明的時候才知道這個代碼在說明中有,而且已經更新過,所以為了保險期間,這裡就不直接貼出代碼了,給地址你們自己去看。
  • Pytorch-Transformers 1.0 發布,支持六個預訓練框架,含 27 個預...
    Le6、Facebook的 XLM,論文:「 Cross-lingual Language Model Pretraining」,論文作者:Guillaume Lample,Alexis Conneau這些實現都在幾個數據集(參見示例腳本)上進行了測試,性能與原始實現相當,例如 BERT中文全詞覆蓋在 SQuAD數據集上的F1分數為93
  • 雲計算學習:用PyTorch實現一個簡單的分類器
    所以我總結了一下自己當初學習的路線,準備繼續深入鞏固自己的 pytorch 基礎;另一方面,也想從頭整理一個教程,從沒有接觸過 pytorch 開始,到完成一些最新論文裡面的工作。以自己的學習筆記整理為主線,大家可以針對參考。第一篇筆記,我們先完成一個簡單的分類器。
  • 使用Google Colab上的PyTorch YOLOv3
    weights = opt.weightsimg_size = opt.img_size# 初始化設備device = torch_utils.select_device(opt.device)# 初始化模型model = Darknet(opt.cfg, img_size)# 加載權重attempt_download(weights)if weights.endswith('.pt'): # pytorch