diff --git a/exps/algos-v2/search-cell.py b/exps/algos-v2/search-cell.py index 7804cfa..976bc0a 100644 --- a/exps/algos-v2/search-cell.py +++ b/exps/algos-v2/search-cell.py @@ -459,7 +459,7 @@ def main(xargs): # the final post procedure : count the time start_time = time.time() genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.eval_candidate_num, xargs.algo) - if xargs.algo == 'setn': + if xargs.algo == 'setn' or xargs.algo == 'enas': network.set_cal_mode('dynamic', genotype) elif xargs.algo == 'gdas': network.set_cal_mode('gdas', None)