autodl-projects/lib/tf_optimizers/weight_decay_optimizers.py
2020-01-05 22:19:38 +11:00

423 lines
18 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class to make optimizers weight decay ready."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
class DecoupledWeightDecayExtension(object):
"""This class allows to extend optimizers with decoupled weight decay.
It implements the decoupled weight decay described by Loshchilov & Hutter
(https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is
decoupled from the optimization steps w.r.t. to the loss function.
For SGD variants, this simplifies hyperparameter search since it decouples
the settings of weight decay and learning rate.
For adaptive gradient algorithms, it regularizes variables with large
gradients more than L2 regularization would, which was shown to yield
better training loss and generalization error in the paper above.
This class alone is not an optimizer but rather extends existing
optimizers with decoupled weight decay. We explicitly define the two
examples used in the above paper (SGDW and AdamW), but in general this
can extend any OptimizerX by using
`extend_with_decoupled_weight_decay(
OptimizerX, weight_decay=weight_decay)`.
In order for it to work, it must be the first class the Optimizer with
weight decay inherits from, e.g.
```python
class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam):
def __init__(self, weight_decay, *args, **kwargs):
super(AdamW, self).__init__(weight_decay, *args, **kwargs).
```
Note: this extension decays weights BEFORE applying the update based
on the gradient, i.e. this extension only has the desired behaviour for
optimizers which do not depend on the value of'var' in the update step!
Note: when applying a decay to the learning rate, be sure to manually apply
the decay to the `weight_decay` as well. For example:
```python
step = tf.Variable(0, trainable=False)
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
[10000, 15000], [1e-0, 1e-1, 1e-2])
# lr and wd can be a function or a tensor
lr = 1e-1 * schedule(step)
wd = lambda: 1e-4 * schedule(step)
# ...
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
```
"""
def __init__(self, weight_decay, **kwargs):
"""Extension class that adds weight decay to an optimizer.
Args:
weight_decay: A `Tensor` or a floating point value, the factor by
which a variable is decayed in the update step.
**kwargs: Optional list or tuple or set of `Variable` objects to
decay.
"""
wd = kwargs.pop('weight_decay', weight_decay)
super(DecoupledWeightDecayExtension, self).__init__(**kwargs)
self._decay_var_list = None # is set in minimize or apply_gradients
self._set_hyper('weight_decay', wd)
def get_config(self):
config = super(DecoupledWeightDecayExtension, self).get_config()
config.update({
'weight_decay':
self._serialize_hyperparameter('weight_decay'),
})
return config
def minimize(self,
loss,
var_list,
grad_loss=None,
name=None,
decay_var_list=None):
"""Minimize `loss` by updating `var_list`.
This method simply computes gradient using `tf.GradientTape` and calls
`apply_gradients()`. If you want to process the gradient before
applying then call `tf.GradientTape` and `apply_gradients()` explicitly
instead of using this function.
Args:
loss: A callable taking no arguments which returns the value to
minimize.
var_list: list or tuple of `Variable` objects to update to
minimize `loss`, or a callable returning the list or tuple of
`Variable` objects. Use callable when the variable list would
otherwise be incomplete before `minimize` since the variables
are created at the first time `loss` is called.
grad_loss: Optional. A `Tensor` holding the gradient computed for
`loss`.
decay_var_list: Optional list of variables to be decayed. Defaults
to all variables in var_list.
name: Optional name for the returned operation.
Returns:
An Operation that updates the variables in `var_list`. If
`global_step` was not `None`, that operation also increments
`global_step`.
Raises:
ValueError: If some of the variables are not `Variable` objects.
"""
self._decay_var_list = set(decay_var_list) if decay_var_list else False
return super(DecoupledWeightDecayExtension, self).minimize(
loss, var_list=var_list, grad_loss=grad_loss, name=name)
def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None):
"""Apply gradients to variables.
This is the second part of `minimize()`. It returns an `Operation` that
applies gradients.
Args:
grads_and_vars: List of (gradient, variable) pairs.
name: Optional name for the returned operation. Default to the
name passed to the `Optimizer` constructor.
decay_var_list: Optional list of variables to be decayed. Defaults
to all variables in var_list.
Returns:
An `Operation` that applies the specified gradients. If
`global_step` was not None, that operation also increments
`global_step`.
Raises:
TypeError: If `grads_and_vars` is malformed.
ValueError: If none of the variables have gradients.
"""
self._decay_var_list = set(decay_var_list) if decay_var_list else False
return super(DecoupledWeightDecayExtension, self).apply_gradients(
grads_and_vars, name=name)
def _decay_weights_op(self, var):
if not self._decay_var_list or var in self._decay_var_list:
return var.assign_sub(
self._get_hyper('weight_decay', var.dtype) * var,
self._use_locking)
return tf.no_op()
def _decay_weights_sparse_op(self, var, indices):
if not self._decay_var_list or var in self._decay_var_list:
update = (-self._get_hyper('weight_decay', var.dtype) * tf.gather(
var, indices))
return self._resource_scatter_add(var, indices, update)
return tf.no_op()
# Here, we overwrite the apply functions that the base optimizer calls.
# super().apply_x resolves to the apply_x function of the BaseOptimizer.
def _resource_apply_dense(self, grad, var):
with tf.control_dependencies([self._decay_weights_op(var)]):
return super(DecoupledWeightDecayExtension,
self)._resource_apply_dense(grad, var)
def _resource_apply_sparse(self, grad, var, indices):
decay_op = self._decay_weights_sparse_op(var, indices)
with tf.control_dependencies([decay_op]):
return super(DecoupledWeightDecayExtension,
self)._resource_apply_sparse(grad, var, indices)
def extend_with_decoupled_weight_decay(base_optimizer):
"""Factory function returning an optimizer class with decoupled weight
decay.
Returns an optimizer class. An instance of the returned class computes the
update step of `base_optimizer` and additionally decays the weights.
E.g., the class returned by
`extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)` is
equivalent to `tfa.optimizers.AdamW`.
The API of the new optimizer class slightly differs from the API of the
base optimizer:
- The first argument to the constructor is the weight decay rate.
- `minimize` and `apply_gradients` accept the optional keyword argument
`decay_var_list`, which specifies the variables that should be decayed.
If `None`, all variables that are optimized are decayed.
Usage example:
```python
# MyAdamW is a new class
MyAdamW = extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)
# Create a MyAdamW object
optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001)
# update var1, var2 but only decay var1
optimizer.minimize(loss, var_list=[var1, var2], decay_variables=[var1])
Note: this extension decays weights BEFORE applying the update based
on the gradient, i.e. this extension only has the desired behaviour for
optimizers which do not depend on the value of 'var' in the update step!
Note: when applying a decay to the learning rate, be sure to manually apply
the decay to the `weight_decay` as well. For example:
```python
step = tf.Variable(0, trainable=False)
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
[10000, 15000], [1e-0, 1e-1, 1e-2])
# lr and wd can be a function or a tensor
lr = 1e-1 * schedule(step)
wd = lambda: 1e-4 * schedule(step)
# ...
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
```
Note: you might want to register your own custom optimizer using
`tf.keras.utils.get_custom_objects()`.
Args:
base_optimizer: An optimizer class that inherits from
tf.optimizers.Optimizer.
Returns:
A new optimizer class that inherits from DecoupledWeightDecayExtension
and base_optimizer.
"""
class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension,
base_optimizer):
"""Base_optimizer with decoupled weight decay.
This class computes the update step of `base_optimizer` and
additionally decays the variable with the weight decay being
decoupled from the optimization steps w.r.t. to the loss
function, as described by Loshchilov & Hutter
(https://arxiv.org/pdf/1711.05101.pdf). For SGD variants, this
simplifies hyperparameter search since it decouples the settings
of weight decay and learning rate. For adaptive gradient
algorithms, it regularizes variables with large gradients more
than L2 regularization would, which was shown to yield better
training loss and generalization error in the paper above.
"""
def __init__(self, weight_decay, *args, **kwargs):
# super delegation is necessary here
super(OptimizerWithDecoupledWeightDecay, self).__init__(
weight_decay, *args, **kwargs)
return OptimizerWithDecoupledWeightDecay
class SGDW(DecoupledWeightDecayExtension, tf.keras.optimizers.SGD):
"""Optimizer that implements the Momentum algorithm with weight_decay.
This is an implementation of the SGDW optimizer described in "Decoupled
Weight Decay Regularization" by Loshchilov & Hutter
(https://arxiv.org/abs/1711.05101)
([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
It computes the update step of `tf.keras.optimizers.SGD` and additionally
decays the variable. Note that this is different from adding
L2 regularization on the variables to the loss. Decoupling the weight decay
from other hyperparameters (in particular the learning rate) simplifies
hyperparameter search.
For further information see the documentation of the SGD Optimizer.
This optimizer can also be instantiated as
```python
extend_with_decoupled_weight_decay(tf.keras.optimizers.SGD,
weight_decay=weight_decay)
```
Note: when applying a decay to the learning rate, be sure to manually apply
the decay to the `weight_decay` as well. For example:
```python
step = tf.Variable(0, trainable=False)
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
[10000, 15000], [1e-0, 1e-1, 1e-2])
# lr and wd can be a function or a tensor
lr = 1e-1 * schedule(step)
wd = lambda: 1e-4 * schedule(step)
# ...
optimizer = tfa.optimizers.SGDW(
learning_rate=lr, weight_decay=wd, momentum=0.9)
```
"""
def __init__(self,
weight_decay,
learning_rate=0.001,
momentum=0.0,
nesterov=False,
name='SGDW',
**kwargs):
"""Construct a new SGDW optimizer.
For further information see the documentation of the SGD Optimizer.
Args:
learning_rate: float hyperparameter >= 0. Learning rate.
momentum: float hyperparameter >= 0 that accelerates SGD in the
relevant direction and dampens oscillations.
nesterov: boolean. Whether to apply Nesterov momentum.
name: Optional name prefix for the operations created when applying
gradients. Defaults to 'SGD'.
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
norm; `clipvalue` is clip gradients by value, `decay` is
included for backward compatibility to allow time inverse decay
of learning rate. `lr` is included for backward compatibility,
recommended to use `learning_rate` instead.
"""
super(SGDW, self).__init__(
weight_decay,
learning_rate=learning_rate,
momentum=momentum,
nesterov=nesterov,
name=name,
**kwargs)
class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam):
"""Optimizer that implements the Adam algorithm with weight decay.
This is an implementation of the AdamW optimizer described in "Decoupled
Weight Decay Regularization" by Loshchilov & Hutter
(https://arxiv.org/abs/1711.05101)
([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
It computes the update step of `tf.keras.optimizers.Adam` and additionally
decays the variable. Note that this is different from adding L2
regularization on the variables to the loss: it regularizes variables with
large gradients more than L2 regularization would, which was shown to yield
better training loss and generalization error in the paper above.
For further information see the documentation of the Adam Optimizer.
This optimizer can also be instantiated as
```python
extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam,
weight_decay=weight_decay)
```
Note: when applying a decay to the learning rate, be sure to manually apply
the decay to the `weight_decay` as well. For example:
```python
step = tf.Variable(0, trainable=False)
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
[10000, 15000], [1e-0, 1e-1, 1e-2])
# lr and wd can be a function or a tensor
lr = 1e-1 * schedule(step)
wd = lambda: 1e-4 * schedule(step)
# ...
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
```
"""
def __init__(self,
weight_decay,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-07,
amsgrad=False,
name="AdamW",
**kwargs):
"""Construct a new AdamW optimizer.
For further information see the documentation of the Adam Optimizer.
Args:
weight_decay: A Tensor or a floating point value. The weight decay.
learning_rate: A Tensor or a floating point value. The learning
rate.
beta_1: A float value or a constant float tensor. The exponential
decay rate for the 1st moment estimates.
beta_2: A float value or a constant float tensor. The exponential
decay rate for the 2nd moment estimates.
epsilon: A small constant for numerical stability. This epsilon is
"epsilon hat" in the Kingma and Ba paper (in the formula just
before Section 2.1), not the epsilon in Algorithm 1 of the
paper.
amsgrad: boolean. Whether to apply AMSGrad variant of this
algorithm from the paper "On the Convergence of Adam and
beyond".
name: Optional name for the operations created when applying
gradients. Defaults to "AdamW".
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
norm; `clipvalue` is clip gradients by value, `decay` is
included for backward compatibility to allow time inverse decay
of learning rate. `lr` is included for backward compatibility,
recommended to use `learning_rate` instead.
"""
super(AdamW, self).__init__(
weight_decay,
learning_rate=learning_rate,
beta_1=beta_1,
beta_2=beta_2,
epsilon=epsilon,
amsgrad=amsgrad,
name=name,
**kwargs)