文本分類是自然語言處理的重要應用之一。在機器學習中有多種方法可以對文本進行分類。但是這些分類技術大多需要大量的預處理和大量的計算資源。在這篇文章中,我們使用PyTorch來進行多類文本分類,因為它有如下優點:
PyTorch提供了一種強大的方法來實現複雜的模型體系結構和算法,其預處理量相對較少,計算資源(包括執行時間)的消耗也較少。PyTorch的基本單元是張量,它具有在運行時改變架構和跨gpu分布訓練的優點。PyTorch提供了一個名為TorchText的強大庫,其中包含用於預處理文本的腳本和一些流行的NLP數據集的原始碼。
在本文中,我們將使用TorchText演示多類文本分類,TorchText是PyTorch中一個強大的自然語言處理庫。
對於這種分類,將使用由EmbeddingBag層和線性層組成的模型。EmbeddingBag通過計算嵌入的平均值來處理長度可變的文本條目。
這個模型將在DBpedia數據集上進行訓練,其中文本屬於14個類。訓練成功後,模型將預測輸入文本的類標籤。
DBpedia數據集
DBpedia是自然語言處理領域中流行的基準數據集。它包含14個類別的文本,如公司、教育機構、藝術家、電影等。
它實際上是從維基百科項目創建的信息中提取的結構化內容集。TorchText提供的DBpedia數據集有63000個屬於14個類的文本實例。它包括5600個訓練實例和70000個測試實例。
用TorchText實現文本分類
首先,我們需要安裝最新版本的TorchText。
!pip install torchtext==0.4
之後,我們將導入所有必需的庫。
import torchimport torchtextfrom torchtext.datasets import text_classificationimport osimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils.data import DataLoaderimport timefrom torch.utils.data.dataset import random_splitimport refrom torchtext.data.utils import ngrams_iteratorfrom torchtext.data.utils import get_tokenizer在下一步中,我們將定義ngrams和batch大小。ngrams特徵用於捕獲有關本地語序的重要信息。
我們使用bigram,數據集中的示例文本將是單個單詞加上bigrams字符串的列表。
NGRAMS = 2BATCH_SIZE = 16現在,我們將讀取TorchText提供的DBpedia數據集。
if not os.path.isdir('./.data'): os.mkdir('./.data')train_dataset, test_dataset = text_classification.DATASETS['DBpedia']( root='./.data', ngrams=NGRAMS, vocab=None)
下載數據集後,我們將驗證下載數據集的長度和標籤數量。
print(len(train_dataset))print(len(test_dataset))
print(len(train_dataset.get_labels()))print(len(test_dataset.get_labels()))
我們將使用CUDA架構來加速運行和執行。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")device
在下一步中,我們將定義分類的模型。
class TextSentiment(nn.Module): def __init__(self, vocab_size, embed_dim, num_class): super().__init__() self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True) self.fc = nn.Linear(embed_dim, num_class) self.init_weights() def init_weights(self): initrange = 0.5 self.embedding.weight.data.uniform_(-initrange, initrange) self.fc.weight.data.uniform_(-initrange, initrange) self.fc.bias.data.zero_() def forward(self, text, offsets): embedded = self.embedding(text, offsets) return self.fc(embedded)print(model)
現在,我們將初始化超參數並定義函數以生成訓練batch。
VOCAB_SIZE = len(train_dataset.get_vocab())EMBED_DIM = 32NUN_CLASS = len(train_dataset.get_labels())model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)def generate_batch(batch): label = torch.tensor([entry[0] for entry in batch]) text = [entry[1] for entry in batch] offsets = [0] + [len(entry) for entry in text] offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) text = torch.cat(text) return text, offsets, label在下一步中,我們將定義用於訓練和測試模型的函數。
def train_func(sub_train_): # 訓練模型 train_loss = 0 train_acc = 0 data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch) for i, (text, offsets, cls) in enumerate(data): optimizer.zero_grad() text, offsets, cls = text.to(device), offsets.to(device), cls.to(device) output = model(text, offsets) loss = criterion(output, cls) train_loss += loss.item() loss.backward() optimizer.step() train_acc += (output.argmax(1) == cls).sum().item() # 調整學習率 scheduler.step() return train_loss / len(sub_train_), train_acc / len(sub_train_)def test(data_): loss = 0 acc = 0 data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch) for text, offsets, cls in data: text, offsets, cls = text.to(device), offsets.to(device), cls.to(device) with torch.no_grad(): output = model(text, offsets) loss = criterion(output, cls) loss += loss.item() acc += (output.argmax(1) == cls).sum().item() return loss / len(data_), acc / len(data_)我們將用5個epoch訓練模型。
N_EPOCHS = 5min_valid_loss = float('inf')criterion = torch.nn.CrossEntropyLoss().to(device)optimizer = torch.optim.SGD(model.parameters(), lr=4.0)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)train_len = int(len(train_dataset) * 0.95)sub_train_, sub_valid_ = \ random_split(train_dataset, [train_len, len(train_dataset) - train_len])for epoch in range(N_EPOCHS): start_time = time.time() train_loss, train_acc = train_func(sub_train_) valid_loss, valid_acc = test(sub_valid_) secs = int(time.time() - start_time) mins = secs / 60 secs = secs % 60 print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs)) print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)') print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')
下一步,我們將在測試數據集上測試我們的模型,並檢查模型的準確性。
print('Checking the results of test dataset...')test_loss, test_acc = test(test_dataset)print(f'\tLoss: {test_loss:.4f}(test)\t|\tAcc: {test_acc * 100:.1f}%(test)')
現在,我們將在單個新聞文本字符串上測試我們的模型,並預測給定新聞文本的類標籤。
DBpedia_label = {0: 'Company', 1: 'EducationalInstitution', 2: 'Artist', 3: 'Athlete', 4: 'OfficeHolder', 5: 'MeanOfTransportation', 6: 'Building', 7: 'NaturalPlace', 8: 'Village', 9: 'Animal', 10: 'Plant', 11: 'Album', 12: 'Film', 13: 'WrittenWork'}def predict(text, model, vocab, ngrams): tokenizer = get_tokenizer("basic_english") with torch.no_grad(): text = torch.tensor([vocab[token] for token in ngrams_iterator(tokenizer(text), ngrams)]) output = model(text, torch.tensor([0])) return output.argmax(1).item() + 1vocab = train_dataset.get_vocab()model = model.to("cpu")現在,我們將從測試數據中隨機抽取一些文本並檢查預測的類標籤。
第一個預測:
ex_text_str = "Brekke Church (Norwegian: Brekke kyrkje) is a parish church in Gulen Municipality in Sogn og Fjordane county, Norway. It is located in the village of Brekke. The church is part of the Brekke parish in the Nordhordland deanery in the Diocese of Bjrgvin. The white, wooden church, which has 390 seats, was consecrated on 19 November 1862 by the local Dean Thomas Erichsen. The architect Christian Henrik Grosch made the designs for the church, which is the third church on the site."print("This is a %s news" %DBpedia_label[predict(ex_text_str, model, vocab, 2)])
第二個預測:
ex_text_str2 = "Cerithiella superba is a species of very small sea snail, a marine gastropod mollusk in the family Newtoniellidae. This species is known from European waters. It was described by Thiele, 1912."print("This text belongs to %s class" %DBpedia_label[predict(ex_text_str2, model, vocab, 2)])
第三個預測:
ex_text_str3 = " Nithari is a village in the western part of the state of Uttar Pradesh India bordering on New Delhi. Nithari forms part of the New Okhla Industrial Development Authority's planned industrial city Noida falling in Sector 31. Nithari made international news headlines in December 2006 when the skeletons of a number of apparently murdered women and children were unearthed in the village."print("This text belongs to %s class" %DBpedia_label[predict(ex_text_str3, model, vocab, 2)])
因此,通過這種方式,我們使用TorchText實現了多類文本分類。
這是一種簡單易行的文本分類方法,使用這個PyTorch庫只需很少的預處理量。在5600個訓練實例上訓練模型只花了不到5分鐘。
通過將ngram從2更改為3來重新運行這些代碼並查看結果是否有改進。同樣的實現也可以在TorchText提供的其他數據集上實現。
參考文獻:
『Text Classification with TorchText』, PyTorch tutorialAllen Nie, 『A Tutorial on TorchText』