update the taskmodel
This commit is contained in:
		| @@ -15,6 +15,17 @@ from rdkit.Chem import AllChem | ||||
| from rdkit import DataStructs | ||||
| from rdkit.Chem import rdMolDescriptors | ||||
| rdBase.DisableLog('rdApp.error') | ||||
| import json | ||||
|  | ||||
| op_type = { | ||||
|     'nor_conv_1x1': 1, | ||||
|     'nor_conv_3x3': 2, | ||||
|     'avg_pool_3x3': 3, | ||||
|     'skip_connect': 4, | ||||
|     'output': 5, | ||||
|     'none': 6, | ||||
|     'input': 7 | ||||
| } | ||||
|  | ||||
| task_to_colname = { | ||||
|     'hiv_b': 'HIV_active', | ||||
| @@ -32,8 +43,10 @@ tasktype_name = { | ||||
|     'O2': 'regression', | ||||
|     'N2': 'regression', | ||||
|     'CO2': 'regression', | ||||
|     'nasbench201': 'regression', | ||||
| } | ||||
|  | ||||
|  | ||||
| class TaskModel(): | ||||
|     """Scores based on an ECFP classifier.""" | ||||
|     def __init__(self, model_path, task_name): | ||||
| @@ -55,8 +68,47 @@ class TaskModel(): | ||||
|             perfermance = self.train() | ||||
|             dump(self.model, model_path) | ||||
|             print('Oracle peformance: ', perfermance) | ||||
|  | ||||
|     def train(self): | ||||
|         def read_adj_ops_from_json(filename): | ||||
|             with open(filename, 'r') as json_file: | ||||
|                 data = json.load(json_file) | ||||
|  | ||||
|             adj_ops_pairs = [] | ||||
|             for item in data: | ||||
|                 adj_matrix = np.array(item['adj_matrix']) | ||||
|                 ops = item['ops'] | ||||
|                 acc = item['train'][0]['accuracy'] | ||||
|                 adj_ops_pairs.append((adj_matrix, ops, acc)) | ||||
|              | ||||
|             return adj_ops_pairs | ||||
|         def feature_from_adj_and_ops(adj, ops): | ||||
|             return np.concatenate([adj.flatten(), ops]) | ||||
|         filename = '/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json' | ||||
|         graphs = read_adj_ops_from_json(filename) | ||||
|         adjs = [] | ||||
|         opss = [] | ||||
|         accs = [] | ||||
|         features = [] | ||||
|         for graph in graphs: | ||||
|             adj, ops, acc=graph | ||||
|             op_code = [op_type[op] for op in ops] | ||||
|             adjs.append(adj) | ||||
|             opss.append(op_code) | ||||
|             accs.append(acc) | ||||
|             features.append(feature_from_adj_and_ops(adj, op_code)) | ||||
|         features = np.array(features) | ||||
|         labels = np.array(accs) | ||||
|  | ||||
|         mask = ~np.isnan(labels) | ||||
|         labels = labels[mask] | ||||
|         features = features[mask] | ||||
|         self.model.fit(features, labels) | ||||
|         y_pred = self.model.predict(features) | ||||
|         perf = self.metric_func(labels, y_pred) | ||||
|         print(f'{self.task_name} performance: {perf}') | ||||
|         return perf | ||||
|  | ||||
|     def train__(self): | ||||
|         data_path = os.path.dirname(self.model_path) | ||||
|         data_path = os.path.join(os.path.dirname(self.model_path), '..', f'raw/{self.task_name}.csv.gz') | ||||
|         df = pd.read_csv(data_path) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user