Remove TF codes
This commit is contained in:
parent
7d02870bf8
commit
cba4741d10
@ -70,7 +70,7 @@ def evaluate(api, weight_dir, data: str):
|
||||
ok += 1
|
||||
norms.append(cur_norm)
|
||||
# query the accuracy
|
||||
info = meta_info.get_metrics(data, 'ori-test', iepoch=None, is_random=777)
|
||||
info = meta_info.get_metrics(data, 'ori-test', iepoch=None, is_random=888 if isinstance(api, NASBench201API) else 777)
|
||||
accuracies.append(info['accuracy'])
|
||||
del net, meta_info
|
||||
# print the information
|
||||
|
@ -1,32 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import torch
|
||||
from os import path as osp
|
||||
|
||||
__all__ = ['get_cell_based_tiny_net', 'get_search_spaces']
|
||||
|
||||
|
||||
# the cell-based NAS models
|
||||
def get_cell_based_tiny_net(config):
|
||||
group_names = ['GDAS', 'DARTS']
|
||||
if config.name in group_names:
|
||||
from .cell_searchs import nas_super_nets
|
||||
from .cell_operations import SearchSpaceNames
|
||||
if isinstance(config.space, str): search_space = SearchSpaceNames[config.space]
|
||||
else: search_space = config.space
|
||||
return nas_super_nets[config.name](
|
||||
config.C, config.N, config.max_nodes,
|
||||
config.num_classes, search_space, config.affine)
|
||||
else:
|
||||
raise ValueError('invalid network name : {:}'.format(config.name))
|
||||
|
||||
|
||||
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
|
||||
def get_search_spaces(xtype, name):
|
||||
if xtype == 'cell':
|
||||
from .cell_operations import SearchSpaceNames
|
||||
assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys())
|
||||
return SearchSpaceNames[name]
|
||||
else:
|
||||
raise ValueError('invalid search-space type is {:}'.format(xtype))
|
@ -1,150 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import tensorflow as tf
|
||||
|
||||
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
|
||||
|
||||
OPS = {
|
||||
'none' : lambda C_in, C_out, stride, affine: Zero(C_in, C_out, stride),
|
||||
'avg_pool_3x3': lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg', affine),
|
||||
'nor_conv_1x1': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 1, stride, affine),
|
||||
'nor_conv_3x3': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 3, stride, affine),
|
||||
'nor_conv_5x5': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 5, stride, affine),
|
||||
'skip_connect': lambda C_in, C_out, stride, affine: Identity(C_in, C_out, stride) if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine)
|
||||
}
|
||||
|
||||
NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
|
||||
SearchSpaceNames = {
|
||||
'nas-bench-201': NAS_BENCH_201,
|
||||
}
|
||||
|
||||
|
||||
class POOLING(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, mode, affine):
|
||||
super(POOLING, self).__init__()
|
||||
if C_in == C_out:
|
||||
self.preprocess = None
|
||||
else:
|
||||
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, affine)
|
||||
if mode == 'avg' : self.op = tf.keras.layers.AvgPool2D((3,3), strides=stride, padding='same')
|
||||
elif mode == 'max': self.op = tf.keras.layers.MaxPool2D((3,3), strides=stride, padding='same')
|
||||
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
|
||||
|
||||
def call(self, inputs, training):
|
||||
if self.preprocess: x = self.preprocess(inputs)
|
||||
else : x = inputs
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Identity(tf.keras.layers.Layer):
|
||||
def __init__(self, C_in, C_out, stride):
|
||||
super(Identity, self).__init__()
|
||||
if C_in != C_out or stride != 1:
|
||||
self.layer = tf.keras.layers.Conv2D(C_out, 3, stride, padding='same', use_bias=False)
|
||||
else:
|
||||
self.layer = None
|
||||
|
||||
def call(self, inputs, training):
|
||||
x = inputs
|
||||
if self.layer is not None:
|
||||
x = self.layer(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class Zero(tf.keras.layers.Layer):
|
||||
def __init__(self, C_in, C_out, stride):
|
||||
super(Zero, self).__init__()
|
||||
if C_in != C_out:
|
||||
self.layer = tf.keras.layers.Conv2D(C_out, 1, stride, padding='same', use_bias=False)
|
||||
elif stride != 1:
|
||||
self.layer = tf.keras.layers.AvgPool2D((stride,stride), None, padding="same")
|
||||
else:
|
||||
self.layer = None
|
||||
|
||||
def call(self, inputs, training):
|
||||
x = tf.zeros_like(inputs)
|
||||
if self.layer is not None:
|
||||
x = self.layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class ReLUConvBN(tf.keras.layers.Layer):
|
||||
def __init__(self, C_in, C_out, kernel_size, strides, affine):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.C_in = C_in
|
||||
self.relu = tf.keras.activations.relu
|
||||
self.conv = tf.keras.layers.Conv2D(C_out, kernel_size, strides, padding='same', use_bias=False)
|
||||
self.bn = tf.keras.layers.BatchNormalization(center=affine, scale=affine)
|
||||
|
||||
def call(self, inputs, training):
|
||||
x = self.relu(inputs)
|
||||
x = self.conv(x)
|
||||
x = self.bn(x, training)
|
||||
return x
|
||||
|
||||
|
||||
class FactorizedReduce(tf.keras.layers.Layer):
|
||||
def __init__(self, C_in, C_out, stride, affine):
|
||||
assert output_filters % 2 == 0, ('Need even number of filters when using this factorized reduction.')
|
||||
self.stride == stride
|
||||
self.relu = tf.keras.activations.relu
|
||||
if stride == 1:
|
||||
self.layer = tf.keras.Sequential([
|
||||
tf.keras.layers.Conv2D(C_out, 1, strides, padding='same', use_bias=False),
|
||||
tf.keras.layers.BatchNormalization(center=affine, scale=affine)])
|
||||
elif stride == 2:
|
||||
stride_spec = [1, stride, stride, 1] # data_format == 'NHWC'
|
||||
self.layer1 = tf.keras.layers.Conv2D(C_out//2, 1, strides, padding='same', use_bias=False)
|
||||
self.layer2 = tf.keras.layers.Conv2D(C_out//2, 1, strides, padding='same', use_bias=False)
|
||||
self.bn = tf.keras.layers.BatchNormalization(center=affine, scale=affine)
|
||||
else:
|
||||
raise ValueError('invalid stride={:}'.format(stride))
|
||||
|
||||
def call(self, inputs, training):
|
||||
x = self.relu(inputs)
|
||||
if self.stride == 1:
|
||||
return self.layer(x, training)
|
||||
else:
|
||||
path1 = x
|
||||
path2 = tf.pad(x, [[0, 0], [0, 1], [0, 1], [0, 0]])[:, 1:, 1:, :] # data_format == 'NHWC'
|
||||
x1 = self.layer1(path1)
|
||||
x2 = self.layer2(path2)
|
||||
final_path = tf.concat(values=[x1, x2], axis=3)
|
||||
return self.bn(final_path)
|
||||
|
||||
|
||||
class ResNetBasicblock(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, inplanes, planes, stride, affine=True):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, affine)
|
||||
self.conv_b = ReLUConvBN( planes, planes, 3, 1, affine)
|
||||
if stride == 2:
|
||||
self.downsample = tf.keras.Sequential([
|
||||
tf.keras.layers.AvgPool2D((stride,stride), None, padding="same"),
|
||||
tf.keras.layers.Conv2D(planes, 1, 1, padding='same', use_bias=False)])
|
||||
elif inplanes != planes:
|
||||
self.downsample = ReLUConvBN(inplanes, planes, 1, stride, affine)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.addition = tf.keras.layers.Add()
|
||||
self.in_dim = inplanes
|
||||
self.out_dim = planes
|
||||
self.stride = stride
|
||||
self.num_conv = 2
|
||||
|
||||
def call(self, inputs, training):
|
||||
|
||||
basicblock = self.conv_a(inputs, training)
|
||||
basicblock = self.conv_b(basicblock, training)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
return self.addition([residual, basicblock])
|
@ -1,8 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from .search_model_gdas import TinyNetworkGDAS
|
||||
from .search_model_darts import TinyNetworkDARTS
|
||||
|
||||
nas_super_nets = {'GDAS' : TinyNetworkGDAS,
|
||||
'DARTS': TinyNetworkDARTS}
|
@ -1,50 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import math, random
|
||||
import tensorflow as tf
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import OPS
|
||||
|
||||
|
||||
class NAS201SearchCell(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False):
|
||||
super(NAS201SearchCell, self).__init__()
|
||||
|
||||
self.op_names = deepcopy(op_names)
|
||||
self.max_nodes = max_nodes
|
||||
self.in_dim = C_in
|
||||
self.out_dim = C_out
|
||||
self.edge_keys = []
|
||||
for i in range(1, max_nodes):
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
if j == 0:
|
||||
xlists = [OPS[op_name](C_in , C_out, stride, affine) for op_name in op_names]
|
||||
else:
|
||||
xlists = [OPS[op_name](C_in , C_out, 1, affine) for op_name in op_names]
|
||||
for k, op in enumerate(xlists):
|
||||
setattr(self, '{:}.{:}'.format(node_str, k), op)
|
||||
self.edge_keys.append( node_str )
|
||||
self.edge_keys = sorted(self.edge_keys)
|
||||
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
|
||||
self.num_edges = len(self.edge_keys)
|
||||
|
||||
def call(self, inputs, weightss, training):
|
||||
w_lst = tf.split(weightss, self.num_edges, 0)
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
edge_idx = self.edge2index[node_str]
|
||||
op_outps = []
|
||||
for k, op_name in enumerate(self.op_names):
|
||||
op = getattr(self, '{:}.{:}'.format(node_str, k))
|
||||
op_outps.append( op(nodes[j], training) )
|
||||
stack_op_outs = tf.stack(op_outps, axis=-1)
|
||||
weighted_sums = tf.math.multiply(stack_op_outs, w_lst[edge_idx])
|
||||
inter_nodes.append( tf.math.reduce_sum(weighted_sums, axis=-1) )
|
||||
nodes.append( tf.math.add_n(inter_nodes) )
|
||||
return nodes[-1]
|
@ -1,83 +0,0 @@
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
|
||||
|
||||
class TinyNetworkDARTS(tf.keras.Model):
|
||||
|
||||
def __init__(self, C, N, max_nodes, num_classes, search_space, affine):
|
||||
super(TinyNetworkDARTS, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self.stem = tf.keras.Sequential([
|
||||
tf.keras.layers.Conv2D(16, 3, 1, padding='same', use_bias=False),
|
||||
tf.keras.layers.BatchNormalization()], name='stem')
|
||||
|
||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||
cell_prefix = 'cell-{:03d}'.format(index)
|
||||
#with tf.name_scope(cell_prefix) as scope:
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine)
|
||||
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
|
||||
C_prev = cell.out_dim
|
||||
setattr(self, cell_prefix, cell)
|
||||
self.num_layers = len(layer_reductions)
|
||||
self.op_names = deepcopy( search_space )
|
||||
self.edge2index = edge2index
|
||||
self.num_edge = num_edge
|
||||
self.lastact = tf.keras.Sequential([
|
||||
tf.keras.layers.BatchNormalization(),
|
||||
tf.keras.layers.ReLU(),
|
||||
tf.keras.layers.GlobalAvgPool2D(),
|
||||
tf.keras.layers.Flatten(),
|
||||
tf.keras.layers.Dense(num_classes, activation='softmax')], name='lastact')
|
||||
#self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
||||
arch_init = tf.random_normal_initializer(mean=0, stddev=0.001)
|
||||
self.arch_parameters = tf.Variable(initial_value=arch_init(shape=(num_edge, len(search_space)), dtype='float32'), trainable=True, name='arch-encoding')
|
||||
|
||||
def get_alphas(self):
|
||||
xlist = self.trainable_variables
|
||||
return [x for x in xlist if 'arch-encoding' in x.name]
|
||||
|
||||
def get_weights(self):
|
||||
xlist = self.trainable_variables
|
||||
return [x for x in xlist if 'arch-encoding' not in x.name]
|
||||
|
||||
def get_np_alphas(self):
|
||||
arch_nps = self.arch_parameters.numpy()
|
||||
arch_ops = np.exp(arch_nps) / np.sum(np.exp(arch_nps), axis=-1, keepdims=True)
|
||||
return arch_ops
|
||||
|
||||
def genotype(self):
|
||||
genotypes, arch_nps = [], self.arch_parameters.numpy()
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
weights = arch_nps[ self.edge2index[node_str] ]
|
||||
op_name = self.op_names[ weights.argmax().item() ]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append( tuple(xlist) )
|
||||
return genotypes
|
||||
|
||||
def call(self, inputs, training):
|
||||
weightss = tf.nn.softmax(self.arch_parameters, axis=1)
|
||||
feature = self.stem(inputs, training)
|
||||
for idx in range(self.num_layers):
|
||||
cell = getattr(self, 'cell-{:03d}'.format(idx))
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell.call(feature, weightss, training)
|
||||
else:
|
||||
feature = cell(feature, training)
|
||||
logits = self.lastact(feature, training)
|
||||
return logits
|
@ -1,99 +0,0 @@
|
||||
###########################################################################
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
|
||||
###########################################################################
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
|
||||
|
||||
def sample_gumbel(shape, eps=1e-20):
|
||||
U = tf.random.uniform(shape, minval=0, maxval=1)
|
||||
return -tf.math.log(-tf.math.log(U + eps) + eps)
|
||||
|
||||
|
||||
def gumbel_softmax(logits, temperature):
|
||||
gumbel_softmax_sample = logits + sample_gumbel(tf.shape(logits))
|
||||
y = tf.nn.softmax(gumbel_softmax_sample / temperature)
|
||||
return y
|
||||
|
||||
|
||||
class TinyNetworkGDAS(tf.keras.Model):
|
||||
|
||||
def __init__(self, C, N, max_nodes, num_classes, search_space, affine):
|
||||
super(TinyNetworkGDAS, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self.stem = tf.keras.Sequential([
|
||||
tf.keras.layers.Conv2D(16, 3, 1, padding='same', use_bias=False),
|
||||
tf.keras.layers.BatchNormalization()], name='stem')
|
||||
|
||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||
cell_prefix = 'cell-{:03d}'.format(index)
|
||||
#with tf.name_scope(cell_prefix) as scope:
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine)
|
||||
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
|
||||
C_prev = cell.out_dim
|
||||
setattr(self, cell_prefix, cell)
|
||||
self.num_layers = len(layer_reductions)
|
||||
self.op_names = deepcopy( search_space )
|
||||
self.edge2index = edge2index
|
||||
self.num_edge = num_edge
|
||||
self.lastact = tf.keras.Sequential([
|
||||
tf.keras.layers.BatchNormalization(),
|
||||
tf.keras.layers.ReLU(),
|
||||
tf.keras.layers.GlobalAvgPool2D(),
|
||||
tf.keras.layers.Flatten(),
|
||||
tf.keras.layers.Dense(num_classes, activation='softmax')], name='lastact')
|
||||
#self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
||||
arch_init = tf.random_normal_initializer(mean=0, stddev=0.001)
|
||||
self.arch_parameters = tf.Variable(initial_value=arch_init(shape=(num_edge, len(search_space)), dtype='float32'), trainable=True, name='arch-encoding')
|
||||
|
||||
def get_alphas(self):
|
||||
xlist = self.trainable_variables
|
||||
return [x for x in xlist if 'arch-encoding' in x.name]
|
||||
|
||||
def get_weights(self):
|
||||
xlist = self.trainable_variables
|
||||
return [x for x in xlist if 'arch-encoding' not in x.name]
|
||||
|
||||
def get_np_alphas(self):
|
||||
arch_nps = self.arch_parameters.numpy()
|
||||
arch_ops = np.exp(arch_nps) / np.sum(np.exp(arch_nps), axis=-1, keepdims=True)
|
||||
return arch_ops
|
||||
|
||||
def genotype(self):
|
||||
genotypes, arch_nps = [], self.arch_parameters.numpy()
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
weights = arch_nps[ self.edge2index[node_str] ]
|
||||
op_name = self.op_names[ weights.argmax().item() ]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append( tuple(xlist) )
|
||||
return genotypes
|
||||
|
||||
#
|
||||
def call(self, inputs, tau, training):
|
||||
weightss = tf.cond(tau < 0, lambda: tf.nn.softmax(self.arch_parameters, axis=1),
|
||||
lambda: gumbel_softmax(tf.math.log_softmax(self.arch_parameters, axis=1), tau))
|
||||
feature = self.stem(inputs, training)
|
||||
for idx in range(self.num_layers):
|
||||
cell = getattr(self, 'cell-{:03d}'.format(idx))
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell.call(feature, weightss, training)
|
||||
else:
|
||||
feature = cell(feature, training)
|
||||
logits = self.lastact(feature, training)
|
||||
return logits
|
@ -1 +0,0 @@
|
||||
from .weight_decay_optimizers import AdamW, SGDW
|
@ -1,422 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Base class to make optimizers weight decay ready."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class DecoupledWeightDecayExtension(object):
|
||||
"""This class allows to extend optimizers with decoupled weight decay.
|
||||
|
||||
It implements the decoupled weight decay described by Loshchilov & Hutter
|
||||
(https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is
|
||||
decoupled from the optimization steps w.r.t. to the loss function.
|
||||
For SGD variants, this simplifies hyperparameter search since it decouples
|
||||
the settings of weight decay and learning rate.
|
||||
For adaptive gradient algorithms, it regularizes variables with large
|
||||
gradients more than L2 regularization would, which was shown to yield
|
||||
better training loss and generalization error in the paper above.
|
||||
|
||||
This class alone is not an optimizer but rather extends existing
|
||||
optimizers with decoupled weight decay. We explicitly define the two
|
||||
examples used in the above paper (SGDW and AdamW), but in general this
|
||||
can extend any OptimizerX by using
|
||||
`extend_with_decoupled_weight_decay(
|
||||
OptimizerX, weight_decay=weight_decay)`.
|
||||
In order for it to work, it must be the first class the Optimizer with
|
||||
weight decay inherits from, e.g.
|
||||
|
||||
```python
|
||||
class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam):
|
||||
def __init__(self, weight_decay, *args, **kwargs):
|
||||
super(AdamW, self).__init__(weight_decay, *args, **kwargs).
|
||||
```
|
||||
|
||||
Note: this extension decays weights BEFORE applying the update based
|
||||
on the gradient, i.e. this extension only has the desired behaviour for
|
||||
optimizers which do not depend on the value of'var' in the update step!
|
||||
|
||||
Note: when applying a decay to the learning rate, be sure to manually apply
|
||||
the decay to the `weight_decay` as well. For example:
|
||||
|
||||
```python
|
||||
step = tf.Variable(0, trainable=False)
|
||||
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
|
||||
[10000, 15000], [1e-0, 1e-1, 1e-2])
|
||||
# lr and wd can be a function or a tensor
|
||||
lr = 1e-1 * schedule(step)
|
||||
wd = lambda: 1e-4 * schedule(step)
|
||||
|
||||
# ...
|
||||
|
||||
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, weight_decay, **kwargs):
|
||||
"""Extension class that adds weight decay to an optimizer.
|
||||
|
||||
Args:
|
||||
weight_decay: A `Tensor` or a floating point value, the factor by
|
||||
which a variable is decayed in the update step.
|
||||
**kwargs: Optional list or tuple or set of `Variable` objects to
|
||||
decay.
|
||||
"""
|
||||
wd = kwargs.pop('weight_decay', weight_decay)
|
||||
super(DecoupledWeightDecayExtension, self).__init__(**kwargs)
|
||||
self._decay_var_list = None # is set in minimize or apply_gradients
|
||||
self._set_hyper('weight_decay', wd)
|
||||
|
||||
def get_config(self):
|
||||
config = super(DecoupledWeightDecayExtension, self).get_config()
|
||||
config.update({
|
||||
'weight_decay':
|
||||
self._serialize_hyperparameter('weight_decay'),
|
||||
})
|
||||
return config
|
||||
|
||||
def minimize(self,
|
||||
loss,
|
||||
var_list,
|
||||
grad_loss=None,
|
||||
name=None,
|
||||
decay_var_list=None):
|
||||
"""Minimize `loss` by updating `var_list`.
|
||||
|
||||
This method simply computes gradient using `tf.GradientTape` and calls
|
||||
`apply_gradients()`. If you want to process the gradient before
|
||||
applying then call `tf.GradientTape` and `apply_gradients()` explicitly
|
||||
instead of using this function.
|
||||
|
||||
Args:
|
||||
loss: A callable taking no arguments which returns the value to
|
||||
minimize.
|
||||
var_list: list or tuple of `Variable` objects to update to
|
||||
minimize `loss`, or a callable returning the list or tuple of
|
||||
`Variable` objects. Use callable when the variable list would
|
||||
otherwise be incomplete before `minimize` since the variables
|
||||
are created at the first time `loss` is called.
|
||||
grad_loss: Optional. A `Tensor` holding the gradient computed for
|
||||
`loss`.
|
||||
decay_var_list: Optional list of variables to be decayed. Defaults
|
||||
to all variables in var_list.
|
||||
name: Optional name for the returned operation.
|
||||
Returns:
|
||||
An Operation that updates the variables in `var_list`. If
|
||||
`global_step` was not `None`, that operation also increments
|
||||
`global_step`.
|
||||
Raises:
|
||||
ValueError: If some of the variables are not `Variable` objects.
|
||||
"""
|
||||
self._decay_var_list = set(decay_var_list) if decay_var_list else False
|
||||
return super(DecoupledWeightDecayExtension, self).minimize(
|
||||
loss, var_list=var_list, grad_loss=grad_loss, name=name)
|
||||
|
||||
def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None):
|
||||
"""Apply gradients to variables.
|
||||
|
||||
This is the second part of `minimize()`. It returns an `Operation` that
|
||||
applies gradients.
|
||||
|
||||
Args:
|
||||
grads_and_vars: List of (gradient, variable) pairs.
|
||||
name: Optional name for the returned operation. Default to the
|
||||
name passed to the `Optimizer` constructor.
|
||||
decay_var_list: Optional list of variables to be decayed. Defaults
|
||||
to all variables in var_list.
|
||||
Returns:
|
||||
An `Operation` that applies the specified gradients. If
|
||||
`global_step` was not None, that operation also increments
|
||||
`global_step`.
|
||||
Raises:
|
||||
TypeError: If `grads_and_vars` is malformed.
|
||||
ValueError: If none of the variables have gradients.
|
||||
"""
|
||||
self._decay_var_list = set(decay_var_list) if decay_var_list else False
|
||||
return super(DecoupledWeightDecayExtension, self).apply_gradients(
|
||||
grads_and_vars, name=name)
|
||||
|
||||
def _decay_weights_op(self, var):
|
||||
if not self._decay_var_list or var in self._decay_var_list:
|
||||
return var.assign_sub(
|
||||
self._get_hyper('weight_decay', var.dtype) * var,
|
||||
self._use_locking)
|
||||
return tf.no_op()
|
||||
|
||||
def _decay_weights_sparse_op(self, var, indices):
|
||||
if not self._decay_var_list or var in self._decay_var_list:
|
||||
update = (-self._get_hyper('weight_decay', var.dtype) * tf.gather(
|
||||
var, indices))
|
||||
return self._resource_scatter_add(var, indices, update)
|
||||
return tf.no_op()
|
||||
|
||||
# Here, we overwrite the apply functions that the base optimizer calls.
|
||||
# super().apply_x resolves to the apply_x function of the BaseOptimizer.
|
||||
|
||||
def _resource_apply_dense(self, grad, var):
|
||||
with tf.control_dependencies([self._decay_weights_op(var)]):
|
||||
return super(DecoupledWeightDecayExtension,
|
||||
self)._resource_apply_dense(grad, var)
|
||||
|
||||
def _resource_apply_sparse(self, grad, var, indices):
|
||||
decay_op = self._decay_weights_sparse_op(var, indices)
|
||||
with tf.control_dependencies([decay_op]):
|
||||
return super(DecoupledWeightDecayExtension,
|
||||
self)._resource_apply_sparse(grad, var, indices)
|
||||
|
||||
|
||||
def extend_with_decoupled_weight_decay(base_optimizer):
|
||||
"""Factory function returning an optimizer class with decoupled weight
|
||||
decay.
|
||||
|
||||
Returns an optimizer class. An instance of the returned class computes the
|
||||
update step of `base_optimizer` and additionally decays the weights.
|
||||
E.g., the class returned by
|
||||
`extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)` is
|
||||
equivalent to `tfa.optimizers.AdamW`.
|
||||
|
||||
The API of the new optimizer class slightly differs from the API of the
|
||||
base optimizer:
|
||||
- The first argument to the constructor is the weight decay rate.
|
||||
- `minimize` and `apply_gradients` accept the optional keyword argument
|
||||
`decay_var_list`, which specifies the variables that should be decayed.
|
||||
If `None`, all variables that are optimized are decayed.
|
||||
|
||||
Usage example:
|
||||
```python
|
||||
# MyAdamW is a new class
|
||||
MyAdamW = extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)
|
||||
# Create a MyAdamW object
|
||||
optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001)
|
||||
# update var1, var2 but only decay var1
|
||||
optimizer.minimize(loss, var_list=[var1, var2], decay_variables=[var1])
|
||||
|
||||
Note: this extension decays weights BEFORE applying the update based
|
||||
on the gradient, i.e. this extension only has the desired behaviour for
|
||||
optimizers which do not depend on the value of 'var' in the update step!
|
||||
|
||||
Note: when applying a decay to the learning rate, be sure to manually apply
|
||||
the decay to the `weight_decay` as well. For example:
|
||||
|
||||
```python
|
||||
step = tf.Variable(0, trainable=False)
|
||||
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
|
||||
[10000, 15000], [1e-0, 1e-1, 1e-2])
|
||||
# lr and wd can be a function or a tensor
|
||||
lr = 1e-1 * schedule(step)
|
||||
wd = lambda: 1e-4 * schedule(step)
|
||||
|
||||
# ...
|
||||
|
||||
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
|
||||
```
|
||||
|
||||
Note: you might want to register your own custom optimizer using
|
||||
`tf.keras.utils.get_custom_objects()`.
|
||||
|
||||
Args:
|
||||
base_optimizer: An optimizer class that inherits from
|
||||
tf.optimizers.Optimizer.
|
||||
|
||||
Returns:
|
||||
A new optimizer class that inherits from DecoupledWeightDecayExtension
|
||||
and base_optimizer.
|
||||
"""
|
||||
|
||||
class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension,
|
||||
base_optimizer):
|
||||
"""Base_optimizer with decoupled weight decay.
|
||||
|
||||
This class computes the update step of `base_optimizer` and
|
||||
additionally decays the variable with the weight decay being
|
||||
decoupled from the optimization steps w.r.t. to the loss
|
||||
function, as described by Loshchilov & Hutter
|
||||
(https://arxiv.org/pdf/1711.05101.pdf). For SGD variants, this
|
||||
simplifies hyperparameter search since it decouples the settings
|
||||
of weight decay and learning rate. For adaptive gradient
|
||||
algorithms, it regularizes variables with large gradients more
|
||||
than L2 regularization would, which was shown to yield better
|
||||
training loss and generalization error in the paper above.
|
||||
"""
|
||||
|
||||
def __init__(self, weight_decay, *args, **kwargs):
|
||||
# super delegation is necessary here
|
||||
super(OptimizerWithDecoupledWeightDecay, self).__init__(
|
||||
weight_decay, *args, **kwargs)
|
||||
|
||||
return OptimizerWithDecoupledWeightDecay
|
||||
|
||||
|
||||
class SGDW(DecoupledWeightDecayExtension, tf.keras.optimizers.SGD):
|
||||
"""Optimizer that implements the Momentum algorithm with weight_decay.
|
||||
|
||||
This is an implementation of the SGDW optimizer described in "Decoupled
|
||||
Weight Decay Regularization" by Loshchilov & Hutter
|
||||
(https://arxiv.org/abs/1711.05101)
|
||||
([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
|
||||
It computes the update step of `tf.keras.optimizers.SGD` and additionally
|
||||
decays the variable. Note that this is different from adding
|
||||
L2 regularization on the variables to the loss. Decoupling the weight decay
|
||||
from other hyperparameters (in particular the learning rate) simplifies
|
||||
hyperparameter search.
|
||||
|
||||
For further information see the documentation of the SGD Optimizer.
|
||||
|
||||
This optimizer can also be instantiated as
|
||||
```python
|
||||
extend_with_decoupled_weight_decay(tf.keras.optimizers.SGD,
|
||||
weight_decay=weight_decay)
|
||||
```
|
||||
|
||||
Note: when applying a decay to the learning rate, be sure to manually apply
|
||||
the decay to the `weight_decay` as well. For example:
|
||||
|
||||
```python
|
||||
step = tf.Variable(0, trainable=False)
|
||||
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
|
||||
[10000, 15000], [1e-0, 1e-1, 1e-2])
|
||||
# lr and wd can be a function or a tensor
|
||||
lr = 1e-1 * schedule(step)
|
||||
wd = lambda: 1e-4 * schedule(step)
|
||||
|
||||
# ...
|
||||
|
||||
optimizer = tfa.optimizers.SGDW(
|
||||
learning_rate=lr, weight_decay=wd, momentum=0.9)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
weight_decay,
|
||||
learning_rate=0.001,
|
||||
momentum=0.0,
|
||||
nesterov=False,
|
||||
name='SGDW',
|
||||
**kwargs):
|
||||
"""Construct a new SGDW optimizer.
|
||||
|
||||
For further information see the documentation of the SGD Optimizer.
|
||||
|
||||
Args:
|
||||
learning_rate: float hyperparameter >= 0. Learning rate.
|
||||
momentum: float hyperparameter >= 0 that accelerates SGD in the
|
||||
relevant direction and dampens oscillations.
|
||||
nesterov: boolean. Whether to apply Nesterov momentum.
|
||||
name: Optional name prefix for the operations created when applying
|
||||
gradients. Defaults to 'SGD'.
|
||||
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
|
||||
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
|
||||
norm; `clipvalue` is clip gradients by value, `decay` is
|
||||
included for backward compatibility to allow time inverse decay
|
||||
of learning rate. `lr` is included for backward compatibility,
|
||||
recommended to use `learning_rate` instead.
|
||||
"""
|
||||
super(SGDW, self).__init__(
|
||||
weight_decay,
|
||||
learning_rate=learning_rate,
|
||||
momentum=momentum,
|
||||
nesterov=nesterov,
|
||||
name=name,
|
||||
**kwargs)
|
||||
|
||||
|
||||
class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam):
|
||||
"""Optimizer that implements the Adam algorithm with weight decay.
|
||||
|
||||
This is an implementation of the AdamW optimizer described in "Decoupled
|
||||
Weight Decay Regularization" by Loshchilov & Hutter
|
||||
(https://arxiv.org/abs/1711.05101)
|
||||
([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
|
||||
|
||||
It computes the update step of `tf.keras.optimizers.Adam` and additionally
|
||||
decays the variable. Note that this is different from adding L2
|
||||
regularization on the variables to the loss: it regularizes variables with
|
||||
large gradients more than L2 regularization would, which was shown to yield
|
||||
better training loss and generalization error in the paper above.
|
||||
|
||||
For further information see the documentation of the Adam Optimizer.
|
||||
|
||||
This optimizer can also be instantiated as
|
||||
```python
|
||||
extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam,
|
||||
weight_decay=weight_decay)
|
||||
```
|
||||
|
||||
Note: when applying a decay to the learning rate, be sure to manually apply
|
||||
the decay to the `weight_decay` as well. For example:
|
||||
|
||||
```python
|
||||
step = tf.Variable(0, trainable=False)
|
||||
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
|
||||
[10000, 15000], [1e-0, 1e-1, 1e-2])
|
||||
# lr and wd can be a function or a tensor
|
||||
lr = 1e-1 * schedule(step)
|
||||
wd = lambda: 1e-4 * schedule(step)
|
||||
|
||||
# ...
|
||||
|
||||
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
weight_decay,
|
||||
learning_rate=0.001,
|
||||
beta_1=0.9,
|
||||
beta_2=0.999,
|
||||
epsilon=1e-07,
|
||||
amsgrad=False,
|
||||
name="AdamW",
|
||||
**kwargs):
|
||||
"""Construct a new AdamW optimizer.
|
||||
|
||||
For further information see the documentation of the Adam Optimizer.
|
||||
|
||||
Args:
|
||||
weight_decay: A Tensor or a floating point value. The weight decay.
|
||||
learning_rate: A Tensor or a floating point value. The learning
|
||||
rate.
|
||||
beta_1: A float value or a constant float tensor. The exponential
|
||||
decay rate for the 1st moment estimates.
|
||||
beta_2: A float value or a constant float tensor. The exponential
|
||||
decay rate for the 2nd moment estimates.
|
||||
epsilon: A small constant for numerical stability. This epsilon is
|
||||
"epsilon hat" in the Kingma and Ba paper (in the formula just
|
||||
before Section 2.1), not the epsilon in Algorithm 1 of the
|
||||
paper.
|
||||
amsgrad: boolean. Whether to apply AMSGrad variant of this
|
||||
algorithm from the paper "On the Convergence of Adam and
|
||||
beyond".
|
||||
name: Optional name for the operations created when applying
|
||||
gradients. Defaults to "AdamW".
|
||||
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
|
||||
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
|
||||
norm; `clipvalue` is clip gradients by value, `decay` is
|
||||
included for backward compatibility to allow time inverse decay
|
||||
of learning rate. `lr` is included for backward compatibility,
|
||||
recommended to use `learning_rate` instead.
|
||||
"""
|
||||
super(AdamW, self).__init__(
|
||||
weight_decay,
|
||||
learning_rate=learning_rate,
|
||||
beta_1=beta_1,
|
||||
beta_2=beta_2,
|
||||
epsilon=epsilon,
|
||||
amsgrad=amsgrad,
|
||||
name=name,
|
||||
**kwargs)
|
Loading…
Reference in New Issue
Block a user