Clean unnecessary files
This commit is contained in:
parent
178d84a7e5
commit
672a9ef0db
@ -1,69 +0,0 @@
|
||||
import os, sys, time, queue, torch
|
||||
from pathlib import Path
|
||||
|
||||
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from log_utils import time_string
|
||||
from aa_nas_api import AANASBenchAPI, ArchResults
|
||||
from models import CellStructure
|
||||
|
||||
|
||||
def get_unique_matrix(archs, consider_zero):
|
||||
UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs]
|
||||
print ('{:} create unique-string ({:}/{:}) done'.format(time_string(), len(set(UniquStrs)), len(UniquStrs)))
|
||||
Unique2Index = dict()
|
||||
for index, xstr in enumerate(UniquStrs):
|
||||
if xstr not in Unique2Index: Unique2Index[xstr] = list()
|
||||
Unique2Index[xstr].append( index )
|
||||
sm_matrix = torch.eye(len(archs)).bool()
|
||||
for _, xlist in Unique2Index.items():
|
||||
for i in xlist:
|
||||
for j in xlist:
|
||||
sm_matrix[i,j] = True
|
||||
unique_ids, unique_num = [-1 for _ in archs], 0
|
||||
for i in range(len(unique_ids)):
|
||||
if unique_ids[i] > -1: continue
|
||||
neighbours = sm_matrix[i].nonzero().view(-1).tolist()
|
||||
for nghb in neighbours:
|
||||
assert unique_ids[nghb] == -1, 'impossible'
|
||||
unique_ids[nghb] = unique_num
|
||||
unique_num += 1
|
||||
return sm_matrix, unique_ids, unique_num
|
||||
|
||||
|
||||
def check_unique_arch():
|
||||
print ('{:} start'.format(time_string()))
|
||||
meta_info = torch.load('./output/AA-NAS-BENCH-4/meta-node-4.pth')
|
||||
arch_strs = meta_info['archs']
|
||||
archs = [CellStructure.str2structure(arch_str) for arch_str in arch_strs]
|
||||
"""
|
||||
for i, arch in enumerate(archs):
|
||||
if not arch.check_valid():
|
||||
print('{:05d} {:}'.format(i, arch))
|
||||
#start = int(i / 390.) * 390
|
||||
#xxend = start + 389
|
||||
#print ('/home/dxy/search-configures/output/TINY-NAS-BENCHMARK-4/{:06d}-{:06d}-C16-N5/arch-{:06d}-seed-0888.pth'.format(start, xxend, i))
|
||||
"""
|
||||
print ('There are {:} valid-archs'.format( sum(arch.check_valid() for arch in archs) ))
|
||||
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(archs, False)
|
||||
save_dir = './output/cell-search-tiny/same-matrix.pth'
|
||||
torch.save(sm_matrix, save_dir)
|
||||
print ('{:} There are {:} unique architectures (not considering zero).'.format(time_string(), unique_num))
|
||||
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(archs, True)
|
||||
print ('{:} There are {:} unique architectures (considering zero).'.format(time_string(), unique_num))
|
||||
|
||||
|
||||
def test_aa_nas_api():
|
||||
arch_result = ArchResults.create_from_state_dict('output/AA-NAS-BENCH-4/simplifies/architectures/000002-FULL.pth')
|
||||
arch_result.show(True)
|
||||
result = arch_result.query('cifar100')
|
||||
#xfile = 'output/AA-NAS-BENCH-4/simplifies/000000-000389-C16-N5.pth'
|
||||
api = AANASBenchAPI('output/AA-NAS-BENCH-4/simplifies/C16-N5-final-infos.pth')
|
||||
results = api.query_by_index(1, 'cifar100')
|
||||
print ('There are {:} trials for this architecture [{:}] on cifar10'.format(len(results), api[1]))
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
if __name__ == '__main__':
|
||||
#check_unique_arch()
|
||||
test_aa_nas_api()
|
Loading…
Reference in New Issue
Block a user