423 lines
18 KiB
Python
423 lines
18 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Base class to make optimizers weight decay ready."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
class DecoupledWeightDecayExtension(object):
|
|
"""This class allows to extend optimizers with decoupled weight decay.
|
|
|
|
It implements the decoupled weight decay described by Loshchilov & Hutter
|
|
(https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is
|
|
decoupled from the optimization steps w.r.t. to the loss function.
|
|
For SGD variants, this simplifies hyperparameter search since it decouples
|
|
the settings of weight decay and learning rate.
|
|
For adaptive gradient algorithms, it regularizes variables with large
|
|
gradients more than L2 regularization would, which was shown to yield
|
|
better training loss and generalization error in the paper above.
|
|
|
|
This class alone is not an optimizer but rather extends existing
|
|
optimizers with decoupled weight decay. We explicitly define the two
|
|
examples used in the above paper (SGDW and AdamW), but in general this
|
|
can extend any OptimizerX by using
|
|
`extend_with_decoupled_weight_decay(
|
|
OptimizerX, weight_decay=weight_decay)`.
|
|
In order for it to work, it must be the first class the Optimizer with
|
|
weight decay inherits from, e.g.
|
|
|
|
```python
|
|
class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam):
|
|
def __init__(self, weight_decay, *args, **kwargs):
|
|
super(AdamW, self).__init__(weight_decay, *args, **kwargs).
|
|
```
|
|
|
|
Note: this extension decays weights BEFORE applying the update based
|
|
on the gradient, i.e. this extension only has the desired behaviour for
|
|
optimizers which do not depend on the value of'var' in the update step!
|
|
|
|
Note: when applying a decay to the learning rate, be sure to manually apply
|
|
the decay to the `weight_decay` as well. For example:
|
|
|
|
```python
|
|
step = tf.Variable(0, trainable=False)
|
|
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
|
|
[10000, 15000], [1e-0, 1e-1, 1e-2])
|
|
# lr and wd can be a function or a tensor
|
|
lr = 1e-1 * schedule(step)
|
|
wd = lambda: 1e-4 * schedule(step)
|
|
|
|
# ...
|
|
|
|
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
|
|
```
|
|
"""
|
|
|
|
def __init__(self, weight_decay, **kwargs):
|
|
"""Extension class that adds weight decay to an optimizer.
|
|
|
|
Args:
|
|
weight_decay: A `Tensor` or a floating point value, the factor by
|
|
which a variable is decayed in the update step.
|
|
**kwargs: Optional list or tuple or set of `Variable` objects to
|
|
decay.
|
|
"""
|
|
wd = kwargs.pop('weight_decay', weight_decay)
|
|
super(DecoupledWeightDecayExtension, self).__init__(**kwargs)
|
|
self._decay_var_list = None # is set in minimize or apply_gradients
|
|
self._set_hyper('weight_decay', wd)
|
|
|
|
def get_config(self):
|
|
config = super(DecoupledWeightDecayExtension, self).get_config()
|
|
config.update({
|
|
'weight_decay':
|
|
self._serialize_hyperparameter('weight_decay'),
|
|
})
|
|
return config
|
|
|
|
def minimize(self,
|
|
loss,
|
|
var_list,
|
|
grad_loss=None,
|
|
name=None,
|
|
decay_var_list=None):
|
|
"""Minimize `loss` by updating `var_list`.
|
|
|
|
This method simply computes gradient using `tf.GradientTape` and calls
|
|
`apply_gradients()`. If you want to process the gradient before
|
|
applying then call `tf.GradientTape` and `apply_gradients()` explicitly
|
|
instead of using this function.
|
|
|
|
Args:
|
|
loss: A callable taking no arguments which returns the value to
|
|
minimize.
|
|
var_list: list or tuple of `Variable` objects to update to
|
|
minimize `loss`, or a callable returning the list or tuple of
|
|
`Variable` objects. Use callable when the variable list would
|
|
otherwise be incomplete before `minimize` since the variables
|
|
are created at the first time `loss` is called.
|
|
grad_loss: Optional. A `Tensor` holding the gradient computed for
|
|
`loss`.
|
|
decay_var_list: Optional list of variables to be decayed. Defaults
|
|
to all variables in var_list.
|
|
name: Optional name for the returned operation.
|
|
Returns:
|
|
An Operation that updates the variables in `var_list`. If
|
|
`global_step` was not `None`, that operation also increments
|
|
`global_step`.
|
|
Raises:
|
|
ValueError: If some of the variables are not `Variable` objects.
|
|
"""
|
|
self._decay_var_list = set(decay_var_list) if decay_var_list else False
|
|
return super(DecoupledWeightDecayExtension, self).minimize(
|
|
loss, var_list=var_list, grad_loss=grad_loss, name=name)
|
|
|
|
def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None):
|
|
"""Apply gradients to variables.
|
|
|
|
This is the second part of `minimize()`. It returns an `Operation` that
|
|
applies gradients.
|
|
|
|
Args:
|
|
grads_and_vars: List of (gradient, variable) pairs.
|
|
name: Optional name for the returned operation. Default to the
|
|
name passed to the `Optimizer` constructor.
|
|
decay_var_list: Optional list of variables to be decayed. Defaults
|
|
to all variables in var_list.
|
|
Returns:
|
|
An `Operation` that applies the specified gradients. If
|
|
`global_step` was not None, that operation also increments
|
|
`global_step`.
|
|
Raises:
|
|
TypeError: If `grads_and_vars` is malformed.
|
|
ValueError: If none of the variables have gradients.
|
|
"""
|
|
self._decay_var_list = set(decay_var_list) if decay_var_list else False
|
|
return super(DecoupledWeightDecayExtension, self).apply_gradients(
|
|
grads_and_vars, name=name)
|
|
|
|
def _decay_weights_op(self, var):
|
|
if not self._decay_var_list or var in self._decay_var_list:
|
|
return var.assign_sub(
|
|
self._get_hyper('weight_decay', var.dtype) * var,
|
|
self._use_locking)
|
|
return tf.no_op()
|
|
|
|
def _decay_weights_sparse_op(self, var, indices):
|
|
if not self._decay_var_list or var in self._decay_var_list:
|
|
update = (-self._get_hyper('weight_decay', var.dtype) * tf.gather(
|
|
var, indices))
|
|
return self._resource_scatter_add(var, indices, update)
|
|
return tf.no_op()
|
|
|
|
# Here, we overwrite the apply functions that the base optimizer calls.
|
|
# super().apply_x resolves to the apply_x function of the BaseOptimizer.
|
|
|
|
def _resource_apply_dense(self, grad, var):
|
|
with tf.control_dependencies([self._decay_weights_op(var)]):
|
|
return super(DecoupledWeightDecayExtension,
|
|
self)._resource_apply_dense(grad, var)
|
|
|
|
def _resource_apply_sparse(self, grad, var, indices):
|
|
decay_op = self._decay_weights_sparse_op(var, indices)
|
|
with tf.control_dependencies([decay_op]):
|
|
return super(DecoupledWeightDecayExtension,
|
|
self)._resource_apply_sparse(grad, var, indices)
|
|
|
|
|
|
def extend_with_decoupled_weight_decay(base_optimizer):
|
|
"""Factory function returning an optimizer class with decoupled weight
|
|
decay.
|
|
|
|
Returns an optimizer class. An instance of the returned class computes the
|
|
update step of `base_optimizer` and additionally decays the weights.
|
|
E.g., the class returned by
|
|
`extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)` is
|
|
equivalent to `tfa.optimizers.AdamW`.
|
|
|
|
The API of the new optimizer class slightly differs from the API of the
|
|
base optimizer:
|
|
- The first argument to the constructor is the weight decay rate.
|
|
- `minimize` and `apply_gradients` accept the optional keyword argument
|
|
`decay_var_list`, which specifies the variables that should be decayed.
|
|
If `None`, all variables that are optimized are decayed.
|
|
|
|
Usage example:
|
|
```python
|
|
# MyAdamW is a new class
|
|
MyAdamW = extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)
|
|
# Create a MyAdamW object
|
|
optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001)
|
|
# update var1, var2 but only decay var1
|
|
optimizer.minimize(loss, var_list=[var1, var2], decay_variables=[var1])
|
|
|
|
Note: this extension decays weights BEFORE applying the update based
|
|
on the gradient, i.e. this extension only has the desired behaviour for
|
|
optimizers which do not depend on the value of 'var' in the update step!
|
|
|
|
Note: when applying a decay to the learning rate, be sure to manually apply
|
|
the decay to the `weight_decay` as well. For example:
|
|
|
|
```python
|
|
step = tf.Variable(0, trainable=False)
|
|
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
|
|
[10000, 15000], [1e-0, 1e-1, 1e-2])
|
|
# lr and wd can be a function or a tensor
|
|
lr = 1e-1 * schedule(step)
|
|
wd = lambda: 1e-4 * schedule(step)
|
|
|
|
# ...
|
|
|
|
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
|
|
```
|
|
|
|
Note: you might want to register your own custom optimizer using
|
|
`tf.keras.utils.get_custom_objects()`.
|
|
|
|
Args:
|
|
base_optimizer: An optimizer class that inherits from
|
|
tf.optimizers.Optimizer.
|
|
|
|
Returns:
|
|
A new optimizer class that inherits from DecoupledWeightDecayExtension
|
|
and base_optimizer.
|
|
"""
|
|
|
|
class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension,
|
|
base_optimizer):
|
|
"""Base_optimizer with decoupled weight decay.
|
|
|
|
This class computes the update step of `base_optimizer` and
|
|
additionally decays the variable with the weight decay being
|
|
decoupled from the optimization steps w.r.t. to the loss
|
|
function, as described by Loshchilov & Hutter
|
|
(https://arxiv.org/pdf/1711.05101.pdf). For SGD variants, this
|
|
simplifies hyperparameter search since it decouples the settings
|
|
of weight decay and learning rate. For adaptive gradient
|
|
algorithms, it regularizes variables with large gradients more
|
|
than L2 regularization would, which was shown to yield better
|
|
training loss and generalization error in the paper above.
|
|
"""
|
|
|
|
def __init__(self, weight_decay, *args, **kwargs):
|
|
# super delegation is necessary here
|
|
super(OptimizerWithDecoupledWeightDecay, self).__init__(
|
|
weight_decay, *args, **kwargs)
|
|
|
|
return OptimizerWithDecoupledWeightDecay
|
|
|
|
|
|
class SGDW(DecoupledWeightDecayExtension, tf.keras.optimizers.SGD):
|
|
"""Optimizer that implements the Momentum algorithm with weight_decay.
|
|
|
|
This is an implementation of the SGDW optimizer described in "Decoupled
|
|
Weight Decay Regularization" by Loshchilov & Hutter
|
|
(https://arxiv.org/abs/1711.05101)
|
|
([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
|
|
It computes the update step of `tf.keras.optimizers.SGD` and additionally
|
|
decays the variable. Note that this is different from adding
|
|
L2 regularization on the variables to the loss. Decoupling the weight decay
|
|
from other hyperparameters (in particular the learning rate) simplifies
|
|
hyperparameter search.
|
|
|
|
For further information see the documentation of the SGD Optimizer.
|
|
|
|
This optimizer can also be instantiated as
|
|
```python
|
|
extend_with_decoupled_weight_decay(tf.keras.optimizers.SGD,
|
|
weight_decay=weight_decay)
|
|
```
|
|
|
|
Note: when applying a decay to the learning rate, be sure to manually apply
|
|
the decay to the `weight_decay` as well. For example:
|
|
|
|
```python
|
|
step = tf.Variable(0, trainable=False)
|
|
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
|
|
[10000, 15000], [1e-0, 1e-1, 1e-2])
|
|
# lr and wd can be a function or a tensor
|
|
lr = 1e-1 * schedule(step)
|
|
wd = lambda: 1e-4 * schedule(step)
|
|
|
|
# ...
|
|
|
|
optimizer = tfa.optimizers.SGDW(
|
|
learning_rate=lr, weight_decay=wd, momentum=0.9)
|
|
```
|
|
"""
|
|
|
|
def __init__(self,
|
|
weight_decay,
|
|
learning_rate=0.001,
|
|
momentum=0.0,
|
|
nesterov=False,
|
|
name='SGDW',
|
|
**kwargs):
|
|
"""Construct a new SGDW optimizer.
|
|
|
|
For further information see the documentation of the SGD Optimizer.
|
|
|
|
Args:
|
|
learning_rate: float hyperparameter >= 0. Learning rate.
|
|
momentum: float hyperparameter >= 0 that accelerates SGD in the
|
|
relevant direction and dampens oscillations.
|
|
nesterov: boolean. Whether to apply Nesterov momentum.
|
|
name: Optional name prefix for the operations created when applying
|
|
gradients. Defaults to 'SGD'.
|
|
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
|
|
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
|
|
norm; `clipvalue` is clip gradients by value, `decay` is
|
|
included for backward compatibility to allow time inverse decay
|
|
of learning rate. `lr` is included for backward compatibility,
|
|
recommended to use `learning_rate` instead.
|
|
"""
|
|
super(SGDW, self).__init__(
|
|
weight_decay,
|
|
learning_rate=learning_rate,
|
|
momentum=momentum,
|
|
nesterov=nesterov,
|
|
name=name,
|
|
**kwargs)
|
|
|
|
|
|
class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam):
|
|
"""Optimizer that implements the Adam algorithm with weight decay.
|
|
|
|
This is an implementation of the AdamW optimizer described in "Decoupled
|
|
Weight Decay Regularization" by Loshchilov & Hutter
|
|
(https://arxiv.org/abs/1711.05101)
|
|
([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
|
|
|
|
It computes the update step of `tf.keras.optimizers.Adam` and additionally
|
|
decays the variable. Note that this is different from adding L2
|
|
regularization on the variables to the loss: it regularizes variables with
|
|
large gradients more than L2 regularization would, which was shown to yield
|
|
better training loss and generalization error in the paper above.
|
|
|
|
For further information see the documentation of the Adam Optimizer.
|
|
|
|
This optimizer can also be instantiated as
|
|
```python
|
|
extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam,
|
|
weight_decay=weight_decay)
|
|
```
|
|
|
|
Note: when applying a decay to the learning rate, be sure to manually apply
|
|
the decay to the `weight_decay` as well. For example:
|
|
|
|
```python
|
|
step = tf.Variable(0, trainable=False)
|
|
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
|
|
[10000, 15000], [1e-0, 1e-1, 1e-2])
|
|
# lr and wd can be a function or a tensor
|
|
lr = 1e-1 * schedule(step)
|
|
wd = lambda: 1e-4 * schedule(step)
|
|
|
|
# ...
|
|
|
|
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
|
|
```
|
|
"""
|
|
|
|
def __init__(self,
|
|
weight_decay,
|
|
learning_rate=0.001,
|
|
beta_1=0.9,
|
|
beta_2=0.999,
|
|
epsilon=1e-07,
|
|
amsgrad=False,
|
|
name="AdamW",
|
|
**kwargs):
|
|
"""Construct a new AdamW optimizer.
|
|
|
|
For further information see the documentation of the Adam Optimizer.
|
|
|
|
Args:
|
|
weight_decay: A Tensor or a floating point value. The weight decay.
|
|
learning_rate: A Tensor or a floating point value. The learning
|
|
rate.
|
|
beta_1: A float value or a constant float tensor. The exponential
|
|
decay rate for the 1st moment estimates.
|
|
beta_2: A float value or a constant float tensor. The exponential
|
|
decay rate for the 2nd moment estimates.
|
|
epsilon: A small constant for numerical stability. This epsilon is
|
|
"epsilon hat" in the Kingma and Ba paper (in the formula just
|
|
before Section 2.1), not the epsilon in Algorithm 1 of the
|
|
paper.
|
|
amsgrad: boolean. Whether to apply AMSGrad variant of this
|
|
algorithm from the paper "On the Convergence of Adam and
|
|
beyond".
|
|
name: Optional name for the operations created when applying
|
|
gradients. Defaults to "AdamW".
|
|
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
|
|
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
|
|
norm; `clipvalue` is clip gradients by value, `decay` is
|
|
included for backward compatibility to allow time inverse decay
|
|
of learning rate. `lr` is included for backward compatibility,
|
|
recommended to use `learning_rate` instead.
|
|
"""
|
|
super(AdamW, self).__init__(
|
|
weight_decay,
|
|
learning_rate=learning_rate,
|
|
beta_1=beta_1,
|
|
beta_2=beta_2,
|
|
epsilon=epsilon,
|
|
amsgrad=amsgrad,
|
|
name=name,
|
|
**kwargs)
|