基於TorchText的PyTorch文本分類

2021-01-11 人工智慧遇見磐創

文本分類是自然語言處理的重要應用之一。在機器學習中有多種方法可以對文本進行分類。但是這些分類技術大多需要大量的預處理和大量的計算資源。在這篇文章中,我們使用PyTorch來進行多類文本分類,因為它有如下優點:

PyTorch提供了一種強大的方法來實現複雜的模型體系結構和算法,其預處理量相對較少,計算資源(包括執行時間)的消耗也較少。PyTorch的基本單元是張量,它具有在運行時改變架構和跨gpu分布訓練的優點。PyTorch提供了一個名為TorchText的強大庫,其中包含用於預處理文本的腳本和一些流行的NLP數據集的原始碼。

在本文中,我們將使用TorchText演示多類文本分類,TorchText是PyTorch中一個強大的自然語言處理庫。

對於這種分類,將使用由EmbeddingBag層和線性層組成的模型。EmbeddingBag通過計算嵌入的平均值來處理長度可變的文本條目。

這個模型將在DBpedia數據集上進行訓練,其中文本屬於14個類。訓練成功後,模型將預測輸入文本的類標籤。

DBpedia數據集

DBpedia是自然語言處理領域中流行的基準數據集。它包含14個類別的文本,如公司、教育機構、藝術家、電影等。

它實際上是從維基百科項目創建的信息中提取的結構化內容集。TorchText提供的DBpedia數據集有63000個屬於14個類的文本實例。它包括5600個訓練實例和70000個測試實例。

用TorchText實現文本分類

首先,我們需要安裝最新版本的TorchText。

!pip install torchtext==0.4

之後,我們將導入所有必需的庫。

import torchimport torchtextfrom torchtext.datasets import text_classificationimport osimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils.data import DataLoaderimport timefrom torch.utils.data.dataset import random_splitimport refrom torchtext.data.utils import ngrams_iteratorfrom torchtext.data.utils import get_tokenizer在下一步中,我們將定義ngrams和batch大小。ngrams特徵用於捕獲有關本地語序的重要信息。

我們使用bigram,數據集中的示例文本將是單個單詞加上bigrams字符串的列表。

NGRAMS = 2BATCH_SIZE = 16現在,我們將讀取TorchText提供的DBpedia數據集。

if not os.path.isdir('./.data'): os.mkdir('./.data')train_dataset, test_dataset = text_classification.DATASETS['DBpedia']( root='./.data', ngrams=NGRAMS, vocab=None)

下載數據集後,我們將驗證下載數據集的長度和標籤數量。

print(len(train_dataset))print(len(test_dataset))

print(len(train_dataset.get_labels()))print(len(test_dataset.get_labels()))

我們將使用CUDA架構來加速運行和執行。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")device

在下一步中,我們將定義分類的模型。

class TextSentiment(nn.Module): def __init__(self, vocab_size, embed_dim, num_class): super().__init__() self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True) self.fc = nn.Linear(embed_dim, num_class) self.init_weights() def init_weights(self): initrange = 0.5 self.embedding.weight.data.uniform_(-initrange, initrange) self.fc.weight.data.uniform_(-initrange, initrange) self.fc.bias.data.zero_() def forward(self, text, offsets): embedded = self.embedding(text, offsets) return self.fc(embedded)print(model)

現在,我們將初始化超參數並定義函數以生成訓練batch。

VOCAB_SIZE = len(train_dataset.get_vocab())EMBED_DIM = 32NUN_CLASS = len(train_dataset.get_labels())model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)def generate_batch(batch): label = torch.tensor([entry[0] for entry in batch]) text = [entry[1] for entry in batch] offsets = [0] + [len(entry) for entry in text] offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) text = torch.cat(text) return text, offsets, label在下一步中,我們將定義用於訓練和測試模型的函數。

def train_func(sub_train_): # 訓練模型 train_loss = 0 train_acc = 0 data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch) for i, (text, offsets, cls) in enumerate(data): optimizer.zero_grad() text, offsets, cls = text.to(device), offsets.to(device), cls.to(device) output = model(text, offsets) loss = criterion(output, cls) train_loss += loss.item() loss.backward() optimizer.step() train_acc += (output.argmax(1) == cls).sum().item() # 調整學習率 scheduler.step() return train_loss / len(sub_train_), train_acc / len(sub_train_)def test(data_): loss = 0 acc = 0 data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch) for text, offsets, cls in data: text, offsets, cls = text.to(device), offsets.to(device), cls.to(device) with torch.no_grad(): output = model(text, offsets) loss = criterion(output, cls) loss += loss.item() acc += (output.argmax(1) == cls).sum().item() return loss / len(data_), acc / len(data_)我們將用5個epoch訓練模型。

N_EPOCHS = 5min_valid_loss = float('inf')criterion = torch.nn.CrossEntropyLoss().to(device)optimizer = torch.optim.SGD(model.parameters(), lr=4.0)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)train_len = int(len(train_dataset) * 0.95)sub_train_, sub_valid_ = \ random_split(train_dataset, [train_len, len(train_dataset) - train_len])for epoch in range(N_EPOCHS): start_time = time.time() train_loss, train_acc = train_func(sub_train_) valid_loss, valid_acc = test(sub_valid_) secs = int(time.time() - start_time) mins = secs / 60 secs = secs % 60 print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs)) print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)') print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')

下一步,我們將在測試數據集上測試我們的模型,並檢查模型的準確性。

print('Checking the results of test dataset...')test_loss, test_acc = test(test_dataset)print(f'\tLoss: {test_loss:.4f}(test)\t|\tAcc: {test_acc * 100:.1f}%(test)')

現在,我們將在單個新聞文本字符串上測試我們的模型,並預測給定新聞文本的類標籤。

DBpedia_label = {0: 'Company', 1: 'EducationalInstitution', 2: 'Artist', 3: 'Athlete', 4: 'OfficeHolder', 5: 'MeanOfTransportation', 6: 'Building', 7: 'NaturalPlace', 8: 'Village', 9: 'Animal', 10: 'Plant', 11: 'Album', 12: 'Film', 13: 'WrittenWork'}def predict(text, model, vocab, ngrams): tokenizer = get_tokenizer("basic_english") with torch.no_grad(): text = torch.tensor([vocab[token] for token in ngrams_iterator(tokenizer(text), ngrams)]) output = model(text, torch.tensor([0])) return output.argmax(1).item() + 1vocab = train_dataset.get_vocab()model = model.to("cpu")現在,我們將從測試數據中隨機抽取一些文本並檢查預測的類標籤。

第一個預測:

ex_text_str = "Brekke Church (Norwegian: Brekke kyrkje) is a parish church in Gulen Municipality in Sogn og Fjordane county, Norway. It is located in the village of Brekke. The church is part of the Brekke parish in the Nordhordland deanery in the Diocese of Bjrgvin. The white, wooden church, which has 390 seats, was consecrated on 19 November 1862 by the local Dean Thomas Erichsen. The architect Christian Henrik Grosch made the designs for the church, which is the third church on the site."print("This is a %s news" %DBpedia_label[predict(ex_text_str, model, vocab, 2)])

第二個預測:

ex_text_str2 = "Cerithiella superba is a species of very small sea snail, a marine gastropod mollusk in the family Newtoniellidae. This species is known from European waters. It was described by Thiele, 1912."print("This text belongs to %s class" %DBpedia_label[predict(ex_text_str2, model, vocab, 2)])

第三個預測:

ex_text_str3 = " Nithari is a village in the western part of the state of Uttar Pradesh India bordering on New Delhi. Nithari forms part of the New Okhla Industrial Development Authority's planned industrial city Noida falling in Sector 31. Nithari made international news headlines in December 2006 when the skeletons of a number of apparently murdered women and children were unearthed in the village."print("This text belongs to %s class" %DBpedia_label[predict(ex_text_str3, model, vocab, 2)])

因此,通過這種方式,我們使用TorchText實現了多類文本分類。

這是一種簡單易行的文本分類方法,使用這個PyTorch庫只需很少的預處理量。在5600個訓練實例上訓練模型只花了不到5分鐘。

通過將ngram從2更改為3來重新運行這些代碼並查看結果是否有改進。同樣的實現也可以在TorchText提供的其他數據集上實現。

參考文獻:

『Text Classification with TorchText』, PyTorch tutorialAllen Nie, 『A Tutorial on TorchText』

相關焦點

  • textCNN論文與原理——短文本分類(基於pytorch和torchtext)
    前言之前書寫了使用pytorch進行短文本分類,其中的數據處理方式比較簡單粗暴。自然語言處理領域包含很多任務,很多的數據向之前那樣處理的話未免有點繁瑣和耗時。在pytorch中眾所周知的數據處理包是處理圖片的torchvision,而處理文本的少有提及,快速處理文本數據的包也是有的,那就是torchtext[1]。
  • 【深度學習】textCNN論文與原理——短文本分類(基於pytorch)
    前言前文已經介紹了TextCNN的基本原理,如果還不熟悉的建議看看原理:【深度學習】textCNN論文與原理[1]及一個簡單的基於pytorch的圖像分類案例:【深度學習】卷積神經網絡-圖片分類案例(pytorch實現)[2]。
  • pytorch編程之使用 TorchText 進行文本分類
    本教程介紹了如何使用torchtext中的文本分類數據集,包括- AG_NEWS,- SogouNews,-
  • [DL] PyTorch 折桂 13:TorchText
    往期匯總:TorchText 是 PyTorch 的一個擴展功能包,主要提供文本數據讀取、創建迭代器的的功能與語料庫、詞向量的信息,分別對應了 torchtext.data、torchtext.datasets 和 torchtext.vocab 三個子模塊。本文參考了三篇文章[1][2][3]。1.
  • 【PyTorch實戰】手把手教你用torchtext處理文本數據
    如何將一個純文本數據(比如一個 txt 文本), 變成一個模型可接受的數據(比如一個 embedding 序列)呢?如果你是 pytorch 的用戶,你可能已經非常熟悉 torchvision 了,因為它已經比較穩定,而且官方也為它出了教程。torchtext 跟 torchvision 一樣,是為了處理特定的數據和數據集而存在的。
  • 使用torchtext導入NLP數據集
    如果你是pytorch的用戶,可能你會很熟悉pytorch生態圈中專門預處理圖像數據集的torchvision庫。
  • 新版PyTorch 1.2 已發布:功能更多、兼容更全、操作更快!
    帶有監督學習數據集的 TORCHTEXT 0.4torchtext 的一個關鍵重點領域是提供有助於加速 NLP 研究的基本要素。其中包括輕鬆訪問常用數據集和基本預處理流程,用以處理基於原始文本的數據。torch
  • 【Python】RST文件打開——以torchtext官方github文檔為例
    ,pytorch官方也打不開,於是在github下載了源碼,看看github有沒有相關的官方文檔。查閱資料有發現,rst文件也是一種標記文本,與md類似,其全稱是:reStructuredText,更多關於rst信息可參考reStructuredText(rst)快速入門語法說明[1]rst文件是Python程式語言的Docutils項目的一部分,Python Doc-SIG (Documentation Special Interest
  • 【乾貨】史上最全的PyTorch學習資源匯總
    · 簡單易上手的PyTorch中文文檔(https://github.com/fendouai/pytorch1.0-cn):非常適合新手學習。該文檔從介紹什麼是PyTorch開始,到神經網絡、PyTorch的安裝,再到圖像分類器、數據並行處理,非常詳細的介紹了PyTorch的知識體系,適合新手的學習入門。
  • 獨家 :教你用Pytorch建立你的第一個文本分類模型!
    學習如何使用PyTorch實現文本分類理解文本分類中的關鍵點學習使用壓縮填充方法在我的編程歷程中,我總是求助於最先進的架構。現在得益於深度學習框架,比如說PyTorch,Keras和 TensorFlow,實現先進的架構已經變得更簡單了。這些深度學習框架提供了一種實現複雜模型架構和算法的簡單方式,不需要你掌握大量的專業知識和編程技能。
  • 新手必備 | 史上最全的PyTorch學習資源匯總
    (5)最後,為大家推薦一個簡單易上手的PyTorch中文文檔,非常適合新手學習:https://github.com/fendouai/pytorch1.0-cn。該文檔從介紹什麼是PyTorch開始,到神經網絡、PyTorch的安裝,再到圖像分類器、數據並行處理,非常詳細的介紹了PyTorch的知識體系,適合新手的學習入門。
  • 使用PyTorch建立你的第一個文本分類模型
    概述學習如何使用PyTorch執行文本分類理解解決文本分類時所涉及的要點學習使用包填充(Pack Padding)特性介紹我總是使用最先進的架構來在一些比賽提交模型結果。因此,在本文中,我們將介紹解決文本分類問題的關鍵點。然後我們將在PyTorch中實現第一個文本分類器!目錄為什麼使用PyTorch進行文本分類?
  • 獨家 | 教你用Pytorch建立你的第一個文本分類模型!
    學習如何使用PyTorch實現文本分類理解文本分類中的關鍵點學習使用壓縮填充方法在我的編程歷程中,我總是求助於最先進的架構。現在得益於深度學習框架,比如說PyTorch,Keras和 TensorFlow,實現先進的架構已經變得更簡單了。這些深度學習框架提供了一種實現複雜模型架構和算法的簡單方式,不需要你掌握大量的專業知識和編程技能。
  • 資料|【乾貨】PyTorch學習資源匯總
    該文檔從介紹什麼是PyTorch開始,到神經網絡、PyTorch的安裝,再到圖像分類器、數據並行處理,非常詳細的介紹了PyTorch的知識體系,適合新手的學習入門。(小編就不翻譯了)還帶了說明文檔,庫和說明文檔的地址為:NLP&PyTorch實戰Pytorch text:Torchtext是一個非常好用的庫,可以幫助我們很好的解決文本的預處理問題。
  • 【乾貨】基於pytorch的CNN、LSTM神經網絡模型調參小結
    對於沒有學習過pytorch的初學者,可以先看一下官網發行的60分鐘入門pytorch,參考地址 :http://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html卷積神經網絡CNN理解參考(https://www.zybuluo.com/hanbingtao/note/485480)
  • pytorch編程之使用 TorchText 進行語言翻譯
    它基於 PyTorch 社區成員 Ben Trevett 的本教程,並由 Seth Weidman 在 Ben 的允許下創建。在本教程結束時,您將能夠:<cite>欄位</cite>和 <cite>TranslationDataset</cite>torchtext具有用於創建數據集的實用程序,可以輕鬆地對其進行迭代,以創建語言翻譯模型。
  • 文本分類的14種算法
    1)伯努利貝葉斯即特徵的取值只有取和不取兩類(0和1),對應樸素貝葉斯公式中,p(yi)=標籤為yi的文本數(句子數)/文本總數(句子總數)p(xj|yi)=(標籤為yi的文本中出現了單詞xj的文本數+1)/(標籤為yi的文本數+2)。
  • 基於Text-CNN模型的中文文本分類實戰
    本文介紹NLP中文本分類任務中核心流程進行了系統的介紹,文末給出一個基於Text-CNN模型在搜狗新聞數據集上二分類的Demo。文本分類是自然語言處理領域最活躍的研究方向之一,從樣本數據的分類標籤是否互斥上來說,可以分為文本多分類與文本多標籤分類。
  • Github 2.2K星的超全PyTorch資源列表
    該部分項目涉及語音識別、多說話人語音處理、機器翻譯、共指消解、情感分類、詞嵌入/表徵、語音生成、文本語音轉換、視覺問答等任務,其中有一些是具體論文的 PyTorch 復現,此外還包括一些任務更廣泛的庫、工具集、框架。