update
This commit is contained in:
		| @@ -96,12 +96,11 @@ def project_op(model, input, target, args, cell_type, proj_queue=None, selected_ | |||||||
|  |  | ||||||
|             model.candidate_flags[cell_type][selected_eid] = False |             model.candidate_flags[cell_type][selected_eid] = False | ||||||
|             # print(model.get_projected_weights()) |             # print(model.get_projected_weights()) | ||||||
|             else: |             measures = predictive.find_measures(model, | ||||||
|                 measures = predictive.find_measures(model, |                                                 proj_queue, | ||||||
|                                                     proj_queue, |                                                 ('random', 1, n_classes), | ||||||
|                                                     ('random', 1, n_classes), |                                                 torch.device("cuda"), | ||||||
|                                                     torch.device("cuda"), |                                                 measure_names=[proj_crit]) | ||||||
|                                                     measure_names=[proj_crit]) |  | ||||||
|  |  | ||||||
|             # print(measures) |             # print(measures) | ||||||
|             for idx in range(num_ops): |             for idx in range(num_ops): | ||||||
|   | |||||||
| @@ -223,12 +223,11 @@ def main(): | |||||||
|                 else: |                 else: | ||||||
|                     #score =  score_loop(network, None, train_queue, args.gpu, None, args.proj_crit) |                     #score =  score_loop(network, None, train_queue, args.gpu, None, args.proj_crit) | ||||||
|                     network.requires_feature = False |                     network.requires_feature = False | ||||||
|                     else: |                     measures = predictive.find_measures(network, | ||||||
|                         measures = predictive.find_measures(network, |                                                         train_queue, | ||||||
|                                                             train_queue, |                                                         ('random', 1, n_classes), | ||||||
|                                                             ('random', 1, n_classes), |                                                         torch.device("cuda"), | ||||||
|                                                             torch.device("cuda"), |                                                         measure_names=[args.proj_crit]) | ||||||
|                                                             measure_names=[args.proj_crit]) |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user