update TF
This commit is contained in:
		
							
								
								
									
										206
									
								
								exps-tf/one-shot-nas.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										206
									
								
								exps-tf/one-shot-nas.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,206 @@ | |||||||
|  | # [D-X-Y] | ||||||
|  | # Run DARTS | ||||||
|  | # CUDA_VISIBLE_DEVICES=0 python exps-tf/one-shot-nas.py --epochs 50 | ||||||
|  | # | ||||||
|  | import os, sys, math, time, random, argparse | ||||||
|  | import tensorflow as tf | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||||
|  | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
|  | # self-lib | ||||||
|  | from tf_models import get_cell_based_tiny_net | ||||||
|  | from tf_optimizers import SGDW, AdamW | ||||||
|  | from config_utils import dict2config | ||||||
|  | from log_utils import time_string | ||||||
|  | from models import CellStructure | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def pre_process(image_a, label_a, image_b, label_b): | ||||||
|  |   def standard_func(image): | ||||||
|  |     x = tf.pad(image, [[4, 4], [4, 4], [0, 0]]) | ||||||
|  |     x = tf.image.random_crop(x, [32, 32, 3]) | ||||||
|  |     x = tf.image.random_flip_left_right(x) | ||||||
|  |     return x | ||||||
|  |   return standard_func(image_a), label_a, standard_func(image_b), label_b | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CosineAnnealingLR(object): | ||||||
|  |   def __init__(self, warmup_epochs, epochs, initial_lr, min_lr): | ||||||
|  |     self.warmup_epochs = warmup_epochs | ||||||
|  |     self.epochs = epochs | ||||||
|  |     self.initial_lr = initial_lr | ||||||
|  |     self.min_lr = min_lr | ||||||
|  |  | ||||||
|  |   def get_lr(self, epoch): | ||||||
|  |     if epoch < self.warmup_epochs: | ||||||
|  |       lr = self.min_lr + (epoch/self.warmup_epochs) * (self.initial_lr-self.min_lr) | ||||||
|  |     elif epoch >= self.epochs: | ||||||
|  |       lr = self.min_lr | ||||||
|  |     else: | ||||||
|  |       lr = self.min_lr + (self.initial_lr-self.min_lr) * 0.5 * (1 + math.cos(math.pi * epoch / self.epochs)) | ||||||
|  |     return lr | ||||||
|  |        | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(xargs): | ||||||
|  |   cifar10 = tf.keras.datasets.cifar10 | ||||||
|  |  | ||||||
|  |   (x_train, y_train), (x_test, y_test) = cifar10.load_data() | ||||||
|  |   x_train, x_test = x_train / 255.0, x_test / 255.0 | ||||||
|  |   x_train, x_test = x_train.astype('float32'), x_test.astype('float32') | ||||||
|  |   y_train, y_test = y_train.reshape(-1), y_test.reshape(-1) | ||||||
|  |  | ||||||
|  |   # Add a channels dimension | ||||||
|  |   all_indexes = list(range(x_train.shape[0])) | ||||||
|  |   random.shuffle(all_indexes) | ||||||
|  |   s_train_idxs, s_valid_idxs = all_indexes[::2], all_indexes[1::2] | ||||||
|  |   search_train_x, search_train_y = x_train[s_train_idxs], y_train[s_train_idxs] | ||||||
|  |   search_valid_x, search_valid_y = x_train[s_valid_idxs], y_train[s_valid_idxs] | ||||||
|  |   #x_train, x_test = x_train[..., tf.newaxis], x_test[..., tf.newaxis] | ||||||
|  |    | ||||||
|  |   # Use tf.data | ||||||
|  |   #train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(64) | ||||||
|  |   search_ds = tf.data.Dataset.from_tensor_slices((search_train_x, search_train_y, search_valid_x, search_valid_y)) | ||||||
|  |   search_ds = search_ds.map(pre_process).shuffle(1000).batch(64) | ||||||
|  |  | ||||||
|  |   test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32) | ||||||
|  |  | ||||||
|  |   # Create an instance of the model | ||||||
|  |   config = dict2config({'name': 'DARTS', | ||||||
|  |                         'C'   : xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, | ||||||
|  |                         'num_classes': 10, 'space': 'nas-bench-201', 'affine': True}, None) | ||||||
|  |   model = get_cell_based_tiny_net(config) | ||||||
|  |   num_iters_per_epoch = int(tf.data.experimental.cardinality(search_ds).numpy()) | ||||||
|  |   #lr_schedular = tf.keras.experimental.CosineDecay(xargs.w_lr_max, num_iters_per_epoch*xargs.epochs, xargs.w_lr_min / xargs.w_lr_max) | ||||||
|  |   lr_schedular = CosineAnnealingLR(0, xargs.epochs, xargs.w_lr_max, xargs.w_lr_min) | ||||||
|  |   # Choose optimizer | ||||||
|  |   loss_object = tf.keras.losses.CategoricalCrossentropy() | ||||||
|  |   w_optimizer = SGDW(learning_rate=xargs.w_lr_max, weight_decay=xargs.w_weight_decay, momentum=xargs.w_momentum, nesterov=True) | ||||||
|  |   a_optimizer = AdamW(learning_rate=xargs.arch_learning_rate, weight_decay=xargs.arch_weight_decay, beta_1=0.5, beta_2=0.999, epsilon=1e-07) | ||||||
|  |   #w_optimizer = tf.keras.optimizers.SGD(learning_rate=0.025, momentum=0.9, nesterov=True) | ||||||
|  |   #a_optimizer = tf.keras.optimizers.AdamW(learning_rate=xargs.arch_learning_rate, beta_1=0.5, beta_2=0.999, epsilon=1e-07) | ||||||
|  |   #### | ||||||
|  |   # metrics | ||||||
|  |   train_loss = tf.keras.metrics.Mean(name='train_loss') | ||||||
|  |   train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy') | ||||||
|  |   valid_loss = tf.keras.metrics.Mean(name='valid_loss') | ||||||
|  |   valid_accuracy = tf.keras.metrics.CategoricalAccuracy(name='valid_accuracy') | ||||||
|  |   test_loss = tf.keras.metrics.Mean(name='test_loss') | ||||||
|  |   test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') | ||||||
|  |    | ||||||
|  |   @tf.function | ||||||
|  |   def search_step(train_images, train_labels, valid_images, valid_labels): | ||||||
|  |     # optimize weights | ||||||
|  |     with tf.GradientTape() as tape: | ||||||
|  |       predictions = model(train_images, True) | ||||||
|  |       w_loss = loss_object(train_labels, predictions) | ||||||
|  |     net_w_param = model.get_weights() | ||||||
|  |     gradients = tape.gradient(w_loss, net_w_param) | ||||||
|  |     w_optimizer.apply_gradients(zip(gradients, net_w_param)) | ||||||
|  |     train_loss(w_loss) | ||||||
|  |     train_accuracy(train_labels, predictions) | ||||||
|  |     # optimize alphas | ||||||
|  |     with tf.GradientTape() as tape: | ||||||
|  |       predictions = model(valid_images, True) | ||||||
|  |       a_loss = loss_object(valid_labels, predictions) | ||||||
|  |     net_a_param = model.get_alphas() | ||||||
|  |     gradients = tape.gradient(a_loss, net_a_param) | ||||||
|  |     a_optimizer.apply_gradients(zip(gradients, net_a_param)) | ||||||
|  |     valid_loss(a_loss) | ||||||
|  |     valid_accuracy(valid_labels, predictions) | ||||||
|  |  | ||||||
|  |   # IFT with Neumann approximation | ||||||
|  |   @tf.function | ||||||
|  |   def search_step_IFTNA(train_images, train_labels, valid_images, valid_labels, max_step): | ||||||
|  |     # optimize weights | ||||||
|  |     with tf.GradientTape() as tape: | ||||||
|  |       predictions = model(train_images, True) | ||||||
|  |       w_loss = loss_object(train_labels, predictions) | ||||||
|  |     # get the weights | ||||||
|  |     net_w_param = model.get_weights() | ||||||
|  |     net_a_param = model.get_alphas() | ||||||
|  |     gradients = tape.gradient(w_loss, net_w_param) | ||||||
|  |     w_optimizer.apply_gradients(zip(gradients, net_w_param)) | ||||||
|  |     train_loss(w_loss) | ||||||
|  |     train_accuracy(train_labels, predictions) | ||||||
|  |     # optimize alphas | ||||||
|  |     with tf.GradientTape(persistent=True) as tape: | ||||||
|  |       predictions = model(valid_images, True) | ||||||
|  |       val_loss = loss_object(valid_labels, predictions) | ||||||
|  |       predictions = model(train_images, True) | ||||||
|  |       trn_loss = loss_object(train_labels, predictions) | ||||||
|  |       # ---- | ||||||
|  |       dV_dW = tape.gradient(val_loss, net_w_param) | ||||||
|  |       # approxInverseHVP to calculate v2 | ||||||
|  |       sum_p = v1 = dV_dW | ||||||
|  |       dT_dW = tape.gradient(trn_loss, net_w_param) | ||||||
|  |       for j in range(1, max_step): | ||||||
|  |         temp_dot = tape.gradient(dT_dW, net_w_param, output_gradients=v1) | ||||||
|  |         v1 = [tf.subtract(A, B) for A, B in zip(v1, temp_dot)] | ||||||
|  |         sum_p = [tf.add(A, B) for A, B in zip(sum_p, v1)] | ||||||
|  |       # calculate v3 | ||||||
|  |       dT_dl = tape.gradient(trn_loss, net_a_param) | ||||||
|  |       import pdb; pdb.set_trace() | ||||||
|  |       v3 = tape.gradient(dT_dl, net_w_param, output_gradients=sum_p) | ||||||
|  |     dV_dl = tape.gradient(val_loss, net_a_param) | ||||||
|  |     a_gradients = [tf.subtract(A, B) for A, B in zip(dV_dl, v3)] | ||||||
|  |     import pdb; pdb.set_trace() | ||||||
|  |     print('--') | ||||||
|  |  | ||||||
|  |   # TEST | ||||||
|  |   @tf.function | ||||||
|  |   def test_step(images, labels): | ||||||
|  |     predictions = model(images) | ||||||
|  |     t_loss = loss_object(labels, predictions) | ||||||
|  |  | ||||||
|  |     test_loss(t_loss) | ||||||
|  |     test_accuracy(labels, predictions) | ||||||
|  |  | ||||||
|  |   print('{:} start searching with {:} epochs ({:} batches per epoch).'.format(time_string(), xargs.epochs, num_iters_per_epoch)) | ||||||
|  |  | ||||||
|  |   for epoch in range(xargs.epochs): | ||||||
|  |     # Reset the metrics at the start of the next epoch | ||||||
|  |     train_loss.reset_states() ; train_accuracy.reset_states() | ||||||
|  |     test_loss.reset_states()  ; test_accuracy.reset_states() | ||||||
|  |     cur_lr  = lr_schedular.get_lr(epoch) | ||||||
|  |     tf.keras.backend.set_value(w_optimizer.lr, cur_lr) | ||||||
|  |  | ||||||
|  |     for trn_imgs, trn_labels, val_imgs, val_labels in search_ds: | ||||||
|  |       #search_step(trn_imgs, trn_labels, val_imgs, val_labels) | ||||||
|  |       trn_labels, val_labels = tf.one_hot(trn_labels, 10), tf.one_hot(val_labels, 10) | ||||||
|  |       search_step_IFTNA(trn_imgs, trn_labels, val_imgs, val_labels, 5) | ||||||
|  |     genotype = model.genotype() | ||||||
|  |     genotype = CellStructure(genotype) | ||||||
|  |  | ||||||
|  |     #for test_images, test_labels in test_ds: | ||||||
|  |     #  test_step(test_images, test_labels) | ||||||
|  |  | ||||||
|  |     cur_lr = float(tf.keras.backend.get_value(w_optimizer.lr)) | ||||||
|  |     template = '{:} Epoch {:03d}/{:03d}, Train-Loss: {:.3f}, Train-Accuracy: {:.2f}%, Valid-Loss: {:.3f}, Valid-Accuracy: {:.2f}% | lr={:.6f}' | ||||||
|  |     print(template.format(time_string(), epoch+1, xargs.epochs, | ||||||
|  |                           train_loss.result(), | ||||||
|  |                           train_accuracy.result()*100, | ||||||
|  |                           valid_loss.result(), | ||||||
|  |                           valid_accuracy.result()*100, | ||||||
|  |                           cur_lr)) | ||||||
|  |     print('{:} genotype : {:}\n{:}\n'.format(time_string(), genotype, model.get_np_alphas())) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |   parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||||
|  |   # training details | ||||||
|  |   parser.add_argument('--epochs'            , type=int  ,   default= 250  ,   help='') | ||||||
|  |   parser.add_argument('--w_lr_max'          , type=float,   default= 0.025,   help='') | ||||||
|  |   parser.add_argument('--w_lr_min'          , type=float,   default= 0.001,   help='') | ||||||
|  |   parser.add_argument('--w_weight_decay'    , type=float,   default=0.0005,   help='') | ||||||
|  |   parser.add_argument('--w_momentum'        , type=float,   default= 0.9  ,   help='') | ||||||
|  |   parser.add_argument('--arch_learning_rate', type=float,   default=0.0003,   help='') | ||||||
|  |   parser.add_argument('--arch_weight_decay' , type=float,   default=0.001,    help='') | ||||||
|  |   # marco structure | ||||||
|  |   parser.add_argument('--channel'           , type=int  ,   default=16,       help='') | ||||||
|  |   parser.add_argument('--num_cells'         , type=int  ,   default= 5,       help='') | ||||||
|  |   parser.add_argument('--max_nodes'         , type=int  ,   default= 4,       help='') | ||||||
|  |   args = parser.parse_args() | ||||||
|  |   main( args ) | ||||||
							
								
								
									
										46
									
								
								exps-tf/test-invH.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								exps-tf/test-invH.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,46 @@ | |||||||
|  | import os, sys, math, time, random, argparse | ||||||
|  | import tensorflow as tf | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_a(): | ||||||
|  |   x = tf.Variable([[1.], [2.], [4.0]]) | ||||||
|  |   with tf.GradientTape(persistent=True) as g: | ||||||
|  |     trn = tf.math.exp(tf.math.reduce_sum(x)) | ||||||
|  |     val = tf.math.cos(tf.math.reduce_sum(x)) | ||||||
|  |     dT_dx = g.gradient(trn, x) | ||||||
|  |     dV_dx = g.gradient(val, x) | ||||||
|  |     hess_vector = g.gradient(dT_dx, x, output_gradients=dV_dx) | ||||||
|  |   print ('calculate ok : {:}'.format(hess_vector)) | ||||||
|  |  | ||||||
|  | def test_b(): | ||||||
|  |   cce = tf.keras.losses.SparseCategoricalCrossentropy() | ||||||
|  |   L1 = tf.convert_to_tensor([0, 1, 2]) | ||||||
|  |   L2 = tf.convert_to_tensor([2, 0, 1]) | ||||||
|  |   B = tf.Variable([[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]]) | ||||||
|  |   with tf.GradientTape(persistent=True) as g: | ||||||
|  |     trn = cce(L1, B) | ||||||
|  |     val = cce(L2, B) | ||||||
|  |     dT_dx = g.gradient(trn, B) | ||||||
|  |     dV_dx = g.gradient(val, B) | ||||||
|  |     hess_vector = g.gradient(dT_dx, B, output_gradients=dV_dx) | ||||||
|  |   print ('calculate ok : {:}'.format(hess_vector)) | ||||||
|  |  | ||||||
|  | def test_c(): | ||||||
|  |   cce = tf.keras.losses.CategoricalCrossentropy() | ||||||
|  |   L1 = tf.convert_to_tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) | ||||||
|  |   L2 = tf.convert_to_tensor([[0., 0., 1.], [0., 1., 0.], [1., 0., 0.]]) | ||||||
|  |   B = tf.Variable([[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]]) | ||||||
|  |   with tf.GradientTape(persistent=True) as g: | ||||||
|  |     trn = cce(L1, B) | ||||||
|  |     val = cce(L2, B) | ||||||
|  |     dT_dx = g.gradient(trn, B) | ||||||
|  |     dV_dx = g.gradient(val, B) | ||||||
|  |     hess_vector = g.gradient(dT_dx, B, output_gradients=dV_dx) | ||||||
|  |   print ('calculate ok : {:}'.format(hess_vector)) | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |   print(tf.__version__) | ||||||
|  |   test_c() | ||||||
|  |   #test_b() | ||||||
|  |   #test_a() | ||||||
| @@ -9,7 +9,7 @@ __all__ = ['get_cell_based_tiny_net', 'get_search_spaces'] | |||||||
|  |  | ||||||
| # the cell-based NAS models | # the cell-based NAS models | ||||||
| def get_cell_based_tiny_net(config): | def get_cell_based_tiny_net(config): | ||||||
|   group_names = ['GDAS'] |   group_names = ['GDAS', 'DARTS'] | ||||||
|   if config.name in group_names: |   if config.name in group_names: | ||||||
|     from .cell_searchs import nas_super_nets |     from .cell_searchs import nas_super_nets | ||||||
|     from .cell_operations import SearchSpaceNames |     from .cell_operations import SearchSpaceNames | ||||||
|   | |||||||
| @@ -2,5 +2,7 @@ | |||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
| ################################################## | ################################################## | ||||||
| from .search_model_gdas     import TinyNetworkGDAS | from .search_model_gdas     import TinyNetworkGDAS | ||||||
|  | from .search_model_darts    import TinyNetworkDARTS | ||||||
|  |  | ||||||
| nas_super_nets = {'GDAS': TinyNetworkGDAS} | nas_super_nets = {'GDAS' : TinyNetworkGDAS, | ||||||
|  |                   'DARTS': TinyNetworkDARTS} | ||||||
|   | |||||||
							
								
								
									
										83
									
								
								lib/tf_models/cell_searchs/search_model_darts.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								lib/tf_models/cell_searchs/search_model_darts.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,83 @@ | |||||||
|  | import tensorflow as tf | ||||||
|  | import numpy as np | ||||||
|  | from copy import deepcopy | ||||||
|  | from ..cell_operations import ResNetBasicblock | ||||||
|  | from .search_cells     import NAS201SearchCell as SearchCell | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TinyNetworkDARTS(tf.keras.Model): | ||||||
|  |  | ||||||
|  |   def __init__(self, C, N, max_nodes, num_classes, search_space, affine): | ||||||
|  |     super(TinyNetworkDARTS, self).__init__() | ||||||
|  |     self._C        = C | ||||||
|  |     self._layerN   = N | ||||||
|  |     self.max_nodes = max_nodes | ||||||
|  |     self.stem = tf.keras.Sequential([ | ||||||
|  |                     tf.keras.layers.Conv2D(16, 3, 1, padding='same', use_bias=False), | ||||||
|  |                     tf.keras.layers.BatchNormalization()], name='stem') | ||||||
|  |    | ||||||
|  |     layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N     | ||||||
|  |     layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||||
|  |  | ||||||
|  |     C_prev, num_edge, edge2index = C, None, None | ||||||
|  |     for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): | ||||||
|  |       cell_prefix = 'cell-{:03d}'.format(index) | ||||||
|  |       #with tf.name_scope(cell_prefix) as scope: | ||||||
|  |       if reduction: | ||||||
|  |         cell = ResNetBasicblock(C_prev, C_curr, 2) | ||||||
|  |       else: | ||||||
|  |         cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine) | ||||||
|  |         if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index | ||||||
|  |         else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) | ||||||
|  |       C_prev = cell.out_dim | ||||||
|  |       setattr(self, cell_prefix, cell) | ||||||
|  |     self.num_layers = len(layer_reductions) | ||||||
|  |     self.op_names   = deepcopy( search_space ) | ||||||
|  |     self.edge2index = edge2index | ||||||
|  |     self.num_edge   = num_edge | ||||||
|  |     self.lastact    = tf.keras.Sequential([ | ||||||
|  |                         tf.keras.layers.BatchNormalization(), | ||||||
|  |                         tf.keras.layers.ReLU(), | ||||||
|  |                         tf.keras.layers.GlobalAvgPool2D(), | ||||||
|  |                         tf.keras.layers.Flatten(), | ||||||
|  |                         tf.keras.layers.Dense(num_classes, activation='softmax')], name='lastact') | ||||||
|  |     #self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) | ||||||
|  |     arch_init = tf.random_normal_initializer(mean=0, stddev=0.001) | ||||||
|  |     self.arch_parameters = tf.Variable(initial_value=arch_init(shape=(num_edge, len(search_space)), dtype='float32'), trainable=True, name='arch-encoding') | ||||||
|  |  | ||||||
|  |   def get_alphas(self): | ||||||
|  |     xlist = self.trainable_variables | ||||||
|  |     return [x for x in xlist if 'arch-encoding' in x.name] | ||||||
|  |  | ||||||
|  |   def get_weights(self): | ||||||
|  |     xlist = self.trainable_variables | ||||||
|  |     return [x for x in xlist if 'arch-encoding' not in x.name] | ||||||
|  |  | ||||||
|  |   def get_np_alphas(self): | ||||||
|  |     arch_nps = self.arch_parameters.numpy() | ||||||
|  |     arch_ops = np.exp(arch_nps) / np.sum(np.exp(arch_nps), axis=-1, keepdims=True) | ||||||
|  |     return arch_ops | ||||||
|  |  | ||||||
|  |   def genotype(self): | ||||||
|  |     genotypes, arch_nps = [], self.arch_parameters.numpy() | ||||||
|  |     for i in range(1, self.max_nodes): | ||||||
|  |       xlist = [] | ||||||
|  |       for j in range(i): | ||||||
|  |         node_str = '{:}<-{:}'.format(i, j) | ||||||
|  |         weights = arch_nps[ self.edge2index[node_str] ] | ||||||
|  |         op_name = self.op_names[ weights.argmax().item() ] | ||||||
|  |         xlist.append((op_name, j)) | ||||||
|  |       genotypes.append( tuple(xlist) ) | ||||||
|  |     return genotypes | ||||||
|  |  | ||||||
|  |   def call(self, inputs, training): | ||||||
|  |     weightss = tf.nn.softmax(self.arch_parameters, axis=1) | ||||||
|  |     feature = self.stem(inputs, training) | ||||||
|  |     for idx in range(self.num_layers): | ||||||
|  |       cell = getattr(self, 'cell-{:03d}'.format(idx)) | ||||||
|  |       if isinstance(cell, SearchCell): | ||||||
|  |         feature = cell.call(feature, weightss, training) | ||||||
|  |       else: | ||||||
|  |         feature = cell(feature, training) | ||||||
|  |     logits = self.lastact(feature, training) | ||||||
|  |     return logits | ||||||
		Reference in New Issue
	
	Block a user