卷積神經網絡(CNN)在計算機視覺任務中有著廣泛的應用,然而它的運算量非常巨大,這使得我們很難將CNN直接運用到計算資源受限的行動裝置上。為了減少CNN的計算代價,許多模型壓縮和加速的方法被提出。
其中AdderNet就是一種從新角度對模型進行加速的方法,以往的模型加速方法通過減少CNN的參數,AdderNet通過重新定義卷積計算,將卷積中的乘法替換為了加法。我們知道,乘法的計算代價要遠遠大於加法,AdderNet通過這種方式減少了計算量。
圖1 加法和乘法計算量對比
CNN卷積計算:
AdderNet計算:
1
代碼解讀
AdderNet的訓練代碼已經在github上開源(https://github.com/huawei-noah/AdderNet),接下來我們對代碼進行分析和解讀。
AdderNet的訓練代碼主要分為幾個文件:
adder.py
main.py
resnet20.py
resnet50.py
test.py
其中adder.py定義了AdderNet的基礎算子,main.py是訓練AdderNet的文件,test.py是測試文件,resnet20.py和resnet50.py定義了網絡結構。
由於訓練和測試的代碼以及網絡結構的代碼和正常的卷積神經網絡一樣,這裡我們不對它們做解析,我們主要解讀定義adder算子的adder.py文件。
adder.py中共含有兩個類和一個函數,兩個類分別是adder2d和adder,一個函數為adder2d_function。我們首先來看adder2d這個類。
class adder2d(nn.Module):
def __init__(self, input_channels, output_channels, kernel_size, stride=1, padding=0, bias=False):
super(adder2d, self).__init__()
self.stride = stride
self.padding = padding
self.input_channel=output_channel
self.kernel_size=kernel_size
self.adder=torch.nn.Parameter(nn.init.normal_(torch.randn(output_channel,input_channel,kernel_size,kernel_size)))
self.bias=bias
if bias:
self.b = torch.nn.Parameter(nn.init.uniform_(torch.zeros(output_channel)))
def forward(self, x):
output = adder2d_function(x,self.adder, self.stride, self.padding)
if self.bias:
output += self.b.unsqueeze(0).unsqueeze(2).unsqueeze(3)
return output
可以看到,adder2d這個類定義了adder算子,是繼承於nn.module的,所以在網絡定義時可以直接使用adder2d來定義一個adder層。例如resnet20.py中就如下定義一個3*3 kernel大小的adder層:
def conv3x3(in_planes, out_planes, stride=1):
" 3x3 convolution with padding "
return adder.adder2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
可以看到adder2d的使用方式和nn.Conv2d基本完全一樣。
接下來我們進一步解讀adder2d包含的屬性,和卷積算子相同,adder算子包括幾個屬性:stride,padding,input_channel,output_channel和kernel_size,給定這幾個屬性後,adder2d就會根據這些屬性定義adder filter和bias
self.adder= torch.nn.Parameter(nn.init.normal_(torch.randn(output_channel,input_channel,kernel_size,kernel_size)))
self.b = torch.nn.Parameter(nn.init.uniform_(torch.zeros(output_channel)))
最後對前向傳播使用函數adder2d_function來得到結果
output= adder2d_function(x,self.adder, self.stride, self.padding)
所以接下來我們進一步分析adder2d_function這個函數是如何進行adder算子的運算的:
def adder2d_function(X, W, stride=1, padding=0):
n_filters, d_filter, h_filter, w_filter = W.size()
n_x, d_x, h_x, w_x = X.size()
h_out = (h_x - h_filter + 2 * padding) / stride + 1
w_out = (w_x - w_filter + 2 * padding) / stride + 1 h_out, w_out = int(h_out), int(w_out)
X_col = torch.nn.functional.unfold(X.view(1, -1, h_x, w_x), h_filter, dilation=1, padding=padding, stride=stride).view(n_x, -1, h_out*w_out)
X_col = X_col.permute(1,2,0).contiguous().view(X_col.size(1),-1)
W_col = W.view(n_filters, -1)
out = adder.apply(W_col,X_col)
out = out.view(n_filters, h_out, w_out, n_x)
out = out.permute(3, 0, 1, 2).contiguous()
return out:
可以看到,adder2d_function將輸入X和卷積核W先進行了一系列變換,變換為W_col和X_col兩個矩陣後再進行計算,這和卷積的計算十分類似,在卷積中,我們通常將輸入圖片通過im2col變換變為矩陣,將卷積核reshape成矩陣,將卷積計算轉換為矩陣乘法運算進行。這裡adder的計算也是同樣的:
X_col = torch.nn.functional.unfold(X.view(1, -1, h_x, w_x), h_filter, dilation=1, padding=padding, stride=stride).view(n_x, -1, h_out*w_out)
X_col = X_col.permute(1,2,0).contiguous().view(X_col.size(1),-1)
上面兩行代碼就是將輸入的X進行im2col變成二維矩陣
W_col = W.view(n_filters, -1)
同樣的W也reshape成二維矩陣。
接下來如果我們要進行卷積,就將這兩個矩陣進行矩陣乘法的運算。然而我們現在進行的是adder運算,相當於將卷積中的乘法改為加法,所以需要重新定義這個矩陣運算:
out = adder.apply(W_col,X_col)
可以看到adder.apply就是重新定義的對應加法神經網絡的矩陣運算。
out = out.view(n_filters, h_out, w_out, n_x)
out = out.permute(3, 0, 1, 2).contiguous()
最後得到的output矩陣同樣通過reshape變回4維。
接下來我們仔細分析這個adder運算是如何實現的。
class adder(Function):
@staticmethod
def forward(ctx, W_col, X_col):
ctx.save_for_backward(W_col,X_col)
output = -(W_col.unsqueeze(2)-X_col.unsqueeze(0)).abs().sum(1)
return output
@staticmethod
def backward(ctx,grad_output):
W_col,X_col = ctx.saved_tensors
grad_W_col = ((X_col.unsqueeze(0)-W_col.unsqueeze(2))*grad_output.unsqueeze(1)).sum(2)
grad_W_col = grad_W_col/grad_W_col.norm(p=2).clamp(min=1e-12)*math.sqrt(W_col.size(1)*W_col.size(0))/5
grad_X_col = (-(X_col.unsqueeze(0)-W_col.unsqueeze(2)).clamp(-1,1)*grad_output.unsqueeze(1)).sum(0)
return grad_W_col, grad_X_col
這個adder運算分為兩部分:前向傳播和反向傳播。
我們先來看前向傳播的部分,只用了很簡單的一句代碼來實現:
output = -(W_col.unsqueeze(2)-X_col.unsqueeze(0)).abs().sum(1)
實際上這個代碼就是將矩陣乘法中的乘法運算用減法和絕對值來代替,我們回顧矩陣乘法,其實就是將兩個矩陣的中間維度進行對應點的相乘後再相加,假設是m*n的矩陣A和n*k的矩陣B相乘,可以將A在第三個維度複製k份,將B在第零個維度複製m份,得到m*n*k大小的矩陣A和B,將這兩個矩陣每個點相乘,最後再第二個維度求和,就得到了m*k的矩陣,也就是矩陣乘法的輸出結果,這其實就是上面代碼的實現過程,將W在第三個維度擴充,將X在第一個維度擴充,然後相減取絕對值,在第二個維度求和,就得到了adder的矩陣運算結果。
我們知道,在pytorch如果你定義好前向傳播,pytorch是會對它進行自動求導的,然而在AdderNet裡,反向傳播的梯度和真實梯度不一樣,所以我們要自己定義這個反向傳播的梯度。
AdderNet中真實梯度為:
梯度被修改為:
和
所以,和上面前向傳播類似的矩陣計算方法,可以用以下代碼計算反向傳播的值,再乘上鏈式法則中輸出的梯度,就得到了W和X的梯度。
grad_W_col= ((X_col.unsqueeze(0)-W_col.unsqueeze(2))*grad_output.unsqueeze(1)).sum(2)
grad_X_col = (-(X_col.unsqueeze(0)-W_col.unsqueeze(2)).clamp(-1,1)*grad_output.unsqueeze(1)).sum(0)
最後再加上論文中提到的adaptive learning rate:
代碼可以表示為:
grad_W_col=grad_W_col/grad_W_col.norm(p=2).clamp(min=1e-12)*math.sqrt(W_col.size(1)*W_col.size(0))/5
以上就是對AdderNet開原始碼的完整解讀。
2
結果
我們最後來看看AdderNet的實驗結果。
可以發現,AdderNet在CIFAR-10和ImageNet數據集上都取得了和CNN相似準確率的結果,並且基本不需要任何乘法,使用github開源的代碼就可以復現以上的結果。
當然,目前AdderNet的訓練還是十分慢的,作者說這是因為AdderNet沒有cuda實現加速,主要的運行速度在於adder這個矩陣計算函數。我們在這裡提供一個簡單的思路來實現cuda加速,我們先參考矩陣乘法的cuda實現https://github.com/NVIDIA/cuda-samples/blob/master/Samples/matrixMul/matrixMul.cu,將矩陣乘法中的乘改為減法和絕對值就可以了,最後,我們可以通過pytorch自帶的cuda extension來編譯cuda代碼(https://pytorch.org/tutorials/advanced/cpp_extension.html),就可以完成AdderNet的cuda加速了。
該論文已被CVPR 2020接收。
論文一作:
陳漢亭,北京大學智能科學系碩博連讀三年級在讀,同濟大學學士,師從北京大學許超教授,在華為諾亞方舟實驗室實習。研究興趣主要包括計算機視覺、機器學習和深度學習。在 ICCV,AAAI,CVPR 等會議發表論文數篇,目前主要研究方向為神經網絡模型小型化。
論文二作:
王雲鶴,在華為諾亞方舟實驗室從事邊緣計算領域的算法開發和工程落地,研究領域包含深度神經網絡的模型裁剪、量化、蒸餾和自動搜索等。王雲鶴博士畢業於北京大學,在相關領域發表學術論文40餘篇,包含NeurIPS、ICML、CVPR、ICCV、TPAMI、AAAI、IJCAI等。
論文地址:https://arxiv.org/pdf/1912.13200.pdf
Github 代碼地址:https://github.com/huawei-noah/AdderNet
招 聘
AI 科技評論希望能夠招聘 科技編輯/記者
辦公地點:北京/深圳
職務:以跟蹤學術熱點、人物專訪為主
工作內容:
1、關注學術領域熱點事件,並及時跟蹤報導;
2、採訪人工智慧領域學者或研發人員;
3、參加各種人工智慧學術會議,並做會議內容報導。
要求:
1、熱愛人工智慧學術研究內容,擅長與學者或企業工程人員打交道;
2、有一定的理工科背景,對人工智慧技術有所了解者更佳;
3、英語能力強(工作內容涉及大量英文資料);
4、學習能力強,對人工智慧前沿技術有一定的了解,並能夠逐漸形成自己的觀點。
感興趣者,可將簡歷發送到郵箱:jiangbaoshang@yanxishe.com
點擊播放 GIF 0.0M
點
擊"閱讀原文",直達「CVPR 交流小組」了解更多會議信息。