點擊上方「MLNLP」,選擇「星標」公眾號
重磅乾貨,第一時間送達
來自 | 知乎
作者 | 藥師
地址 | https://zhuanlan.zhihu.com/p/87209990專欄 | 非凸優化學習之路
編輯 | 機器學習算法與自然語言處理
【PyTorch】優化器 torch.optim.Optimizer
之前寫過一篇 TensorFlow 的優化器 AdamOptimizer 的源碼解讀(連結:https://zhuanlan.zhihu.com/p/63500952),這次來看一看 PyTorch 的優化器源碼。
PyTorch 的優化器基本都繼承於 "class Optimizer",這是所有 optimizer 的 base class,本文嘗試對其中的源碼進行解讀。
總的來說,PyTorch 中 Optimizer 的代碼相較於 TensorFlow 要更易讀一些。下邊先通過一個簡單的例子看一下,PyTorch 中是如何使用優化器的。
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
首先,在創建優化器對象的時候,要傳入網絡模型的參數,並設置學習率等優化方法的參數。然後使用函數zero_grad將梯度置為零。接著調用函數backward來進行反向傳播計算梯度。最後使用優化器的step函數來更新參數。
以 PyTorch 中的 SGD Optimizer 為例,下邊是 __init__ 函數。網絡模型的參數被傳進來後,用params表示;其餘參數被打包進字典中命名為defaults。再通過super(SGD, self).__init__(params, defaults)來將params和defaults傳給父類Optimizer的__init__函數。
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(SGD, self).__init__(params, defaults)
這樣做的好處就是,我可以把子類中一些相同的處理操作集中寫到父類的初始化函數中去,這樣所有子類只需要調用就好了。例如 SGD 類的其他函數中所用到的self.param_groups 就是在父類的__init__函數中創建的。
def __init__(self, params, defaults):
torch._C._log_api_usage_once("python.optimizer")
self.defaults = defaults
if isinstance(params, torch.Tensor):
raise TypeError("params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " +
torch.typename(params))
self.state = defaultdict(dict)
self.param_groups = []
param_groups = list(params)
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}]
for param_group in param_groups:
self.add_param_group(param_group)
其中第十行中的defaultdict的作用在於當字典裡的 key 被查找但不存在時,返回的不是keyError而是一個默認值,此處defaultdict(dict)返回的默認值會是個空字典。最後一行調用的self.add_param_group(param_group),其中param_group是個字典,Key 就是params,Value 就是param_groups = list(params)。
函數add_param_group的主要作用是將param_group放進self.param_groups中。首先要將網絡參數重新組織放到列表中param_group['params'] = list(params)。接著將self.defaults中的鍵值對遍歷放到字典param_group中。之後對self.param_groups和param_group中的元素進行判斷,確保沒有重複的參數。最後將字典param_group放進列表self.param_groups中。( 註:self.param_groups = []是在__init__函數中創建的 )
def add_param_group(self, param_group):
r"""Add a param group to the :class:`Optimizer` s `param_groups`.
This can be useful when fine tuning a pre-trained network as frozen layers can be made
trainable and added to the :class:`Optimizer` as training progresses.
Arguments:
param_group (dict): Specifies what Tensors should be optimized along with group
specific optimization options.
"""
assert isinstance(param_group, dict), "param group must be a dict"
params = param_group['params']
if isinstance(params, torch.Tensor):
param_group['params'] = [params]
elif isinstance(params, set):
raise TypeError('optimizer parameters need to be organized in '
'ordered collections, but the ordering of tensors in sets '
'will change between runs. Please use a list instead.')
else:
param_group['params'] = list(params)
for param in param_group['params']:
if not isinstance(param, torch.Tensor):
raise TypeError("optimizer can only optimize Tensors, "
"but one of the params is " + torch.typename(param))
if not param.is_leaf:
raise ValueError("can't optimize a non-leaf Tensor")
for name, default in self.defaults.items():
if default is required and name not in param_group:
raise ValueError("parameter group didn't specify a value of required "
"optimization parameter " + name)
else:
param_group.setdefault(name, default)
param_set = set()
for group in self.param_groups:
param_set.update(set(group['params']))
if not param_set.isdisjoint(set(param_group['params'])):
raise ValueError("some parameters appear in more than one parameter group")
self.param_groups.append(param_group)
接下來看一下函數zero_grad。優化器 SGD 中的 zero_grad 函數如下所示。可以看到,操作很簡單,就是將所有參數的梯度置為零p.grad.zero_()。detach_()的作用是Detaches the Tensor from the graph that created it, making it a leaf. self.param_groups是列表,其中的元素是字典。
def zero_grad(self):
r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
Optimizer 更新參數主要是靠 step 函數,在父類 Optimizer 的 step 函數中只有一行代碼raise NotImplementedError 。SGD 中的實現如下所示。正如前邊介紹的,網絡模型參數和優化器的參數都保存在列表 self.param_groups 的元素中,該元素以字典形式存儲和訪問具體的網絡模型參數和優化器的參數。所以,可以通過兩層循環訪問網絡模型的每一個參數 p 。獲取到梯度d_p = p.grad.data之後,根據優化器參數設置是否使用 momentum或者nesterov再對參數進行調整。最後一行 p.data.add_(-group['lr'], d_p)的作用是對參數進行更新。
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
if weight_decay != 0:
d_p.add_(weight_decay, p.data)
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(1 - dampening, d_p)
if nesterov:
d_p = d_p.add(momentum, buf)
else:
d_p = buf
p.data.add_(-group['lr'], d_p)
return loss
PyTorch 中的 Adam Optimizer 和 SGD Optimizer 的主要區別也是 step 函數不同。Adam Optimizer 中的 step 函數如下所示。其中,對於每個網絡模型參數都使用state['exp_avg']和state['exp_avg_sq']來保存 梯度 和 梯度的平方 的移動平均值。第一次更新的時候沒有state,即len(state) == 0,所以兩個數值都需要使用torch.zeros_like(p.data)來初始化為 ,之後每次都只需要從state中取出該值使用和更新即可。state['step']用於保存本次更新是優化器第幾輪迭代更新參數。最後使用p.data.addcdiv_(-step_size, exp_avg, denom)來更新網絡模型參數 p 。
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, '
'please consider SparseAdam instead')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
if group['weight_decay'] != 0:
grad.add_(group['weight_decay'], p.data)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
else:
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
step_size = group['lr'] / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom)
return loss
值得注意的是,Adam Optimizer 只能處理 dense gradient,要想處理 sparse gradient 需要使用 SparseAdam Optimizer 。
另外,我收集了一些 PyTorch 實現的 Optimizer,歡迎大家來一起維護。
地址:https://github.com/201419/Optimizer-PyTorch
推薦閱讀:
基於多任務自監督學習的文本順滑研究
基於Transformers+CNN/LSTM/GRU的文本分類
自然語言領域中圖神經網絡模型(GNN)應用現狀(論文)