update the taskmodel
This commit is contained in:
		| @@ -15,6 +15,17 @@ from rdkit.Chem import AllChem | |||||||
| from rdkit import DataStructs | from rdkit import DataStructs | ||||||
| from rdkit.Chem import rdMolDescriptors | from rdkit.Chem import rdMolDescriptors | ||||||
| rdBase.DisableLog('rdApp.error') | rdBase.DisableLog('rdApp.error') | ||||||
|  | import json | ||||||
|  |  | ||||||
|  | op_type = { | ||||||
|  |     'nor_conv_1x1': 1, | ||||||
|  |     'nor_conv_3x3': 2, | ||||||
|  |     'avg_pool_3x3': 3, | ||||||
|  |     'skip_connect': 4, | ||||||
|  |     'output': 5, | ||||||
|  |     'none': 6, | ||||||
|  |     'input': 7 | ||||||
|  | } | ||||||
|  |  | ||||||
| task_to_colname = { | task_to_colname = { | ||||||
|     'hiv_b': 'HIV_active', |     'hiv_b': 'HIV_active', | ||||||
| @@ -32,8 +43,10 @@ tasktype_name = { | |||||||
|     'O2': 'regression', |     'O2': 'regression', | ||||||
|     'N2': 'regression', |     'N2': 'regression', | ||||||
|     'CO2': 'regression', |     'CO2': 'regression', | ||||||
|  |     'nasbench201': 'regression', | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| class TaskModel(): | class TaskModel(): | ||||||
|     """Scores based on an ECFP classifier.""" |     """Scores based on an ECFP classifier.""" | ||||||
|     def __init__(self, model_path, task_name): |     def __init__(self, model_path, task_name): | ||||||
| @@ -55,8 +68,47 @@ class TaskModel(): | |||||||
|             perfermance = self.train() |             perfermance = self.train() | ||||||
|             dump(self.model, model_path) |             dump(self.model, model_path) | ||||||
|             print('Oracle peformance: ', perfermance) |             print('Oracle peformance: ', perfermance) | ||||||
|  |  | ||||||
|     def train(self): |     def train(self): | ||||||
|  |         def read_adj_ops_from_json(filename): | ||||||
|  |             with open(filename, 'r') as json_file: | ||||||
|  |                 data = json.load(json_file) | ||||||
|  |  | ||||||
|  |             adj_ops_pairs = [] | ||||||
|  |             for item in data: | ||||||
|  |                 adj_matrix = np.array(item['adj_matrix']) | ||||||
|  |                 ops = item['ops'] | ||||||
|  |                 acc = item['train'][0]['accuracy'] | ||||||
|  |                 adj_ops_pairs.append((adj_matrix, ops, acc)) | ||||||
|  |              | ||||||
|  |             return adj_ops_pairs | ||||||
|  |         def feature_from_adj_and_ops(adj, ops): | ||||||
|  |             return np.concatenate([adj.flatten(), ops]) | ||||||
|  |         filename = '/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json' | ||||||
|  |         graphs = read_adj_ops_from_json(filename) | ||||||
|  |         adjs = [] | ||||||
|  |         opss = [] | ||||||
|  |         accs = [] | ||||||
|  |         features = [] | ||||||
|  |         for graph in graphs: | ||||||
|  |             adj, ops, acc=graph | ||||||
|  |             op_code = [op_type[op] for op in ops] | ||||||
|  |             adjs.append(adj) | ||||||
|  |             opss.append(op_code) | ||||||
|  |             accs.append(acc) | ||||||
|  |             features.append(feature_from_adj_and_ops(adj, op_code)) | ||||||
|  |         features = np.array(features) | ||||||
|  |         labels = np.array(accs) | ||||||
|  |  | ||||||
|  |         mask = ~np.isnan(labels) | ||||||
|  |         labels = labels[mask] | ||||||
|  |         features = features[mask] | ||||||
|  |         self.model.fit(features, labels) | ||||||
|  |         y_pred = self.model.predict(features) | ||||||
|  |         perf = self.metric_func(labels, y_pred) | ||||||
|  |         print(f'{self.task_name} performance: {perf}') | ||||||
|  |         return perf | ||||||
|  |  | ||||||
|  |     def train__(self): | ||||||
|         data_path = os.path.dirname(self.model_path) |         data_path = os.path.dirname(self.model_path) | ||||||
|         data_path = os.path.join(os.path.dirname(self.model_path), '..', f'raw/{self.task_name}.csv.gz') |         data_path = os.path.join(os.path.dirname(self.model_path), '..', f'raw/{self.task_name}.csv.gz') | ||||||
|         df = pd.read_csv(data_path) |         df = pd.read_csv(data_path) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user