From 6d9db64a48095fb226b73c1fb69571928897975c Mon Sep 17 00:00:00 2001 From: mhz Date: Wed, 21 Aug 2024 10:26:02 +0200 Subject: [PATCH] explore the 201 space script --- graph_dit/exp_201/main.py | 82 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 graph_dit/exp_201/main.py diff --git a/graph_dit/exp_201/main.py b/graph_dit/exp_201/main.py new file mode 100644 index 0000000..2cad786 --- /dev/null +++ b/graph_dit/exp_201/main.py @@ -0,0 +1,82 @@ + +import matplotlib.pyplot as plt +import pandas as pd +from nas_201_api import NASBench201API as API +from naswot.score_networks import get_nasbench201_idx_score +from naswot import datasets as dt +from naswot import nasspace + +class Args(): + pass +args = Args() +args.trainval = True +args.augtype = 'none' +args.repeat = 1 +args.score = 'hook_logdet' +args.sigma = 0.05 +args.nasspace = 'nasbench201' +args.batch_size = 128 +args.GPU = '0' +args.dataset = 'cifar10' +args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' +args.data_loc = '../cifardata/' +args.seed = 777 +args.init = '' +args.save_loc = 'results' +args.save_string = 'naswot' +args.dropout = False +args.maxofn = 1 +args.n_samples = 100 +args.n_runs = 500 +args.stem_out_channels = 16 +args.num_stacks = 3 +args.num_modules_per_stack = 3 +args.num_labels = 1 +searchspace = nasspace.get_search_space(args) +train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) +device = torch.device('cuda:2') + + + +# source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' +# api = API(source) + + + +# 示例百分数列表,精确到小数点后两位 +# percentages = [5.12, 15.78, 25.43, 35.22, 45.99, 55.34, 65.12, 75.68, 85.99, 95.25, 23.45, 12.34, 37.89, 58.67, 64.23, 72.15, 81.76, 99.99, 42.11, 61.58, 77.34, 14.56] +percentages = [] + +len_201 = 15625 + +for i in range(len_201): + percentage = get_nasbench201_idx_score(i, train_loader, searchspace, args, device) + percentages.append(percentage) + +# 定义10%区间 +bins = [i for i in range(0, 101, 10)] + +# 对数据进行分箱,计算每个区间的数据量 +hist, bin_edges = pd.cut(percentages, bins=bins, right=False, retbins=True, include_lowest=True) +bin_counts = hist.value_counts().sort_index() + +total_counts = len(percentages) +percentages_in_bins = (bin_counts / total_counts) * 100 + +# 绘制条形图 +plt.figure(figsize=(10, 6)) +bars = plt.bar(bin_counts.index.astype(str), bin_counts.values, width=0.9, color='skyblue') + +for bar, percentage in zip(bars, percentages_in_bins): + plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), + f'{percentage:.2f}%', ha='center', va='bottom') + +# 添加标题和标签 +plt.title('Distribution of Percentages in 10% Intervals') +plt.xlabel('Percentage Interval') +plt.ylabel('Count') + +# 显示图表 +plt.xticks(rotation=45) +plt.savefig('barplog.png') +