Batch、Mini-batch和隨機梯度下降的區別和Python示例

2021-01-11 不靠譜的貓

在研究機器學習和深度學習時出現的主要問題之一是梯度下降的幾種類型。在梯度下降的三種類型(Batch梯度下降、Mini-batch梯度下降和隨機梯度下降)中,我應該使用哪一種呢?在這篇文章中,我們將了解這些概念之間的區別,並從梯度下降的代碼實現來闡明這些方法。

梯度下降

梯度下降是幫助神經網絡獲得正確的權重值和偏差值的最常見算法之一。梯度下降法(GD)是在每一步中最小化成本函數J(W,b)的一種算法。它迭代地更新權重和偏差,以嘗試在成本函數中達到全局最小值。

最小化成本函數,梯度下降圖

在我們計算GD之前,首先獲取輸入並通過神經網絡的所有節點,然後計算輸入、權重和偏差的加權和。這是計算梯度下降的主要步驟之一,稱為正向傳播。一旦我們有了一個輸出,我們將這個輸出與預期的輸出進行比較,並計算出它們之間的差異,即誤差。有了這個誤差,我們現在可以反向傳播它,更新每個權重和偏差,並嘗試最小化這個誤差。正如你所預料的,這部分被稱為反向傳播。反向傳播步驟是使用導數計算得出的,並返回「梯度」,這個值告訴我們應遵循哪個方向以最小化成本函數。

現在我們準備更新權重矩陣W和偏差向量b了。梯度下降規則如下:

換句話說,新的權重/偏差值將是最後一個權重/偏差值減去梯度的值,使其接近成本函數的全局最小值。我們還將這個梯度乘以一個學習率,它控制著步長。

這種經典的梯度下降法也稱為Batch梯度下降法。在這種方法中,每個epoch遍歷所有訓練數據,然後計算損失並更新W和b值。該方法雖然具有穩定的收斂性和穩定的誤差,但是該方法使用了整個機器學習訓練集,因此,對於大型機器學習數據集會非常慢。

Mini-batch梯度下降

想像一下,將您的數據集分成幾個batches。這樣,它就不必等到算法遍歷整個數據集後才更新權重和偏差,而是在每個所謂的Mini-batch結束時進行更新。這使得我們能夠快速將成本函數移至全局最小值,並在每個epoch中多次更新權重和偏差。最常見的Mini-batch大小是16、32、64、128、256和512。大多數項目使用Mini-batch梯度下降,因為它在較大的機器學習數據集中速度更快。

Batch梯度下降

如前所述,在此梯度下降中,每個Batch等於整個數據集。那是:

其中{1}表示Mini-batch中的第一批次。缺點是每次迭代花費的時間太長。此方法可用於訓練少於2000個樣本的機器學習數據集。

隨機梯度下降

在這種方法中,每個batch等於訓練集中的一個實例。

其中(1)表示第一個訓練實例。這裡的缺點是它失去了向量化的優勢,有更多的振蕩但收斂得更快。

最後

理解這些優化算法之間的區別是很重要的,因為它們構成了神經網絡的關鍵功能。綜上所述,Batch梯度下降雖然比隨機梯度下降具有更高的準確度,但是隨機梯度下降的速度更快。Mini-batch梯度下降很好地結合了兩者,從而提供了良好的準確性和性能。

可以僅使用Mini-batch梯度下降代碼來實現所有版本的梯度下降,對於隨機梯度下降可以將mini_batch_size設置為1,對於Batch梯度下降可以將mini_batch_size設置為數據集中的實例數。因此,Batch、Mini-batch和隨機梯度下降之間的主要區別是每個epoch使用的實例數以及達到成本函數的全局最小值所需的時間。

相關焦點

  • 梯度下降—Python實現
    梯度下降是數據科學的基礎,無論是深度學習還是機器學習。深入了解梯度下降原理一定會對你今後的工作有所幫助。你將真正了解這些超參數的作用以及處理使用此算法可能遇到的問題。然而,梯度下降並不局限於一種算法。另外兩種流行的梯度下降(隨機和小批量梯度下降)建立在主要算法的基礎上,你可能會看到比普通批量梯度下降更多的算法。
  • 如何理解深度學習分布式訓練中的large batch size與learning rate...
    (1)理解SGD、minibatch-SGD和GD在機器學習優化算法中,GD(gradient descent)是最常用的方法之一,簡單來說就是在整個訓練集中計算當前的梯度,選定一個步長進行更新。GD的優點是,基於整個數據集得到的梯度,梯度估計相對較準,更新過程更準確。
  • 神經網絡算法Batch Normalization的分析與展望 | 大牛講堂
    BN的動機和成功的原因訓練神經網絡,通常是用反向傳播算法(BP)+隨機梯度下降(SGD),具體到一個具體的第L層連接神經元i和神經元j的參數的更新就是,通常我們會希望這個更新比較穩定。BN的結果之所以更好,可能作對了兩個地方,第一,正像我們之前講的,一堆隨機數的和()更接近一個正態分布,在這個和上來做標準化更加容易使得通過非線性之後依然保持穩定,對於使用sigmoid非線性的時候,對於一個mini batch而言大概接近68%的值在[0.27 0.73]之間,95%的值在[0.12 0.88]之間,這當然很好,基本都在sigmoid函數的線性部分,不會出現飽和的情況(這裡的梯度不會接近
  • 神經網絡算法BatchNormalization的分析與展望|大牛講堂
    BN的動機和成功的原因訓練神經網絡,通常是用反向傳播算法(BP)+隨機梯度下降(SGD),具體到一個具體的第L層連接神經元i和神經元j的參數的更新就是,通常我們會希望這個更新比較穩定。BN的結果之所以更好,可能作對了兩個地方,第一,正像我們之前講的,一堆隨機數的和()更接近一個正態分布,在這個和上來做標準化更加容易使得通過非線性之後依然保持穩定,對於使用sigmoid非線性的時候,對於一個mini batch而言大概接近68%的值在[0.27 0.73]之間,95%的值在[0.12 0.88]之間,這當然很好,基本都在sigmoid函數的線性部分,不會出現飽和的情況(這裡的梯度不會接近
  • 機器學習:隨機梯度下降和批量梯度下降算法介紹
    機器學習:隨機梯度下降和批量梯度下降算法介紹 佚名 發表於 2017-11-28 04:00:28 隨機梯度下降(Stochastic gradient descent)
  • 梯度下降兩大痛點:陷入局部極小值和過擬合
    打開APP 梯度下降兩大痛點:陷入局部極小值和過擬合 胡薇 發表於 2018-04-27 17:01:36 介紹 基於梯度下降訓練神經網絡時,我們將冒網絡落入局部極小值的風險,網絡在誤差平面上停止的位置並非整個平面的最低點。
  • 批歸一化Batch Normalization的原理及算法
    隨機梯度下降法是訓練深度網絡的首選。儘管隨機梯度下降法對於訓練深度網絡簡單高效,但是需要我們人為的去選擇參數,比如學習速率、初始化參數、權重衰減係數、Drop out比例,等等。這些參數的選擇對訓練結果至關重要,以至於我們很多時間都浪費在這些的調參上。
  • 【乾貨】機器學習最常用優化之一——梯度下降優化算法綜述
    而隨機梯度下降算法每次只隨機選擇一個樣本來更新模型參數,因此每次的學習是非常快速的,並且可以進行在線更新。 Mini-batch 梯度下降綜合了 batch 梯度下降與 stochastic 梯度下降,在每次更新速度與更新次數中間取得一個平衡,其每次更新從訓練集中隨機選擇 m,m<n 個樣本進行學習,即: θ=θ−η⋅∇θJ(θ;xi:i+m;yi:i+m) 其代碼如下:
  • 一文讀懂線性回歸和梯度下降
    我們又兩種方式將只有一個樣本的數學表達轉化為樣本為多個的情況:梯度下降(gradient descent)和正則方程(The normal equations)。這裡我們重點講梯度下降。 梯度下降批梯度下降(batch gradient descent)    如下公式是處理一個樣本的表達式:
  • 一文詳解神經網絡 BP 算法原理及 Python 實現
    什麼是梯度下降和鏈式求導法則 假設我們有一個函數 J(w),如下圖所示。 梯度下降示意圖 現在,我們要求當 w 等於什麼的時候,J(w) 能夠取到最小值。從圖中我們知道最小值在初始位置的左邊,也就意味著如果想要使 J(w) 最小,w的值需要減小。
  • 從頭開始:用Python實現帶隨機梯度下降的Logistic回歸
    隨機梯度下降梯度下降是通過順著成本函數(cost function)的梯度來最小化函數的過程。這涉及到成本函數的形式及導數,使得從任意給定點能推算梯度並在該方向上移動,例如,沿坡向下(downhill)直到最小值。
  • batchplot批量列印怎麼用?Batchplot功能特色
    能夠批量列印 AutoCAD 和其他軟體生成的DWG/DXF 文件。2. 支持智能識別圖紙的列印區域(圖框)及繪圖比例,無需逐個設置圖紙的列印區域。3. 支持按目錄添加圖紙,高效率添加圖紙文件。4. 支持一次選擇多個圖紙文件批量列印,也能夠智能識別一個文件中的多張圖紙。5. 圖紙排版算法國際領先,最大限度地節省紙張和耗材。6.
  • 零基礎入門深度學習(六):圖像分類任務之LeNet和AlexNet
    從本課程中,你將學習到:深度學習基礎知識numpy實現神經網絡構建和梯度下降算法計算機視覺領域主要方向的原理、實踐自然語言處理領域主要方向的原理、實踐(img)batch_labels.append(label)if len(batch_imgs) == batch_size:# 當數據列表的長度等於batch_size的時候,# 把這些數據當作一個mini-batch,並作為數據生成器的一個輸出
  • 機器學習 101:一文帶你讀懂梯度下降
    首先,我們使用pandas在python中加載數據,並分離房屋大小和價格特徵。之後,我們對數據進行標準化,以防止某些特徵的大小範圍與其他特徵不同。而且,標準化過的數據在進行梯度下降時,收斂速度比其他方法快得多。
  • 乾貨分享|使用JAX創建神經網絡的對抗性示例(附詳細代碼)
    在本教程中,我們將看到如何創建使用JAX訓練神經網絡的對抗示例。首先,讓我們看一些定義。有哪些例子?簡而言之,對抗性示例是神經網絡的輸入,這些輸入經過優化以欺騙算法,即導致目標變量分類錯誤。通過向目標變量添加「適當的」噪聲,我們可以對目標變量進行錯誤分類。下圖演示了該概念。本教程的重點是演示如何創建對抗示例。我們將使用快速梯度符號法生成。