update the taskmodel

This commit is contained in:
mhz 2024-06-30 16:39:42 +02:00
parent 66fe70028e
commit 7274b3f606

View File

@ -15,6 +15,17 @@ from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.Chem import rdMolDescriptors
rdBase.DisableLog('rdApp.error')
import json
op_type = {
'nor_conv_1x1': 1,
'nor_conv_3x3': 2,
'avg_pool_3x3': 3,
'skip_connect': 4,
'output': 5,
'none': 6,
'input': 7
}
task_to_colname = {
'hiv_b': 'HIV_active',
@ -32,8 +43,10 @@ tasktype_name = {
'O2': 'regression',
'N2': 'regression',
'CO2': 'regression',
'nasbench201': 'regression',
}
class TaskModel():
"""Scores based on an ECFP classifier."""
def __init__(self, model_path, task_name):
@ -55,8 +68,47 @@ class TaskModel():
perfermance = self.train()
dump(self.model, model_path)
print('Oracle peformance: ', perfermance)
def train(self):
def read_adj_ops_from_json(filename):
with open(filename, 'r') as json_file:
data = json.load(json_file)
adj_ops_pairs = []
for item in data:
adj_matrix = np.array(item['adj_matrix'])
ops = item['ops']
acc = item['train'][0]['accuracy']
adj_ops_pairs.append((adj_matrix, ops, acc))
return adj_ops_pairs
def feature_from_adj_and_ops(adj, ops):
return np.concatenate([adj.flatten(), ops])
filename = '/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json'
graphs = read_adj_ops_from_json(filename)
adjs = []
opss = []
accs = []
features = []
for graph in graphs:
adj, ops, acc=graph
op_code = [op_type[op] for op in ops]
adjs.append(adj)
opss.append(op_code)
accs.append(acc)
features.append(feature_from_adj_and_ops(adj, op_code))
features = np.array(features)
labels = np.array(accs)
mask = ~np.isnan(labels)
labels = labels[mask]
features = features[mask]
self.model.fit(features, labels)
y_pred = self.model.predict(features)
perf = self.metric_func(labels, y_pred)
print(f'{self.task_name} performance: {perf}')
return perf
def train__(self):
data_path = os.path.dirname(self.model_path)
data_path = os.path.join(os.path.dirname(self.model_path), '..', f'raw/{self.task_name}.csv.gz')
df = pd.read_csv(data_path)