From 5a1dc897569c2ed7703191e999e74d8a60c23f91 Mon Sep 17 00:00:00 2001 From: HamsterMimi Date: Thu, 4 May 2023 13:41:59 +0800 Subject: [PATCH] update --- sota/cnn/init_projection.py | 11 +++++------ zerocostnas/post_validate.py | 11 +++++------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/sota/cnn/init_projection.py b/sota/cnn/init_projection.py index e1d6a6a..9393be3 100644 --- a/sota/cnn/init_projection.py +++ b/sota/cnn/init_projection.py @@ -96,12 +96,11 @@ def project_op(model, input, target, args, cell_type, proj_queue=None, selected_ model.candidate_flags[cell_type][selected_eid] = False # print(model.get_projected_weights()) - else: - measures = predictive.find_measures(model, - proj_queue, - ('random', 1, n_classes), - torch.device("cuda"), - measure_names=[proj_crit]) + measures = predictive.find_measures(model, + proj_queue, + ('random', 1, n_classes), + torch.device("cuda"), + measure_names=[proj_crit]) # print(measures) for idx in range(num_ops): diff --git a/zerocostnas/post_validate.py b/zerocostnas/post_validate.py index a5b6a4d..7617494 100644 --- a/zerocostnas/post_validate.py +++ b/zerocostnas/post_validate.py @@ -223,12 +223,11 @@ def main(): else: #score = score_loop(network, None, train_queue, args.gpu, None, args.proj_crit) network.requires_feature = False - else: - measures = predictive.find_measures(network, - train_queue, - ('random', 1, n_classes), - torch.device("cuda"), - measure_names=[args.proj_crit]) + measures = predictive.find_measures(network, + train_queue, + ('random', 1, n_classes), + torch.device("cuda"), + measure_names=[args.proj_crit])