diff --git a/exps-tf/one-shot-nas.py b/exps-tf/one-shot-nas.py new file mode 100644 index 0000000..181d9d8 --- /dev/null +++ b/exps-tf/one-shot-nas.py @@ -0,0 +1,206 @@ +# [D-X-Y] +# Run DARTS +# CUDA_VISIBLE_DEVICES=0 python exps-tf/one-shot-nas.py --epochs 50 +# +import os, sys, math, time, random, argparse +import tensorflow as tf +from pathlib import Path + +lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() +if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +# self-lib +from tf_models import get_cell_based_tiny_net +from tf_optimizers import SGDW, AdamW +from config_utils import dict2config +from log_utils import time_string +from models import CellStructure + + +def pre_process(image_a, label_a, image_b, label_b): + def standard_func(image): + x = tf.pad(image, [[4, 4], [4, 4], [0, 0]]) + x = tf.image.random_crop(x, [32, 32, 3]) + x = tf.image.random_flip_left_right(x) + return x + return standard_func(image_a), label_a, standard_func(image_b), label_b + + +class CosineAnnealingLR(object): + def __init__(self, warmup_epochs, epochs, initial_lr, min_lr): + self.warmup_epochs = warmup_epochs + self.epochs = epochs + self.initial_lr = initial_lr + self.min_lr = min_lr + + def get_lr(self, epoch): + if epoch < self.warmup_epochs: + lr = self.min_lr + (epoch/self.warmup_epochs) * (self.initial_lr-self.min_lr) + elif epoch >= self.epochs: + lr = self.min_lr + else: + lr = self.min_lr + (self.initial_lr-self.min_lr) * 0.5 * (1 + math.cos(math.pi * epoch / self.epochs)) + return lr + + + +def main(xargs): + cifar10 = tf.keras.datasets.cifar10 + + (x_train, y_train), (x_test, y_test) = cifar10.load_data() + x_train, x_test = x_train / 255.0, x_test / 255.0 + x_train, x_test = x_train.astype('float32'), x_test.astype('float32') + y_train, y_test = y_train.reshape(-1), y_test.reshape(-1) + + # Add a channels dimension + all_indexes = list(range(x_train.shape[0])) + random.shuffle(all_indexes) + s_train_idxs, s_valid_idxs = all_indexes[::2], all_indexes[1::2] + search_train_x, search_train_y = x_train[s_train_idxs], y_train[s_train_idxs] + search_valid_x, search_valid_y = x_train[s_valid_idxs], y_train[s_valid_idxs] + #x_train, x_test = x_train[..., tf.newaxis], x_test[..., tf.newaxis] + + # Use tf.data + #train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(64) + search_ds = tf.data.Dataset.from_tensor_slices((search_train_x, search_train_y, search_valid_x, search_valid_y)) + search_ds = search_ds.map(pre_process).shuffle(1000).batch(64) + + test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32) + + # Create an instance of the model + config = dict2config({'name': 'DARTS', + 'C' : xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, + 'num_classes': 10, 'space': 'nas-bench-201', 'affine': True}, None) + model = get_cell_based_tiny_net(config) + num_iters_per_epoch = int(tf.data.experimental.cardinality(search_ds).numpy()) + #lr_schedular = tf.keras.experimental.CosineDecay(xargs.w_lr_max, num_iters_per_epoch*xargs.epochs, xargs.w_lr_min / xargs.w_lr_max) + lr_schedular = CosineAnnealingLR(0, xargs.epochs, xargs.w_lr_max, xargs.w_lr_min) + # Choose optimizer + loss_object = tf.keras.losses.CategoricalCrossentropy() + w_optimizer = SGDW(learning_rate=xargs.w_lr_max, weight_decay=xargs.w_weight_decay, momentum=xargs.w_momentum, nesterov=True) + a_optimizer = AdamW(learning_rate=xargs.arch_learning_rate, weight_decay=xargs.arch_weight_decay, beta_1=0.5, beta_2=0.999, epsilon=1e-07) + #w_optimizer = tf.keras.optimizers.SGD(learning_rate=0.025, momentum=0.9, nesterov=True) + #a_optimizer = tf.keras.optimizers.AdamW(learning_rate=xargs.arch_learning_rate, beta_1=0.5, beta_2=0.999, epsilon=1e-07) + #### + # metrics + train_loss = tf.keras.metrics.Mean(name='train_loss') + train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy') + valid_loss = tf.keras.metrics.Mean(name='valid_loss') + valid_accuracy = tf.keras.metrics.CategoricalAccuracy(name='valid_accuracy') + test_loss = tf.keras.metrics.Mean(name='test_loss') + test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') + + @tf.function + def search_step(train_images, train_labels, valid_images, valid_labels): + # optimize weights + with tf.GradientTape() as tape: + predictions = model(train_images, True) + w_loss = loss_object(train_labels, predictions) + net_w_param = model.get_weights() + gradients = tape.gradient(w_loss, net_w_param) + w_optimizer.apply_gradients(zip(gradients, net_w_param)) + train_loss(w_loss) + train_accuracy(train_labels, predictions) + # optimize alphas + with tf.GradientTape() as tape: + predictions = model(valid_images, True) + a_loss = loss_object(valid_labels, predictions) + net_a_param = model.get_alphas() + gradients = tape.gradient(a_loss, net_a_param) + a_optimizer.apply_gradients(zip(gradients, net_a_param)) + valid_loss(a_loss) + valid_accuracy(valid_labels, predictions) + + # IFT with Neumann approximation + @tf.function + def search_step_IFTNA(train_images, train_labels, valid_images, valid_labels, max_step): + # optimize weights + with tf.GradientTape() as tape: + predictions = model(train_images, True) + w_loss = loss_object(train_labels, predictions) + # get the weights + net_w_param = model.get_weights() + net_a_param = model.get_alphas() + gradients = tape.gradient(w_loss, net_w_param) + w_optimizer.apply_gradients(zip(gradients, net_w_param)) + train_loss(w_loss) + train_accuracy(train_labels, predictions) + # optimize alphas + with tf.GradientTape(persistent=True) as tape: + predictions = model(valid_images, True) + val_loss = loss_object(valid_labels, predictions) + predictions = model(train_images, True) + trn_loss = loss_object(train_labels, predictions) + # ---- + dV_dW = tape.gradient(val_loss, net_w_param) + # approxInverseHVP to calculate v2 + sum_p = v1 = dV_dW + dT_dW = tape.gradient(trn_loss, net_w_param) + for j in range(1, max_step): + temp_dot = tape.gradient(dT_dW, net_w_param, output_gradients=v1) + v1 = [tf.subtract(A, B) for A, B in zip(v1, temp_dot)] + sum_p = [tf.add(A, B) for A, B in zip(sum_p, v1)] + # calculate v3 + dT_dl = tape.gradient(trn_loss, net_a_param) + import pdb; pdb.set_trace() + v3 = tape.gradient(dT_dl, net_w_param, output_gradients=sum_p) + dV_dl = tape.gradient(val_loss, net_a_param) + a_gradients = [tf.subtract(A, B) for A, B in zip(dV_dl, v3)] + import pdb; pdb.set_trace() + print('--') + + # TEST + @tf.function + def test_step(images, labels): + predictions = model(images) + t_loss = loss_object(labels, predictions) + + test_loss(t_loss) + test_accuracy(labels, predictions) + + print('{:} start searching with {:} epochs ({:} batches per epoch).'.format(time_string(), xargs.epochs, num_iters_per_epoch)) + + for epoch in range(xargs.epochs): + # Reset the metrics at the start of the next epoch + train_loss.reset_states() ; train_accuracy.reset_states() + test_loss.reset_states() ; test_accuracy.reset_states() + cur_lr = lr_schedular.get_lr(epoch) + tf.keras.backend.set_value(w_optimizer.lr, cur_lr) + + for trn_imgs, trn_labels, val_imgs, val_labels in search_ds: + #search_step(trn_imgs, trn_labels, val_imgs, val_labels) + trn_labels, val_labels = tf.one_hot(trn_labels, 10), tf.one_hot(val_labels, 10) + search_step_IFTNA(trn_imgs, trn_labels, val_imgs, val_labels, 5) + genotype = model.genotype() + genotype = CellStructure(genotype) + + #for test_images, test_labels in test_ds: + # test_step(test_images, test_labels) + + cur_lr = float(tf.keras.backend.get_value(w_optimizer.lr)) + template = '{:} Epoch {:03d}/{:03d}, Train-Loss: {:.3f}, Train-Accuracy: {:.2f}%, Valid-Loss: {:.3f}, Valid-Accuracy: {:.2f}% | lr={:.6f}' + print(template.format(time_string(), epoch+1, xargs.epochs, + train_loss.result(), + train_accuracy.result()*100, + valid_loss.result(), + valid_accuracy.result()*100, + cur_lr)) + print('{:} genotype : {:}\n{:}\n'.format(time_string(), genotype, model.get_np_alphas())) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) + # training details + parser.add_argument('--epochs' , type=int , default= 250 , help='') + parser.add_argument('--w_lr_max' , type=float, default= 0.025, help='') + parser.add_argument('--w_lr_min' , type=float, default= 0.001, help='') + parser.add_argument('--w_weight_decay' , type=float, default=0.0005, help='') + parser.add_argument('--w_momentum' , type=float, default= 0.9 , help='') + parser.add_argument('--arch_learning_rate', type=float, default=0.0003, help='') + parser.add_argument('--arch_weight_decay' , type=float, default=0.001, help='') + # marco structure + parser.add_argument('--channel' , type=int , default=16, help='') + parser.add_argument('--num_cells' , type=int , default= 5, help='') + parser.add_argument('--max_nodes' , type=int , default= 4, help='') + args = parser.parse_args() + main( args ) diff --git a/exps-tf/test-invH.py b/exps-tf/test-invH.py new file mode 100644 index 0000000..b455506 --- /dev/null +++ b/exps-tf/test-invH.py @@ -0,0 +1,46 @@ +import os, sys, math, time, random, argparse +import tensorflow as tf +from pathlib import Path + + +def test_a(): + x = tf.Variable([[1.], [2.], [4.0]]) + with tf.GradientTape(persistent=True) as g: + trn = tf.math.exp(tf.math.reduce_sum(x)) + val = tf.math.cos(tf.math.reduce_sum(x)) + dT_dx = g.gradient(trn, x) + dV_dx = g.gradient(val, x) + hess_vector = g.gradient(dT_dx, x, output_gradients=dV_dx) + print ('calculate ok : {:}'.format(hess_vector)) + +def test_b(): + cce = tf.keras.losses.SparseCategoricalCrossentropy() + L1 = tf.convert_to_tensor([0, 1, 2]) + L2 = tf.convert_to_tensor([2, 0, 1]) + B = tf.Variable([[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]]) + with tf.GradientTape(persistent=True) as g: + trn = cce(L1, B) + val = cce(L2, B) + dT_dx = g.gradient(trn, B) + dV_dx = g.gradient(val, B) + hess_vector = g.gradient(dT_dx, B, output_gradients=dV_dx) + print ('calculate ok : {:}'.format(hess_vector)) + +def test_c(): + cce = tf.keras.losses.CategoricalCrossentropy() + L1 = tf.convert_to_tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) + L2 = tf.convert_to_tensor([[0., 0., 1.], [0., 1., 0.], [1., 0., 0.]]) + B = tf.Variable([[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]]) + with tf.GradientTape(persistent=True) as g: + trn = cce(L1, B) + val = cce(L2, B) + dT_dx = g.gradient(trn, B) + dV_dx = g.gradient(val, B) + hess_vector = g.gradient(dT_dx, B, output_gradients=dV_dx) + print ('calculate ok : {:}'.format(hess_vector)) + +if __name__ == '__main__': + print(tf.__version__) + test_c() + #test_b() + #test_a() diff --git a/lib/tf_models/__init__.py b/lib/tf_models/__init__.py index ac05da5..d402913 100644 --- a/lib/tf_models/__init__.py +++ b/lib/tf_models/__init__.py @@ -9,7 +9,7 @@ __all__ = ['get_cell_based_tiny_net', 'get_search_spaces'] # the cell-based NAS models def get_cell_based_tiny_net(config): - group_names = ['GDAS'] + group_names = ['GDAS', 'DARTS'] if config.name in group_names: from .cell_searchs import nas_super_nets from .cell_operations import SearchSpaceNames diff --git a/lib/tf_models/cell_searchs/__init__.py b/lib/tf_models/cell_searchs/__init__.py index 479cb03..717fbe4 100644 --- a/lib/tf_models/cell_searchs/__init__.py +++ b/lib/tf_models/cell_searchs/__init__.py @@ -2,5 +2,7 @@ # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # ################################################## from .search_model_gdas import TinyNetworkGDAS +from .search_model_darts import TinyNetworkDARTS -nas_super_nets = {'GDAS': TinyNetworkGDAS} +nas_super_nets = {'GDAS' : TinyNetworkGDAS, + 'DARTS': TinyNetworkDARTS} diff --git a/lib/tf_models/cell_searchs/search_model_darts.py b/lib/tf_models/cell_searchs/search_model_darts.py new file mode 100644 index 0000000..ad05b8b --- /dev/null +++ b/lib/tf_models/cell_searchs/search_model_darts.py @@ -0,0 +1,83 @@ +import tensorflow as tf +import numpy as np +from copy import deepcopy +from ..cell_operations import ResNetBasicblock +from .search_cells import NAS201SearchCell as SearchCell + + +class TinyNetworkDARTS(tf.keras.Model): + + def __init__(self, C, N, max_nodes, num_classes, search_space, affine): + super(TinyNetworkDARTS, self).__init__() + self._C = C + self._layerN = N + self.max_nodes = max_nodes + self.stem = tf.keras.Sequential([ + tf.keras.layers.Conv2D(16, 3, 1, padding='same', use_bias=False), + tf.keras.layers.BatchNormalization()], name='stem') + + layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + + C_prev, num_edge, edge2index = C, None, None + for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): + cell_prefix = 'cell-{:03d}'.format(index) + #with tf.name_scope(cell_prefix) as scope: + if reduction: + cell = ResNetBasicblock(C_prev, C_curr, 2) + else: + cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine) + if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index + else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) + C_prev = cell.out_dim + setattr(self, cell_prefix, cell) + self.num_layers = len(layer_reductions) + self.op_names = deepcopy( search_space ) + self.edge2index = edge2index + self.num_edge = num_edge + self.lastact = tf.keras.Sequential([ + tf.keras.layers.BatchNormalization(), + tf.keras.layers.ReLU(), + tf.keras.layers.GlobalAvgPool2D(), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(num_classes, activation='softmax')], name='lastact') + #self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) + arch_init = tf.random_normal_initializer(mean=0, stddev=0.001) + self.arch_parameters = tf.Variable(initial_value=arch_init(shape=(num_edge, len(search_space)), dtype='float32'), trainable=True, name='arch-encoding') + + def get_alphas(self): + xlist = self.trainable_variables + return [x for x in xlist if 'arch-encoding' in x.name] + + def get_weights(self): + xlist = self.trainable_variables + return [x for x in xlist if 'arch-encoding' not in x.name] + + def get_np_alphas(self): + arch_nps = self.arch_parameters.numpy() + arch_ops = np.exp(arch_nps) / np.sum(np.exp(arch_nps), axis=-1, keepdims=True) + return arch_ops + + def genotype(self): + genotypes, arch_nps = [], self.arch_parameters.numpy() + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + weights = arch_nps[ self.edge2index[node_str] ] + op_name = self.op_names[ weights.argmax().item() ] + xlist.append((op_name, j)) + genotypes.append( tuple(xlist) ) + return genotypes + + def call(self, inputs, training): + weightss = tf.nn.softmax(self.arch_parameters, axis=1) + feature = self.stem(inputs, training) + for idx in range(self.num_layers): + cell = getattr(self, 'cell-{:03d}'.format(idx)) + if isinstance(cell, SearchCell): + feature = cell.call(feature, weightss, training) + else: + feature = cell(feature, training) + logits = self.lastact(feature, training) + return logits