最近用keras訓練模型時由於數據集很大,使用fit會佔用很多內存,所以用了下批量訓練接口fit_generator,首先我們先看下這兩個接口的基本調用方式:
從上述兩個例子可以看出來,fit是直接加載所有的訓練集數據,而fit_generator通過自定義的批加載函數每次只讀取batch_size大小的數據集進行訓練,假設有1W張圖片,fit則要直接把這1W張圖片加載進行迭代訓練,而fit_generator每次只加載100張圖像進行訓練,則其內存佔用只有fit的百分之一,就不用擔心內存過多了。
好了,對於大數據集使用fit_generator確實可以減少內存佔用,這個就不用測了,那麼速度上fit與fit_generator相比會怎樣呢?我們下面來進行測試對比一下:
model.fit性能
model.fit_generator性能
上面的測試是60000幅圖像,batch_size=100,epoch=2,從結果可以看出兩次測試速度一樣,大概都是38秒一輪,測試集準確率和誤差幾乎一致,這說明在性能方面其效果是一致的。
好了,總結一下就是對於大數據集來說,如果訓練集佔用內存過大的話可以使用fit_generator進行批訓練,對於小數據集的話就沒必要了,直接fit就行。