update TF

This commit is contained in:
D-X-Y 2020-01-20 15:24:12 +11:00
parent 1ce3249a5a
commit 533a508444
5 changed files with 339 additions and 2 deletions

206
exps-tf/one-shot-nas.py Normal file
View File

@ -0,0 +1,206 @@
# [D-X-Y]
# Run DARTS
# CUDA_VISIBLE_DEVICES=0 python exps-tf/one-shot-nas.py --epochs 50
#
import os, sys, math, time, random, argparse
import tensorflow as tf
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
# self-lib
from tf_models import get_cell_based_tiny_net
from tf_optimizers import SGDW, AdamW
from config_utils import dict2config
from log_utils import time_string
from models import CellStructure
def pre_process(image_a, label_a, image_b, label_b):
def standard_func(image):
x = tf.pad(image, [[4, 4], [4, 4], [0, 0]])
x = tf.image.random_crop(x, [32, 32, 3])
x = tf.image.random_flip_left_right(x)
return x
return standard_func(image_a), label_a, standard_func(image_b), label_b
class CosineAnnealingLR(object):
def __init__(self, warmup_epochs, epochs, initial_lr, min_lr):
self.warmup_epochs = warmup_epochs
self.epochs = epochs
self.initial_lr = initial_lr
self.min_lr = min_lr
def get_lr(self, epoch):
if epoch < self.warmup_epochs:
lr = self.min_lr + (epoch/self.warmup_epochs) * (self.initial_lr-self.min_lr)
elif epoch >= self.epochs:
lr = self.min_lr
else:
lr = self.min_lr + (self.initial_lr-self.min_lr) * 0.5 * (1 + math.cos(math.pi * epoch / self.epochs))
return lr
def main(xargs):
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train, x_test = x_train.astype('float32'), x_test.astype('float32')
y_train, y_test = y_train.reshape(-1), y_test.reshape(-1)
# Add a channels dimension
all_indexes = list(range(x_train.shape[0]))
random.shuffle(all_indexes)
s_train_idxs, s_valid_idxs = all_indexes[::2], all_indexes[1::2]
search_train_x, search_train_y = x_train[s_train_idxs], y_train[s_train_idxs]
search_valid_x, search_valid_y = x_train[s_valid_idxs], y_train[s_valid_idxs]
#x_train, x_test = x_train[..., tf.newaxis], x_test[..., tf.newaxis]
# Use tf.data
#train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(64)
search_ds = tf.data.Dataset.from_tensor_slices((search_train_x, search_train_y, search_valid_x, search_valid_y))
search_ds = search_ds.map(pre_process).shuffle(1000).batch(64)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
# Create an instance of the model
config = dict2config({'name': 'DARTS',
'C' : xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes,
'num_classes': 10, 'space': 'nas-bench-201', 'affine': True}, None)
model = get_cell_based_tiny_net(config)
num_iters_per_epoch = int(tf.data.experimental.cardinality(search_ds).numpy())
#lr_schedular = tf.keras.experimental.CosineDecay(xargs.w_lr_max, num_iters_per_epoch*xargs.epochs, xargs.w_lr_min / xargs.w_lr_max)
lr_schedular = CosineAnnealingLR(0, xargs.epochs, xargs.w_lr_max, xargs.w_lr_min)
# Choose optimizer
loss_object = tf.keras.losses.CategoricalCrossentropy()
w_optimizer = SGDW(learning_rate=xargs.w_lr_max, weight_decay=xargs.w_weight_decay, momentum=xargs.w_momentum, nesterov=True)
a_optimizer = AdamW(learning_rate=xargs.arch_learning_rate, weight_decay=xargs.arch_weight_decay, beta_1=0.5, beta_2=0.999, epsilon=1e-07)
#w_optimizer = tf.keras.optimizers.SGD(learning_rate=0.025, momentum=0.9, nesterov=True)
#a_optimizer = tf.keras.optimizers.AdamW(learning_rate=xargs.arch_learning_rate, beta_1=0.5, beta_2=0.999, epsilon=1e-07)
####
# metrics
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
valid_loss = tf.keras.metrics.Mean(name='valid_loss')
valid_accuracy = tf.keras.metrics.CategoricalAccuracy(name='valid_accuracy')
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
@tf.function
def search_step(train_images, train_labels, valid_images, valid_labels):
# optimize weights
with tf.GradientTape() as tape:
predictions = model(train_images, True)
w_loss = loss_object(train_labels, predictions)
net_w_param = model.get_weights()
gradients = tape.gradient(w_loss, net_w_param)
w_optimizer.apply_gradients(zip(gradients, net_w_param))
train_loss(w_loss)
train_accuracy(train_labels, predictions)
# optimize alphas
with tf.GradientTape() as tape:
predictions = model(valid_images, True)
a_loss = loss_object(valid_labels, predictions)
net_a_param = model.get_alphas()
gradients = tape.gradient(a_loss, net_a_param)
a_optimizer.apply_gradients(zip(gradients, net_a_param))
valid_loss(a_loss)
valid_accuracy(valid_labels, predictions)
# IFT with Neumann approximation
@tf.function
def search_step_IFTNA(train_images, train_labels, valid_images, valid_labels, max_step):
# optimize weights
with tf.GradientTape() as tape:
predictions = model(train_images, True)
w_loss = loss_object(train_labels, predictions)
# get the weights
net_w_param = model.get_weights()
net_a_param = model.get_alphas()
gradients = tape.gradient(w_loss, net_w_param)
w_optimizer.apply_gradients(zip(gradients, net_w_param))
train_loss(w_loss)
train_accuracy(train_labels, predictions)
# optimize alphas
with tf.GradientTape(persistent=True) as tape:
predictions = model(valid_images, True)
val_loss = loss_object(valid_labels, predictions)
predictions = model(train_images, True)
trn_loss = loss_object(train_labels, predictions)
# ----
dV_dW = tape.gradient(val_loss, net_w_param)
# approxInverseHVP to calculate v2
sum_p = v1 = dV_dW
dT_dW = tape.gradient(trn_loss, net_w_param)
for j in range(1, max_step):
temp_dot = tape.gradient(dT_dW, net_w_param, output_gradients=v1)
v1 = [tf.subtract(A, B) for A, B in zip(v1, temp_dot)]
sum_p = [tf.add(A, B) for A, B in zip(sum_p, v1)]
# calculate v3
dT_dl = tape.gradient(trn_loss, net_a_param)
import pdb; pdb.set_trace()
v3 = tape.gradient(dT_dl, net_w_param, output_gradients=sum_p)
dV_dl = tape.gradient(val_loss, net_a_param)
a_gradients = [tf.subtract(A, B) for A, B in zip(dV_dl, v3)]
import pdb; pdb.set_trace()
print('--')
# TEST
@tf.function
def test_step(images, labels):
predictions = model(images)
t_loss = loss_object(labels, predictions)
test_loss(t_loss)
test_accuracy(labels, predictions)
print('{:} start searching with {:} epochs ({:} batches per epoch).'.format(time_string(), xargs.epochs, num_iters_per_epoch))
for epoch in range(xargs.epochs):
# Reset the metrics at the start of the next epoch
train_loss.reset_states() ; train_accuracy.reset_states()
test_loss.reset_states() ; test_accuracy.reset_states()
cur_lr = lr_schedular.get_lr(epoch)
tf.keras.backend.set_value(w_optimizer.lr, cur_lr)
for trn_imgs, trn_labels, val_imgs, val_labels in search_ds:
#search_step(trn_imgs, trn_labels, val_imgs, val_labels)
trn_labels, val_labels = tf.one_hot(trn_labels, 10), tf.one_hot(val_labels, 10)
search_step_IFTNA(trn_imgs, trn_labels, val_imgs, val_labels, 5)
genotype = model.genotype()
genotype = CellStructure(genotype)
#for test_images, test_labels in test_ds:
# test_step(test_images, test_labels)
cur_lr = float(tf.keras.backend.get_value(w_optimizer.lr))
template = '{:} Epoch {:03d}/{:03d}, Train-Loss: {:.3f}, Train-Accuracy: {:.2f}%, Valid-Loss: {:.3f}, Valid-Accuracy: {:.2f}% | lr={:.6f}'
print(template.format(time_string(), epoch+1, xargs.epochs,
train_loss.result(),
train_accuracy.result()*100,
valid_loss.result(),
valid_accuracy.result()*100,
cur_lr))
print('{:} genotype : {:}\n{:}\n'.format(time_string(), genotype, model.get_np_alphas()))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# training details
parser.add_argument('--epochs' , type=int , default= 250 , help='')
parser.add_argument('--w_lr_max' , type=float, default= 0.025, help='')
parser.add_argument('--w_lr_min' , type=float, default= 0.001, help='')
parser.add_argument('--w_weight_decay' , type=float, default=0.0005, help='')
parser.add_argument('--w_momentum' , type=float, default= 0.9 , help='')
parser.add_argument('--arch_learning_rate', type=float, default=0.0003, help='')
parser.add_argument('--arch_weight_decay' , type=float, default=0.001, help='')
# marco structure
parser.add_argument('--channel' , type=int , default=16, help='')
parser.add_argument('--num_cells' , type=int , default= 5, help='')
parser.add_argument('--max_nodes' , type=int , default= 4, help='')
args = parser.parse_args()
main( args )

46
exps-tf/test-invH.py Normal file
View File

@ -0,0 +1,46 @@
import os, sys, math, time, random, argparse
import tensorflow as tf
from pathlib import Path
def test_a():
x = tf.Variable([[1.], [2.], [4.0]])
with tf.GradientTape(persistent=True) as g:
trn = tf.math.exp(tf.math.reduce_sum(x))
val = tf.math.cos(tf.math.reduce_sum(x))
dT_dx = g.gradient(trn, x)
dV_dx = g.gradient(val, x)
hess_vector = g.gradient(dT_dx, x, output_gradients=dV_dx)
print ('calculate ok : {:}'.format(hess_vector))
def test_b():
cce = tf.keras.losses.SparseCategoricalCrossentropy()
L1 = tf.convert_to_tensor([0, 1, 2])
L2 = tf.convert_to_tensor([2, 0, 1])
B = tf.Variable([[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]])
with tf.GradientTape(persistent=True) as g:
trn = cce(L1, B)
val = cce(L2, B)
dT_dx = g.gradient(trn, B)
dV_dx = g.gradient(val, B)
hess_vector = g.gradient(dT_dx, B, output_gradients=dV_dx)
print ('calculate ok : {:}'.format(hess_vector))
def test_c():
cce = tf.keras.losses.CategoricalCrossentropy()
L1 = tf.convert_to_tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
L2 = tf.convert_to_tensor([[0., 0., 1.], [0., 1., 0.], [1., 0., 0.]])
B = tf.Variable([[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]])
with tf.GradientTape(persistent=True) as g:
trn = cce(L1, B)
val = cce(L2, B)
dT_dx = g.gradient(trn, B)
dV_dx = g.gradient(val, B)
hess_vector = g.gradient(dT_dx, B, output_gradients=dV_dx)
print ('calculate ok : {:}'.format(hess_vector))
if __name__ == '__main__':
print(tf.__version__)
test_c()
#test_b()
#test_a()

View File

@ -9,7 +9,7 @@ __all__ = ['get_cell_based_tiny_net', 'get_search_spaces']
# the cell-based NAS models # the cell-based NAS models
def get_cell_based_tiny_net(config): def get_cell_based_tiny_net(config):
group_names = ['GDAS'] group_names = ['GDAS', 'DARTS']
if config.name in group_names: if config.name in group_names:
from .cell_searchs import nas_super_nets from .cell_searchs import nas_super_nets
from .cell_operations import SearchSpaceNames from .cell_operations import SearchSpaceNames

View File

@ -2,5 +2,7 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
################################################## ##################################################
from .search_model_gdas import TinyNetworkGDAS from .search_model_gdas import TinyNetworkGDAS
from .search_model_darts import TinyNetworkDARTS
nas_super_nets = {'GDAS': TinyNetworkGDAS} nas_super_nets = {'GDAS' : TinyNetworkGDAS,
'DARTS': TinyNetworkDARTS}

View File

@ -0,0 +1,83 @@
import tensorflow as tf
import numpy as np
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
class TinyNetworkDARTS(tf.keras.Model):
def __init__(self, C, N, max_nodes, num_classes, search_space, affine):
super(TinyNetworkDARTS, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = tf.keras.Sequential([
tf.keras.layers.Conv2D(16, 3, 1, padding='same', use_bias=False),
tf.keras.layers.BatchNormalization()], name='stem')
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
cell_prefix = 'cell-{:03d}'.format(index)
#with tf.name_scope(cell_prefix) as scope:
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
C_prev = cell.out_dim
setattr(self, cell_prefix, cell)
self.num_layers = len(layer_reductions)
self.op_names = deepcopy( search_space )
self.edge2index = edge2index
self.num_edge = num_edge
self.lastact = tf.keras.Sequential([
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.GlobalAvgPool2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(num_classes, activation='softmax')], name='lastact')
#self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
arch_init = tf.random_normal_initializer(mean=0, stddev=0.001)
self.arch_parameters = tf.Variable(initial_value=arch_init(shape=(num_edge, len(search_space)), dtype='float32'), trainable=True, name='arch-encoding')
def get_alphas(self):
xlist = self.trainable_variables
return [x for x in xlist if 'arch-encoding' in x.name]
def get_weights(self):
xlist = self.trainable_variables
return [x for x in xlist if 'arch-encoding' not in x.name]
def get_np_alphas(self):
arch_nps = self.arch_parameters.numpy()
arch_ops = np.exp(arch_nps) / np.sum(np.exp(arch_nps), axis=-1, keepdims=True)
return arch_ops
def genotype(self):
genotypes, arch_nps = [], self.arch_parameters.numpy()
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = arch_nps[ self.edge2index[node_str] ]
op_name = self.op_names[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return genotypes
def call(self, inputs, training):
weightss = tf.nn.softmax(self.arch_parameters, axis=1)
feature = self.stem(inputs, training)
for idx in range(self.num_layers):
cell = getattr(self, 'cell-{:03d}'.format(idx))
if isinstance(cell, SearchCell):
feature = cell.call(feature, weightss, training)
else:
feature = cell(feature, training)
logits = self.lastact(feature, training)
return logits