Remove TF codes

This commit is contained in:
D-X-Y 2020-07-05 23:19:25 +00:00
parent 7d02870bf8
commit cba4741d10
9 changed files with 1 additions and 846 deletions

View File

@ -70,7 +70,7 @@ def evaluate(api, weight_dir, data: str):
ok += 1
norms.append(cur_norm)
# query the accuracy
info = meta_info.get_metrics(data, 'ori-test', iepoch=None, is_random=777)
info = meta_info.get_metrics(data, 'ori-test', iepoch=None, is_random=888 if isinstance(api, NASBench201API) else 777)
accuracies.append(info['accuracy'])
del net, meta_info
# print the information

View File

@ -1,32 +0,0 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
from os import path as osp
__all__ = ['get_cell_based_tiny_net', 'get_search_spaces']
# the cell-based NAS models
def get_cell_based_tiny_net(config):
group_names = ['GDAS', 'DARTS']
if config.name in group_names:
from .cell_searchs import nas_super_nets
from .cell_operations import SearchSpaceNames
if isinstance(config.space, str): search_space = SearchSpaceNames[config.space]
else: search_space = config.space
return nas_super_nets[config.name](
config.C, config.N, config.max_nodes,
config.num_classes, search_space, config.affine)
else:
raise ValueError('invalid network name : {:}'.format(config.name))
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
def get_search_spaces(xtype, name):
if xtype == 'cell':
from .cell_operations import SearchSpaceNames
assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys())
return SearchSpaceNames[name]
else:
raise ValueError('invalid search-space type is {:}'.format(xtype))

View File

@ -1,150 +0,0 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import tensorflow as tf
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
OPS = {
'none' : lambda C_in, C_out, stride, affine: Zero(C_in, C_out, stride),
'avg_pool_3x3': lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg', affine),
'nor_conv_1x1': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 1, stride, affine),
'nor_conv_3x3': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 3, stride, affine),
'nor_conv_5x5': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 5, stride, affine),
'skip_connect': lambda C_in, C_out, stride, affine: Identity(C_in, C_out, stride) if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine)
}
NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
SearchSpaceNames = {
'nas-bench-201': NAS_BENCH_201,
}
class POOLING(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, stride, mode, affine):
super(POOLING, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, affine)
if mode == 'avg' : self.op = tf.keras.layers.AvgPool2D((3,3), strides=stride, padding='same')
elif mode == 'max': self.op = tf.keras.layers.MaxPool2D((3,3), strides=stride, padding='same')
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
def call(self, inputs, training):
if self.preprocess: x = self.preprocess(inputs)
else : x = inputs
return self.op(x)
class Identity(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, stride):
super(Identity, self).__init__()
if C_in != C_out or stride != 1:
self.layer = tf.keras.layers.Conv2D(C_out, 3, stride, padding='same', use_bias=False)
else:
self.layer = None
def call(self, inputs, training):
x = inputs
if self.layer is not None:
x = self.layer(x)
return x
class Zero(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, stride):
super(Zero, self).__init__()
if C_in != C_out:
self.layer = tf.keras.layers.Conv2D(C_out, 1, stride, padding='same', use_bias=False)
elif stride != 1:
self.layer = tf.keras.layers.AvgPool2D((stride,stride), None, padding="same")
else:
self.layer = None
def call(self, inputs, training):
x = tf.zeros_like(inputs)
if self.layer is not None:
x = self.layer(x)
return x
class ReLUConvBN(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, kernel_size, strides, affine):
super(ReLUConvBN, self).__init__()
self.C_in = C_in
self.relu = tf.keras.activations.relu
self.conv = tf.keras.layers.Conv2D(C_out, kernel_size, strides, padding='same', use_bias=False)
self.bn = tf.keras.layers.BatchNormalization(center=affine, scale=affine)
def call(self, inputs, training):
x = self.relu(inputs)
x = self.conv(x)
x = self.bn(x, training)
return x
class FactorizedReduce(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, stride, affine):
assert output_filters % 2 == 0, ('Need even number of filters when using this factorized reduction.')
self.stride == stride
self.relu = tf.keras.activations.relu
if stride == 1:
self.layer = tf.keras.Sequential([
tf.keras.layers.Conv2D(C_out, 1, strides, padding='same', use_bias=False),
tf.keras.layers.BatchNormalization(center=affine, scale=affine)])
elif stride == 2:
stride_spec = [1, stride, stride, 1] # data_format == 'NHWC'
self.layer1 = tf.keras.layers.Conv2D(C_out//2, 1, strides, padding='same', use_bias=False)
self.layer2 = tf.keras.layers.Conv2D(C_out//2, 1, strides, padding='same', use_bias=False)
self.bn = tf.keras.layers.BatchNormalization(center=affine, scale=affine)
else:
raise ValueError('invalid stride={:}'.format(stride))
def call(self, inputs, training):
x = self.relu(inputs)
if self.stride == 1:
return self.layer(x, training)
else:
path1 = x
path2 = tf.pad(x, [[0, 0], [0, 1], [0, 1], [0, 0]])[:, 1:, 1:, :] # data_format == 'NHWC'
x1 = self.layer1(path1)
x2 = self.layer2(path2)
final_path = tf.concat(values=[x1, x2], axis=3)
return self.bn(final_path)
class ResNetBasicblock(tf.keras.layers.Layer):
def __init__(self, inplanes, planes, stride, affine=True):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, affine)
self.conv_b = ReLUConvBN( planes, planes, 3, 1, affine)
if stride == 2:
self.downsample = tf.keras.Sequential([
tf.keras.layers.AvgPool2D((stride,stride), None, padding="same"),
tf.keras.layers.Conv2D(planes, 1, 1, padding='same', use_bias=False)])
elif inplanes != planes:
self.downsample = ReLUConvBN(inplanes, planes, 1, stride, affine)
else:
self.downsample = None
self.addition = tf.keras.layers.Add()
self.in_dim = inplanes
self.out_dim = planes
self.stride = stride
self.num_conv = 2
def call(self, inputs, training):
basicblock = self.conv_a(inputs, training)
basicblock = self.conv_b(basicblock, training)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
return self.addition([residual, basicblock])

View File

@ -1,8 +0,0 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .search_model_gdas import TinyNetworkGDAS
from .search_model_darts import TinyNetworkDARTS
nas_super_nets = {'GDAS' : TinyNetworkGDAS,
'DARTS': TinyNetworkDARTS}

View File

@ -1,50 +0,0 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, random
import tensorflow as tf
from copy import deepcopy
from ..cell_operations import OPS
class NAS201SearchCell(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False):
super(NAS201SearchCell, self).__init__()
self.op_names = deepcopy(op_names)
self.max_nodes = max_nodes
self.in_dim = C_in
self.out_dim = C_out
self.edge_keys = []
for i in range(1, max_nodes):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
if j == 0:
xlists = [OPS[op_name](C_in , C_out, stride, affine) for op_name in op_names]
else:
xlists = [OPS[op_name](C_in , C_out, 1, affine) for op_name in op_names]
for k, op in enumerate(xlists):
setattr(self, '{:}.{:}'.format(node_str, k), op)
self.edge_keys.append( node_str )
self.edge_keys = sorted(self.edge_keys)
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
self.num_edges = len(self.edge_keys)
def call(self, inputs, weightss, training):
w_lst = tf.split(weightss, self.num_edges, 0)
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
edge_idx = self.edge2index[node_str]
op_outps = []
for k, op_name in enumerate(self.op_names):
op = getattr(self, '{:}.{:}'.format(node_str, k))
op_outps.append( op(nodes[j], training) )
stack_op_outs = tf.stack(op_outps, axis=-1)
weighted_sums = tf.math.multiply(stack_op_outs, w_lst[edge_idx])
inter_nodes.append( tf.math.reduce_sum(weighted_sums, axis=-1) )
nodes.append( tf.math.add_n(inter_nodes) )
return nodes[-1]

View File

@ -1,83 +0,0 @@
import tensorflow as tf
import numpy as np
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
class TinyNetworkDARTS(tf.keras.Model):
def __init__(self, C, N, max_nodes, num_classes, search_space, affine):
super(TinyNetworkDARTS, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = tf.keras.Sequential([
tf.keras.layers.Conv2D(16, 3, 1, padding='same', use_bias=False),
tf.keras.layers.BatchNormalization()], name='stem')
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
cell_prefix = 'cell-{:03d}'.format(index)
#with tf.name_scope(cell_prefix) as scope:
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
C_prev = cell.out_dim
setattr(self, cell_prefix, cell)
self.num_layers = len(layer_reductions)
self.op_names = deepcopy( search_space )
self.edge2index = edge2index
self.num_edge = num_edge
self.lastact = tf.keras.Sequential([
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.GlobalAvgPool2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(num_classes, activation='softmax')], name='lastact')
#self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
arch_init = tf.random_normal_initializer(mean=0, stddev=0.001)
self.arch_parameters = tf.Variable(initial_value=arch_init(shape=(num_edge, len(search_space)), dtype='float32'), trainable=True, name='arch-encoding')
def get_alphas(self):
xlist = self.trainable_variables
return [x for x in xlist if 'arch-encoding' in x.name]
def get_weights(self):
xlist = self.trainable_variables
return [x for x in xlist if 'arch-encoding' not in x.name]
def get_np_alphas(self):
arch_nps = self.arch_parameters.numpy()
arch_ops = np.exp(arch_nps) / np.sum(np.exp(arch_nps), axis=-1, keepdims=True)
return arch_ops
def genotype(self):
genotypes, arch_nps = [], self.arch_parameters.numpy()
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = arch_nps[ self.edge2index[node_str] ]
op_name = self.op_names[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return genotypes
def call(self, inputs, training):
weightss = tf.nn.softmax(self.arch_parameters, axis=1)
feature = self.stem(inputs, training)
for idx in range(self.num_layers):
cell = getattr(self, 'cell-{:03d}'.format(idx))
if isinstance(cell, SearchCell):
feature = cell.call(feature, weightss, training)
else:
feature = cell(feature, training)
logits = self.lastact(feature, training)
return logits

View File

@ -1,99 +0,0 @@
###########################################################################
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
###########################################################################
import tensorflow as tf
import numpy as np
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
def sample_gumbel(shape, eps=1e-20):
U = tf.random.uniform(shape, minval=0, maxval=1)
return -tf.math.log(-tf.math.log(U + eps) + eps)
def gumbel_softmax(logits, temperature):
gumbel_softmax_sample = logits + sample_gumbel(tf.shape(logits))
y = tf.nn.softmax(gumbel_softmax_sample / temperature)
return y
class TinyNetworkGDAS(tf.keras.Model):
def __init__(self, C, N, max_nodes, num_classes, search_space, affine):
super(TinyNetworkGDAS, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = tf.keras.Sequential([
tf.keras.layers.Conv2D(16, 3, 1, padding='same', use_bias=False),
tf.keras.layers.BatchNormalization()], name='stem')
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
cell_prefix = 'cell-{:03d}'.format(index)
#with tf.name_scope(cell_prefix) as scope:
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
C_prev = cell.out_dim
setattr(self, cell_prefix, cell)
self.num_layers = len(layer_reductions)
self.op_names = deepcopy( search_space )
self.edge2index = edge2index
self.num_edge = num_edge
self.lastact = tf.keras.Sequential([
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.GlobalAvgPool2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(num_classes, activation='softmax')], name='lastact')
#self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
arch_init = tf.random_normal_initializer(mean=0, stddev=0.001)
self.arch_parameters = tf.Variable(initial_value=arch_init(shape=(num_edge, len(search_space)), dtype='float32'), trainable=True, name='arch-encoding')
def get_alphas(self):
xlist = self.trainable_variables
return [x for x in xlist if 'arch-encoding' in x.name]
def get_weights(self):
xlist = self.trainable_variables
return [x for x in xlist if 'arch-encoding' not in x.name]
def get_np_alphas(self):
arch_nps = self.arch_parameters.numpy()
arch_ops = np.exp(arch_nps) / np.sum(np.exp(arch_nps), axis=-1, keepdims=True)
return arch_ops
def genotype(self):
genotypes, arch_nps = [], self.arch_parameters.numpy()
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = arch_nps[ self.edge2index[node_str] ]
op_name = self.op_names[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return genotypes
#
def call(self, inputs, tau, training):
weightss = tf.cond(tau < 0, lambda: tf.nn.softmax(self.arch_parameters, axis=1),
lambda: gumbel_softmax(tf.math.log_softmax(self.arch_parameters, axis=1), tau))
feature = self.stem(inputs, training)
for idx in range(self.num_layers):
cell = getattr(self, 'cell-{:03d}'.format(idx))
if isinstance(cell, SearchCell):
feature = cell.call(feature, weightss, training)
else:
feature = cell(feature, training)
logits = self.lastact(feature, training)
return logits

View File

@ -1 +0,0 @@
from .weight_decay_optimizers import AdamW, SGDW

View File

@ -1,422 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class to make optimizers weight decay ready."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
class DecoupledWeightDecayExtension(object):
"""This class allows to extend optimizers with decoupled weight decay.
It implements the decoupled weight decay described by Loshchilov & Hutter
(https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is
decoupled from the optimization steps w.r.t. to the loss function.
For SGD variants, this simplifies hyperparameter search since it decouples
the settings of weight decay and learning rate.
For adaptive gradient algorithms, it regularizes variables with large
gradients more than L2 regularization would, which was shown to yield
better training loss and generalization error in the paper above.
This class alone is not an optimizer but rather extends existing
optimizers with decoupled weight decay. We explicitly define the two
examples used in the above paper (SGDW and AdamW), but in general this
can extend any OptimizerX by using
`extend_with_decoupled_weight_decay(
OptimizerX, weight_decay=weight_decay)`.
In order for it to work, it must be the first class the Optimizer with
weight decay inherits from, e.g.
```python
class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam):
def __init__(self, weight_decay, *args, **kwargs):
super(AdamW, self).__init__(weight_decay, *args, **kwargs).
```
Note: this extension decays weights BEFORE applying the update based
on the gradient, i.e. this extension only has the desired behaviour for
optimizers which do not depend on the value of'var' in the update step!
Note: when applying a decay to the learning rate, be sure to manually apply
the decay to the `weight_decay` as well. For example:
```python
step = tf.Variable(0, trainable=False)
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
[10000, 15000], [1e-0, 1e-1, 1e-2])
# lr and wd can be a function or a tensor
lr = 1e-1 * schedule(step)
wd = lambda: 1e-4 * schedule(step)
# ...
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
```
"""
def __init__(self, weight_decay, **kwargs):
"""Extension class that adds weight decay to an optimizer.
Args:
weight_decay: A `Tensor` or a floating point value, the factor by
which a variable is decayed in the update step.
**kwargs: Optional list or tuple or set of `Variable` objects to
decay.
"""
wd = kwargs.pop('weight_decay', weight_decay)
super(DecoupledWeightDecayExtension, self).__init__(**kwargs)
self._decay_var_list = None # is set in minimize or apply_gradients
self._set_hyper('weight_decay', wd)
def get_config(self):
config = super(DecoupledWeightDecayExtension, self).get_config()
config.update({
'weight_decay':
self._serialize_hyperparameter('weight_decay'),
})
return config
def minimize(self,
loss,
var_list,
grad_loss=None,
name=None,
decay_var_list=None):
"""Minimize `loss` by updating `var_list`.
This method simply computes gradient using `tf.GradientTape` and calls
`apply_gradients()`. If you want to process the gradient before
applying then call `tf.GradientTape` and `apply_gradients()` explicitly
instead of using this function.
Args:
loss: A callable taking no arguments which returns the value to
minimize.
var_list: list or tuple of `Variable` objects to update to
minimize `loss`, or a callable returning the list or tuple of
`Variable` objects. Use callable when the variable list would
otherwise be incomplete before `minimize` since the variables
are created at the first time `loss` is called.
grad_loss: Optional. A `Tensor` holding the gradient computed for
`loss`.
decay_var_list: Optional list of variables to be decayed. Defaults
to all variables in var_list.
name: Optional name for the returned operation.
Returns:
An Operation that updates the variables in `var_list`. If
`global_step` was not `None`, that operation also increments
`global_step`.
Raises:
ValueError: If some of the variables are not `Variable` objects.
"""
self._decay_var_list = set(decay_var_list) if decay_var_list else False
return super(DecoupledWeightDecayExtension, self).minimize(
loss, var_list=var_list, grad_loss=grad_loss, name=name)
def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None):
"""Apply gradients to variables.
This is the second part of `minimize()`. It returns an `Operation` that
applies gradients.
Args:
grads_and_vars: List of (gradient, variable) pairs.
name: Optional name for the returned operation. Default to the
name passed to the `Optimizer` constructor.
decay_var_list: Optional list of variables to be decayed. Defaults
to all variables in var_list.
Returns:
An `Operation` that applies the specified gradients. If
`global_step` was not None, that operation also increments
`global_step`.
Raises:
TypeError: If `grads_and_vars` is malformed.
ValueError: If none of the variables have gradients.
"""
self._decay_var_list = set(decay_var_list) if decay_var_list else False
return super(DecoupledWeightDecayExtension, self).apply_gradients(
grads_and_vars, name=name)
def _decay_weights_op(self, var):
if not self._decay_var_list or var in self._decay_var_list:
return var.assign_sub(
self._get_hyper('weight_decay', var.dtype) * var,
self._use_locking)
return tf.no_op()
def _decay_weights_sparse_op(self, var, indices):
if not self._decay_var_list or var in self._decay_var_list:
update = (-self._get_hyper('weight_decay', var.dtype) * tf.gather(
var, indices))
return self._resource_scatter_add(var, indices, update)
return tf.no_op()
# Here, we overwrite the apply functions that the base optimizer calls.
# super().apply_x resolves to the apply_x function of the BaseOptimizer.
def _resource_apply_dense(self, grad, var):
with tf.control_dependencies([self._decay_weights_op(var)]):
return super(DecoupledWeightDecayExtension,
self)._resource_apply_dense(grad, var)
def _resource_apply_sparse(self, grad, var, indices):
decay_op = self._decay_weights_sparse_op(var, indices)
with tf.control_dependencies([decay_op]):
return super(DecoupledWeightDecayExtension,
self)._resource_apply_sparse(grad, var, indices)
def extend_with_decoupled_weight_decay(base_optimizer):
"""Factory function returning an optimizer class with decoupled weight
decay.
Returns an optimizer class. An instance of the returned class computes the
update step of `base_optimizer` and additionally decays the weights.
E.g., the class returned by
`extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)` is
equivalent to `tfa.optimizers.AdamW`.
The API of the new optimizer class slightly differs from the API of the
base optimizer:
- The first argument to the constructor is the weight decay rate.
- `minimize` and `apply_gradients` accept the optional keyword argument
`decay_var_list`, which specifies the variables that should be decayed.
If `None`, all variables that are optimized are decayed.
Usage example:
```python
# MyAdamW is a new class
MyAdamW = extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)
# Create a MyAdamW object
optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001)
# update var1, var2 but only decay var1
optimizer.minimize(loss, var_list=[var1, var2], decay_variables=[var1])
Note: this extension decays weights BEFORE applying the update based
on the gradient, i.e. this extension only has the desired behaviour for
optimizers which do not depend on the value of 'var' in the update step!
Note: when applying a decay to the learning rate, be sure to manually apply
the decay to the `weight_decay` as well. For example:
```python
step = tf.Variable(0, trainable=False)
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
[10000, 15000], [1e-0, 1e-1, 1e-2])
# lr and wd can be a function or a tensor
lr = 1e-1 * schedule(step)
wd = lambda: 1e-4 * schedule(step)
# ...
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
```
Note: you might want to register your own custom optimizer using
`tf.keras.utils.get_custom_objects()`.
Args:
base_optimizer: An optimizer class that inherits from
tf.optimizers.Optimizer.
Returns:
A new optimizer class that inherits from DecoupledWeightDecayExtension
and base_optimizer.
"""
class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension,
base_optimizer):
"""Base_optimizer with decoupled weight decay.
This class computes the update step of `base_optimizer` and
additionally decays the variable with the weight decay being
decoupled from the optimization steps w.r.t. to the loss
function, as described by Loshchilov & Hutter
(https://arxiv.org/pdf/1711.05101.pdf). For SGD variants, this
simplifies hyperparameter search since it decouples the settings
of weight decay and learning rate. For adaptive gradient
algorithms, it regularizes variables with large gradients more
than L2 regularization would, which was shown to yield better
training loss and generalization error in the paper above.
"""
def __init__(self, weight_decay, *args, **kwargs):
# super delegation is necessary here
super(OptimizerWithDecoupledWeightDecay, self).__init__(
weight_decay, *args, **kwargs)
return OptimizerWithDecoupledWeightDecay
class SGDW(DecoupledWeightDecayExtension, tf.keras.optimizers.SGD):
"""Optimizer that implements the Momentum algorithm with weight_decay.
This is an implementation of the SGDW optimizer described in "Decoupled
Weight Decay Regularization" by Loshchilov & Hutter
(https://arxiv.org/abs/1711.05101)
([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
It computes the update step of `tf.keras.optimizers.SGD` and additionally
decays the variable. Note that this is different from adding
L2 regularization on the variables to the loss. Decoupling the weight decay
from other hyperparameters (in particular the learning rate) simplifies
hyperparameter search.
For further information see the documentation of the SGD Optimizer.
This optimizer can also be instantiated as
```python
extend_with_decoupled_weight_decay(tf.keras.optimizers.SGD,
weight_decay=weight_decay)
```
Note: when applying a decay to the learning rate, be sure to manually apply
the decay to the `weight_decay` as well. For example:
```python
step = tf.Variable(0, trainable=False)
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
[10000, 15000], [1e-0, 1e-1, 1e-2])
# lr and wd can be a function or a tensor
lr = 1e-1 * schedule(step)
wd = lambda: 1e-4 * schedule(step)
# ...
optimizer = tfa.optimizers.SGDW(
learning_rate=lr, weight_decay=wd, momentum=0.9)
```
"""
def __init__(self,
weight_decay,
learning_rate=0.001,
momentum=0.0,
nesterov=False,
name='SGDW',
**kwargs):
"""Construct a new SGDW optimizer.
For further information see the documentation of the SGD Optimizer.
Args:
learning_rate: float hyperparameter >= 0. Learning rate.
momentum: float hyperparameter >= 0 that accelerates SGD in the
relevant direction and dampens oscillations.
nesterov: boolean. Whether to apply Nesterov momentum.
name: Optional name prefix for the operations created when applying
gradients. Defaults to 'SGD'.
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
norm; `clipvalue` is clip gradients by value, `decay` is
included for backward compatibility to allow time inverse decay
of learning rate. `lr` is included for backward compatibility,
recommended to use `learning_rate` instead.
"""
super(SGDW, self).__init__(
weight_decay,
learning_rate=learning_rate,
momentum=momentum,
nesterov=nesterov,
name=name,
**kwargs)
class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam):
"""Optimizer that implements the Adam algorithm with weight decay.
This is an implementation of the AdamW optimizer described in "Decoupled
Weight Decay Regularization" by Loshchilov & Hutter
(https://arxiv.org/abs/1711.05101)
([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
It computes the update step of `tf.keras.optimizers.Adam` and additionally
decays the variable. Note that this is different from adding L2
regularization on the variables to the loss: it regularizes variables with
large gradients more than L2 regularization would, which was shown to yield
better training loss and generalization error in the paper above.
For further information see the documentation of the Adam Optimizer.
This optimizer can also be instantiated as
```python
extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam,
weight_decay=weight_decay)
```
Note: when applying a decay to the learning rate, be sure to manually apply
the decay to the `weight_decay` as well. For example:
```python
step = tf.Variable(0, trainable=False)
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
[10000, 15000], [1e-0, 1e-1, 1e-2])
# lr and wd can be a function or a tensor
lr = 1e-1 * schedule(step)
wd = lambda: 1e-4 * schedule(step)
# ...
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
```
"""
def __init__(self,
weight_decay,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-07,
amsgrad=False,
name="AdamW",
**kwargs):
"""Construct a new AdamW optimizer.
For further information see the documentation of the Adam Optimizer.
Args:
weight_decay: A Tensor or a floating point value. The weight decay.
learning_rate: A Tensor or a floating point value. The learning
rate.
beta_1: A float value or a constant float tensor. The exponential
decay rate for the 1st moment estimates.
beta_2: A float value or a constant float tensor. The exponential
decay rate for the 2nd moment estimates.
epsilon: A small constant for numerical stability. This epsilon is
"epsilon hat" in the Kingma and Ba paper (in the formula just
before Section 2.1), not the epsilon in Algorithm 1 of the
paper.
amsgrad: boolean. Whether to apply AMSGrad variant of this
algorithm from the paper "On the Convergence of Adam and
beyond".
name: Optional name for the operations created when applying
gradients. Defaults to "AdamW".
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
norm; `clipvalue` is clip gradients by value, `decay` is
included for backward compatibility to allow time inverse decay
of learning rate. `lr` is included for backward compatibility,
recommended to use `learning_rate` instead.
"""
super(AdamW, self).__init__(
weight_decay,
learning_rate=learning_rate,
beta_1=beta_1,
beta_2=beta_2,
epsilon=epsilon,
amsgrad=amsgrad,
name=name,
**kwargs)