diffusionNAG/MobileNetV3/analysis/arch_metrics.py
2024-03-15 14:38:51 +00:00

114 lines
5.5 KiB
Python

from analysis.arch_functions import compute_arch_metrics, compute_arch_metrics_meta
from torch import Tensor
import wandb
import torch.nn as nn
class SamplingArchMetrics(nn.Module):
def __init__(self, config, train_ds, exp_name):
super().__init__()
self.exp_name = exp_name
self.train_ds = train_ds
if config.data.name == 'ofa':
self.train_arch_str_list = train_ds.x_list_
else:
self.train_arch_str_list = train_ds.arch_str_list_
self.name = config.data.name
self.except_inout = config.data.except_inout
self.data_root = config.data.root
def forward(self, arch_list: list, adj, mask, this_sample_dir, test=False, timestep=None):
"""_summary_
:params arch_list: list of archs
:params adj: [batch_size, num_nodes, num_nodes]
:params mask: [batch_size, num_nodes, num_nodes]
"""
arch_metrics, all_arch_str = compute_arch_metrics(
arch_list, adj, mask, self.train_arch_str_list, self.train_ds, timestep=timestep,
name=self.name, except_inout=self.except_inout, data_root=self.data_root)
# arch_metrics
# ([validity, uniqueness, novelty],
# unique,
# dict(test_acc_list=test_acc_list, flops_list=flops_list, params_list=params_list, latency_list=latency_list),
# all_arch_str)
if test and self.name != 'ofa':
with open(r'final_.txt', 'w') as fp:
for arch_str in all_arch_str:
# write each item on a new line
fp.write("%s\n" % arch_str)
print('All archs saved')
if self.name != 'ofa':
valid_unique_arch = arch_metrics[1]
valid_unique_arch_prop_dict = arch_metrics[2] # test_acc, flops, params, latency
# textfile = open(f'{this_sample_dir}/archs/{name}/valid_unique_arch_step-{current_step}.txt', "w")
textfile = open(f'{this_sample_dir}/valid_unique_archs.txt', "w")
for i in range(len(valid_unique_arch)):
textfile.write(f"Arch: {valid_unique_arch[i]} \n")
textfile.write(f"Test Acc: {valid_unique_arch_prop_dict['test_acc_list'][i]} \n")
textfile.write(f"FLOPs: {valid_unique_arch_prop_dict['flops_list'][i]} \n ")
textfile.write(f"#Params: {valid_unique_arch_prop_dict['params_list'][i]} \n")
textfile.write(f"Latency: {valid_unique_arch_prop_dict['latency_list'][i]} \n \n")
textfile.writelines(valid_unique_arch)
textfile.close()
# res_dic = {
# 'Validity': arch_metrics[0][0], 'Uniqueness': arch_metrics[0][1], 'Novelty': arch_metrics[0][2],
# 'test_acc_max': -1, 'test_acc_min':-1, 'test_acc_mean': -1, 'test_acc_std': 0,
# 'flops_max': -1, 'flops_min':-1, 'flops_mean': -1, 'flops_std': 0,
# 'params_max': -1, 'params_min':-1, 'params_mean': -1, 'params_std': 0,
# 'latency_max': -1, 'latency_min':-1, 'latency_mean': -1, 'latency_std': 0,
# }
return arch_metrics
class SamplingArchMetricsMeta(nn.Module):
def __init__(self, config, train_ds, exp_name, train_index=None, nasbench=None):
super().__init__()
self.exp_name = exp_name
self.train_ds = train_ds
self.search_space = config.data.name
if self.search_space == 'ofa':
self.train_arch_str_list = None
else:
self.train_arch_str_list = [train_ds.arch_str_list[i] for i in train_ds.idx_lst['train']]
def forward(self, arch_list: list, adj, mask, this_sample_dir, test=False,
timestep=None, check_dataname='cifar10'):
"""_summary_
:params arch_list: list of archs
:params adj: [batch_size, num_nodes, num_nodes]
:params mask: [batch_size, num_nodes, num_nodes]
"""
arch_metrics = compute_arch_metrics_meta(arch_list, adj, mask, self.train_arch_str_list,
self.train_ds, timestep=timestep, check_dataname=check_dataname,
name=self.search_space)
all_arch_str = arch_metrics[-1]
if test:
with open(r'final_.txt', 'w') as fp:
for arch_str in all_arch_str:
# write each item on a new line
fp.write("%s\n" % arch_str)
print('All archs saved')
valid_unique_arch = arch_metrics[1] # arch_str
valid_unique_arch_prop_dict = arch_metrics[2] # test_acc, flops, params, latency
# textfile = open(f'{this_sample_dir}/archs/{name}/valid_unique_arch_step-{current_step}.txt', "w")
if self.search_space != 'ofa':
textfile = open(f'{this_sample_dir}/valid_unique_archs.txt', "w")
for i in range(len(valid_unique_arch)):
textfile.write(f"Arch: {valid_unique_arch[i]} \n")
textfile.write(f"Arch Index: {valid_unique_arch_prop_dict['arch_idx_list'][i]} \n")
textfile.write(f"Test Acc: {valid_unique_arch_prop_dict['test_acc_list'][i]} \n")
textfile.write(f"FLOPs: {valid_unique_arch_prop_dict['flops_list'][i]} \n ")
textfile.write(f"#Params: {valid_unique_arch_prop_dict['params_list'][i]} \n")
textfile.write(f"Latency: {valid_unique_arch_prop_dict['latency_list'][i]} \n \n")
textfile.writelines(valid_unique_arch)
textfile.close()
return arch_metrics