Prototype generic nas model (cont.) for GDAS.

This commit is contained in:
D-X-Y 2020-07-20 08:45:41 +00:00
parent 5cf66d24a1
commit 8d27050f6f
2 changed files with 22 additions and 9 deletions

View File

@ -377,8 +377,7 @@ def main(xargs):
start_epoch = last_info['epoch'] start_epoch = last_info['epoch']
checkpoint = torch.load(last_info['last_checkpoint']) checkpoint = torch.load(last_info['last_checkpoint'])
genotypes = checkpoint['genotypes'] genotypes = checkpoint['genotypes']
if xargs.algo == 'enas': baseline = checkpoint['baseline']
baseline = checkpoint['baseline']
valid_accuracies = checkpoint['valid_accuracies'] valid_accuracies = checkpoint['valid_accuracies']
search_model.load_state_dict( checkpoint['search_model'] ) search_model.load_state_dict( checkpoint['search_model'] )
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
@ -401,7 +400,7 @@ def main(xargs):
network.set_drop_path(float(epoch+1) / total_epoch, xargs.drop_path_rate) network.set_drop_path(float(epoch+1) / total_epoch, xargs.drop_path_rate)
if xargs.algo == 'gdas': if xargs.algo == 'gdas':
network.set_tau( xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1) ) network.set_tau( xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1) )
logger.log('[Reset tau as : {:}'.format(network.tau)) logger.log('[RESET tau as : {:} and drop_path as {:}]'.format(network.tau, network.drop_path))
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, xargs.algo, logger) = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, xargs.algo, logger)
search_time.update(time.time() - start_time) search_time.update(time.time() - start_time)
@ -423,6 +422,7 @@ def main(xargs):
network.set_cal_mode('urs', None) network.set_cal_mode('urs', None)
else: else:
raise ValueError('Invalid algorithm name : {:}'.format(xargs.algo)) raise ValueError('Invalid algorithm name : {:}'.format(xargs.algo))
logger.log('[{:}] - [get_best_arch] : {:} -> {:}'.format(epoch_str, genotype, temp_accuracy))
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion, xargs.algo, logger) valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion, xargs.algo, logger)
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype)) logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype))
valid_accuracies[epoch] = valid_a_top1 valid_accuracies[epoch] = valid_a_top1
@ -494,7 +494,7 @@ if __name__ == '__main__':
parser.add_argument('--eval_candidate_num', type=int, default=100, help='The number of selected architectures to evaluate.') parser.add_argument('--eval_candidate_num', type=int, default=100, help='The number of selected architectures to evaluate.')
# #
parser.add_argument('--track_running_stats',type=int, default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') parser.add_argument('--track_running_stats',type=int, default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
parser.add_argument('--affine' , type=int, default=1, choices=[0,1],help='Whether use affine=True or False in the BN layer.') parser.add_argument('--affine' , type=int, default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.')
parser.add_argument('--config_path' , type=str, default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.') parser.add_argument('--config_path' , type=str, default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.')
# architecture leraning rate # architecture leraning rate
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')

View File

@ -102,17 +102,18 @@ class GenericNAS201Model(nn.Module):
self._op_names = deepcopy(search_space) self._op_names = deepcopy(search_space)
self._Layer = len(self._cells) self._Layer = len(self._cells)
self.edge2index = edge2index self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev, affine=affine, track_running_stats=track_running_stats), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1) self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes) self.classifier = nn.Linear(C_prev, num_classes)
self._num_edge = num_edge self._num_edge = num_edge
# algorithm related # algorithm related
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) self.arch_parameters = nn.Parameter(1e-3*torch.randn(num_edge, len(search_space)))
self._mode = None self._mode = None
self.dynamic_cell = None self.dynamic_cell = None
self._tau = None self._tau = None
self._algo = None self._algo = None
self._drop_path = None self._drop_path = None
self.verbose = False
def set_algo(self, algo: Text): def set_algo(self, algo: Text):
# used for searching # used for searching
@ -256,33 +257,45 @@ class GenericNAS201Model(nn.Module):
else: break else: break
with torch.no_grad(): with torch.no_grad():
hardwts_cpu = hardwts.detach().cpu() hardwts_cpu = hardwts.detach().cpu()
return hardwts, hardwts_cpu, index return hardwts, hardwts_cpu, index, 'GUMBEL'
else: else:
alphas = nn.functional.softmax(self.arch_parameters, dim=-1) alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
index = alphas.max(-1, keepdim=True)[1] index = alphas.max(-1, keepdim=True)[1]
with torch.no_grad(): with torch.no_grad():
alphas_cpu = alphas.detach().cpu() alphas_cpu = alphas.detach().cpu()
return alphas, alphas_cpu, index return alphas, alphas_cpu, index, 'SOFTMAX'
def forward(self, inputs): def forward(self, inputs):
alphas, alphas_cpu, index = self.normalize_archp() alphas, alphas_cpu, index, verbose_str = self.normalize_archp()
feature = self._stem(inputs) feature = self._stem(inputs)
for i, cell in enumerate(self._cells): for i, cell in enumerate(self._cells):
if isinstance(cell, SearchCell): if isinstance(cell, SearchCell):
if self.mode == 'urs': if self.mode == 'urs':
feature = cell.forward_urs(feature) feature = cell.forward_urs(feature)
if self.verbose:
verbose_str += '-forward_urs'
elif self.mode == 'select': elif self.mode == 'select':
feature = cell.forward_select(feature, alphas_cpu) feature = cell.forward_select(feature, alphas_cpu)
if self.verbose:
verbose_str += '-forward_select'
elif self.mode == 'joint': elif self.mode == 'joint':
feature = cell.forward_joint(feature, alphas) feature = cell.forward_joint(feature, alphas)
if self.verbose:
verbose_str += '-forward_joint'
elif self.mode == 'dynamic': elif self.mode == 'dynamic':
feature = cell.forward_dynamic(feature, self.dynamic_cell) feature = cell.forward_dynamic(feature, self.dynamic_cell)
if self.verbose:
verbose_str += '-forward_dynamic'
elif self.mode == 'gdas': elif self.mode == 'gdas':
feature = cell.forward_gdas(feature, alphas, index) feature = cell.forward_gdas(feature, alphas, index)
if self.verbose:
verbose_str += '-forward_gdas'
else: raise ValueError('invalid mode={:}'.format(self.mode)) else: raise ValueError('invalid mode={:}'.format(self.mode))
else: feature = cell(feature) else: feature = cell(feature)
if self.drop_path is not None: if self.drop_path is not None:
feature = drop_path(feature, self.drop_path) feature = drop_path(feature, self.drop_path)
if self.verbose and random.random() < 0.001:
print(verbose_str)
out = self.lastact(feature) out = self.lastact(feature)
out = self.global_pooling(out) out = self.global_pooling(out)
out = out.view(out.size(0), -1) out = out.view(out.size(0), -1)