diffusionNAG/MobileNetV3/main_exp/transfer_nag_lib/DeepKernelGPModules.py
2024-03-15 14:38:51 +00:00

582 lines
25 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jul 6 14:03:42 2021
@author: hsjomaa
"""
## Original packages
import torch
import torch.nn as nn
from sklearn.preprocessing import MinMaxScaler
import copy
import numpy as np
import os
# from torch.utils.tensorboard import SummaryWriter
import json
import time
## Our packages
import gpytorch
import logging
from transfer_nag_lib.DeepKernelGPHelpers import totorch,prepare_data, Metric, EI
from transfer_nag_lib.MetaD2A_nas_bench_201.generator import Generator
from transfer_nag_lib.MetaD2A_nas_bench_201.main import get_parser
np.random.seed(1203)
RandomQueryGenerator= np.random.RandomState(413)
RandomSupportGenerator= np.random.RandomState(413)
RandomTaskGenerator = np.random.RandomState(413)
class DeepKernelGP(nn.Module):
def __init__(self,X,Y,Z,kernel,backbone_fn, config, support,log_dir,seed):
super(DeepKernelGP, self).__init__()
torch.manual_seed(seed)
## GP parameters
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.X,self.Y,self.Z = X,Y,Z
self.feature_extractor = backbone_fn().to(self.device)
self.config=config
self.get_model_likelihood_mll(len(support),kernel,backbone_fn)
logging.basicConfig(filename=log_dir, level=logging.DEBUG)
def get_model_likelihood_mll(self, train_size,kernel,backbone_fn):
train_x=torch.ones(train_size, self.feature_extractor.out_features).to(self.device)
train_y=torch.ones(train_size).to(self.device)
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPLayer(train_x=train_x, train_y=train_y, likelihood=likelihood, config=self.config,
dims=self.feature_extractor.out_features)
self.model = model.to(self.device)
self.likelihood = likelihood.to(self.device)
self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model).to(self.device)
def set_forward(self, x, is_feature=False):
pass
def set_forward_loss(self, x):
pass
def train(self, support, load_model,optimizer, checkpoint=None,epochs=1000, verbose = False):
if load_model:
assert(checkpoint is not None)
print("KEYS MATCHED")
self.load_checkpoint(os.path.join(checkpoint,"weights"))
inputs,labels = prepare_data(support,support,self.X,self.Y,self.Z)
inputs,labels = totorch(inputs,device=self.device), totorch(labels.reshape(-1,),device=self.device)
losses = [np.inf]
best_loss = np.inf
starttime = time.time()
initial_weights = copy.deepcopy(self.state_dict())
patience=0
max_patience = self.config["patience"]
for _ in range(epochs):
optimizer.zero_grad()
z = self.feature_extractor(inputs)
self.model.set_train_data(inputs=z, targets=labels)
predictions = self.model(z)
try:
loss = -self.mll(predictions, self.model.train_targets)
loss.backward()
optimizer.step()
except Exception as ada:
logging.info(f"Exception {ada}")
break
if verbose:
print("Iter {iter}/{epochs} - Loss: {loss:.5f} noise: {noise:.5f}".format(
iter=_+1,epochs=epochs,loss=loss.item(),noise=self.likelihood.noise.item()))
losses.append(loss.detach().to("cpu").item())
if best_loss>losses[-1]:
best_loss = losses[-1]
weights = copy.deepcopy(self.state_dict())
if np.allclose(losses[-1],losses[-2],atol=self.config["loss_tol"]):
patience+=1
else:
patience=0
if patience>max_patience:
break
self.load_state_dict(weights)
logging.info(f"Current Iteration: {len(support)} | Incumbent {max(self.Y[support])} | Duration {np.round(time.time()-starttime)} | Epochs {_} | Noise {self.likelihood.noise.item()}")
return losses,weights,initial_weights
def load_checkpoint(self, checkpoint):
ckpt = torch.load(checkpoint,map_location=torch.device(self.device))
self.model.load_state_dict(ckpt['gp'],strict=False)
self.likelihood.load_state_dict(ckpt['likelihood'],strict=False)
self.feature_extractor.load_state_dict(ckpt['net'],strict=False)
def predict(self,support, query_range=None, noise_fn=None):
card = len(self.Y)
if noise_fn:
self.Y = noise_fn(self.Y)
x_support,y_support = prepare_data(support,support,
self.X,self.Y,self.Z)
if query_range is None:
x_query,_ = prepare_data(np.arange(card),support,
self.X,self.Y,self.Z)
else:
x_query,_ = prepare_data(query_range,support,
self.X,self.Y,self.Z)
self.model.eval()
self.feature_extractor.eval()
self.likelihood.eval()
z_support = self.feature_extractor(totorch(x_support,self.device)).detach()
self.model.set_train_data(inputs=z_support, targets=totorch(y_support.reshape(-1,),self.device), strict=False)
with torch.no_grad():
z_query = self.feature_extractor(totorch(x_query,self.device)).detach()
pred = self.likelihood(self.model(z_query))
mu = pred.mean.detach().to("cpu").numpy().reshape(-1,)
stddev = pred.stddev.detach().to("cpu").numpy().reshape(-1,)
return mu,stddev
class DKT(nn.Module):
def __init__(self, train_data,valid_data, kernel,backbone_fn, config):
super(DKT, self).__init__()
## GP parameters
self.train_data = train_data
self.valid_data = valid_data
self.fixed_context_size = config["fixed_context_size"]
self.minibatch_size = config["minibatch_size"]
self.n_inner_steps = config["n_inner_steps"]
self.checkpoint_path = config["checkpoint_path"]
os.makedirs(self.checkpoint_path,exist_ok=False)
json.dump(config, open(os.path.join(self.checkpoint_path,"configuration.json"),"w"))
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(filename=os.path.join(self.checkpoint_path,"log.txt"), level=logging.DEBUG)
self.feature_extractor = backbone_fn().to(self.device)
self.config=config
self.get_model_likelihood_mll(self.fixed_context_size,kernel,backbone_fn)
self.mse = nn.MSELoss()
self.curr_valid_loss = np.inf
self.get_tasks()
self.setup_writers()
self.train_metrics = Metric()
self.valid_metrics = Metric(prefix="valid: ")
print(self)
def setup_writers(self,):
train_log_dir = os.path.join(self.checkpoint_path,"train")
os.makedirs(train_log_dir,exist_ok=True)
self.train_summary_writer = SummaryWriter(train_log_dir)
valid_log_dir = os.path.join(self.checkpoint_path,"valid")
os.makedirs(valid_log_dir,exist_ok=True)
self.valid_summary_writer = SummaryWriter(valid_log_dir)
def get_tasks(self,):
pairs = []
for space in self.train_data.keys():
for task in self.train_data[space].keys():
pairs.append([space,task])
self.tasks = pairs
##########
pairs = []
for space in self.valid_data.keys():
for task in self.valid_data[space].keys():
pairs.append([space,task])
self.valid_tasks = pairs
def get_model_likelihood_mll(self, train_size,kernel,backbone_fn):
train_x=torch.ones(train_size, self.feature_extractor.out_features).to(self.device)
train_y=torch.ones(train_size).to(self.device)
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPLayer(train_x=train_x, train_y=train_y, likelihood=likelihood, config=self.config,dims = self.feature_extractor.out_features)
self.model = model.to(self.device)
self.likelihood = likelihood.to(self.device)
self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model).to(self.device)
def set_forward(self, x, is_feature=False):
pass
def set_forward_loss(self, x):
pass
def epoch_end(self):
RandomTaskGenerator.shuffle(self.tasks)
def train_loop(self, epoch, optimizer, scheduler_fn=None):
if scheduler_fn:
scheduler = scheduler_fn(optimizer,len(self.tasks))
self.epoch_end()
assert(self.training)
for task in self.tasks:
inputs, labels = self.get_batch(task)
for _ in range(self.n_inner_steps):
optimizer.zero_grad()
z = self.feature_extractor(inputs)
self.model.set_train_data(inputs=z, targets=labels, strict=False)
predictions = self.model(z)
loss = -self.mll(predictions, self.model.train_targets)
loss.backward()
optimizer.step()
mse = self.mse(predictions.mean, labels)
self.train_metrics.update(loss,self.model.likelihood.noise,mse)
if scheduler_fn:
scheduler.step()
training_results = self.train_metrics.get()
for k,v in training_results.items():
self.train_summary_writer.add_scalar(k, v, epoch)
for task in self.valid_tasks:
mse,loss = self.test_loop(task,train=False)
self.valid_metrics.update(loss,np.array(0),mse,)
logging.info(self.train_metrics.report() + " " + self.valid_metrics.report())
validation_results = self.valid_metrics.get()
for k,v in validation_results.items():
self.valid_summary_writer.add_scalar(k, v, epoch)
self.feature_extractor.train()
self.likelihood.train()
self.model.train()
if validation_results["loss"] < self.curr_valid_loss:
self.save_checkpoint(os.path.join(self.checkpoint_path,"weights"))
self.curr_valid_loss = validation_results["loss"]
self.valid_metrics.reset()
self.train_metrics.reset()
def test_loop(self, task, train, optimizer=None): # no optimizer needed for GP
(x_support, y_support),(x_query,y_query) = self.get_support_and_queries(task,train)
z_support = self.feature_extractor(x_support).detach()
self.model.set_train_data(inputs=z_support, targets=y_support, strict=False)
self.model.eval()
self.feature_extractor.eval()
self.likelihood.eval()
with torch.no_grad():
z_query = self.feature_extractor(x_query).detach()
pred = self.likelihood(self.model(z_query))
loss = -self.mll(pred, y_query)
lower, upper = pred.confidence_region() #2 standard deviations above and below the mean
mse = self.mse(pred.mean, y_query)
return mse,loss
def get_batch(self,task):
# we want to fit the gp given context info to new observations
# task is an algorithm/dataset pair
space,task = task
Lambda,response = np.array(self.train_data[space][task]["X"]), MinMaxScaler().fit_transform(np.array(self.train_data[space][task]["y"])).reshape(-1,)
card, dim = Lambda.shape
support = RandomSupportGenerator.choice(np.arange(card),
replace=False,size=self.fixed_context_size)
remaining = np.setdiff1d(np.arange(card),support)
indexes = RandomQueryGenerator.choice(
remaining,replace=False,size=self.minibatch_size if len(remaining)>self.minibatch_size else len(remaining))
inputs,labels = prepare_data(support,indexes,Lambda,response,np.zeros(32))
inputs,labels = totorch(inputs,device=self.device), totorch(labels.reshape(-1,),device=self.device)
return inputs, labels
def get_support_and_queries(self,task, train=False):
# task is an algorithm/dataset pair
space,task = task
hpo_data = self.valid_data if not train else self.train_data
Lambda,response = np.array(hpo_data[space][task]["X"]), MinMaxScaler().fit_transform(np.array(hpo_data[space][task]["y"])).reshape(-1,)
card, dim = Lambda.shape
support = RandomSupportGenerator.choice(np.arange(card),
replace=False,size=self.fixed_context_size)
indexes = RandomQueryGenerator.choice(
np.setdiff1d(np.arange(card),support),replace=False,size=self.minibatch_size)
support_x,support_y = prepare_data(support,support,Lambda,response,np.zeros(32))
query_x,query_y = prepare_data(support,indexes,Lambda,response,np.zeros(32))
return (totorch(support_x,self.device),totorch(support_y.reshape(-1,),self.device)),\
(totorch(query_x,self.device),totorch(query_y.reshape(-1,),self.device))
def save_checkpoint(self, checkpoint):
# save state
gp_state_dict = self.model.state_dict()
likelihood_state_dict = self.likelihood.state_dict()
nn_state_dict = self.feature_extractor.state_dict()
torch.save({'gp': gp_state_dict, 'likelihood': likelihood_state_dict, 'net':nn_state_dict}, checkpoint)
def load_checkpoint(self, checkpoint):
ckpt = torch.load(checkpoint)
self.model.load_state_dict(ckpt['gp'])
self.likelihood.load_state_dict(ckpt['likelihood'])
self.feature_extractor.load_state_dict(ckpt['net'])
class ExactGPLayer(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood,config,dims ):
super(ExactGPLayer, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ConstantMean()
## RBF kernel
if(config["kernel"]=='rbf' or config["kernel"]=='RBF'):
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=dims if config["ard"] else None))
elif(config["kernel"]=='52'):
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel(nu=config["nu"],ard_num_dims=dims if config["ard"] else None))
## Spectral kernel
else:
raise ValueError("[ERROR] the kernel '" + str(config["kernel"]) + "' is not supported for regression, use 'rbf' or 'spectral'.")
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
class batch_mlp(nn.Module):
def __init__(self, d_in, output_sizes, nonlinearity="relu",dropout=0.0):
super(batch_mlp, self).__init__()
assert(nonlinearity=="relu")
self.nonlinearity = nn.ReLU()
self.fc = nn.ModuleList([nn.Linear(in_features=d_in, out_features=output_sizes[0])])
for d_out in output_sizes[1:]:
self.fc.append(nn.Linear(in_features=self.fc[-1].out_features, out_features=d_out))
self.out_features = output_sizes[-1]
self.dropout = nn.Dropout(dropout)
def forward(self,x):
for fc in self.fc[:-1]:
x = fc(x)
x = self.dropout(x)
x = self.nonlinearity(x)
x = self.fc[-1](x)
x = self.dropout(x)
return x
class StandardDeepGP(nn.Module):
def __init__(self, configuration):
super(StandardDeepGP, self).__init__()
self.A = batch_mlp(configuration["dim"], configuration["output_size_A"],dropout=configuration["dropout"])
self.out_features = configuration["output_size_A"][-1]
def forward(self, x):
# e,r,x,z = x
hidden = self.A(x.squeeze(dim=-1)) ### NxA
return hidden
class DKTNAS(nn.Module):
def __init__(self, kernel, backbone_fn, config, pretrained_encoder=True, GP_only=False):
super(DKTNAS, self).__init__()
## GP parameters
self.fixed_context_size = config["fixed_context_size"]
self.minibatch_size = config["minibatch_size"]
self.n_inner_steps = config["n_inner_steps"]
self.set_encoder_args = get_parser()
if not os.path.exists(self.set_encoder_args.save_path):
os.makedirs(self.set_encoder_args.save_path)
self.set_encoder_args.model_path = os.path.join(self.set_encoder_args.save_path,
self.set_encoder_args.model_name, 'model')
if not os.path.exists(self.set_encoder_args.model_path):
os.makedirs(self.set_encoder_args.model_path)
self.set_encoder = Generator(self.set_encoder_args)
if pretrained_encoder:
self.dataset_enc, self.arch, self.acc = self.set_encoder.train_dgp(encode=False)
self.dataset_enc_val, self.acc_val = self.set_encoder.test_dgp(data_name='cifar100', encode=False)
else: # In case we want to train the set-encoder from scratch
self.dataset_enc = np.load("train_data_path.npy")
self.acc = np.load("train_acc.npy")
self.dataset_enc_val = np.load("cifar100_data_path.npy")
self.acc_val = np.load("cifar100_acc.npy")
self.valid_data = self.dataset_enc_val
self.checkpoint_path = config["checkpoint_path"]
os.makedirs(self.checkpoint_path, exist_ok=False)
json.dump(config, open(os.path.join(self.checkpoint_path, "configuration.json"), "w"))
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(filename=os.path.join(self.checkpoint_path, "log.txt"), level=logging.DEBUG)
self.feature_extractor = backbone_fn().to(self.device)
self.config = config
self.GP_only = GP_only
self.get_model_likelihood_mll(self.fixed_context_size, kernel, backbone_fn)
self.mse = nn.MSELoss()
self.curr_valid_loss = np.inf
# self.get_tasks()
self.setup_writers()
self.train_metrics = Metric()
self.valid_metrics = Metric(prefix="valid: ")
self.tasks = len(self.dataset_enc)
print(self)
def setup_writers(self, ):
train_log_dir = os.path.join(self.checkpoint_path, "train")
os.makedirs(train_log_dir, exist_ok=True)
# self.train_summary_writer = SummaryWriter(train_log_dir)
valid_log_dir = os.path.join(self.checkpoint_path, "valid")
os.makedirs(valid_log_dir, exist_ok=True)
# self.valid_summary_writer = SummaryWriter(valid_log_dir)
def get_model_likelihood_mll(self, train_size, kernel, backbone_fn):
if not self.GP_only:
train_x = torch.ones(train_size, self.feature_extractor.out_features).to(self.device)
train_y = torch.ones(train_size).to(self.device)
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPLayer(train_x=None, train_y=None, likelihood=likelihood, config=self.config,
dims=self.feature_extractor.out_features)
else:
train_x = torch.ones(train_size, self.fixed_context_size).to(self.device)
train_y = torch.ones(train_size).to(self.device)
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPLayer(train_x=None, train_y=None, likelihood=likelihood, config=self.config,
dims=self.fixed_context_size)
self.model = model.to(self.device)
self.likelihood = likelihood.to(self.device)
self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model).to(self.device)
def set_forward(self, x, is_feature=False):
pass
def set_forward_loss(self, x):
pass
def epoch_end(self):
RandomTaskGenerator.shuffle([1])
def train_loop(self, epoch, optimizer, scheduler_fn=None):
if scheduler_fn:
scheduler = scheduler_fn(optimizer, 1)
self.epoch_end()
assert (self.training)
for task in range(self.tasks):
inputs, labels = self.get_batch(task)
for _ in range(self.n_inner_steps):
optimizer.zero_grad()
z = self.feature_extractor(inputs)
self.model.set_train_data(inputs=z, targets=labels, strict=False)
predictions = self.model(z)
loss = -self.mll(predictions, self.model.train_targets)
loss.backward()
optimizer.step()
mse = self.mse(predictions.mean, labels)
self.train_metrics.update(loss, self.model.likelihood.noise, mse)
if scheduler_fn:
scheduler.step()
training_results = self.train_metrics.get()
for k, v in training_results.items():
self.train_summary_writer.add_scalar(k, v, epoch)
mse, loss = self.test_loop(train=False)
self.valid_metrics.update(loss, np.array(0), mse, )
logging.info(self.train_metrics.report() + " " + self.valid_metrics.report())
validation_results = self.valid_metrics.get()
for k, v in validation_results.items():
self.valid_summary_writer.add_scalar(k, v, epoch)
self.feature_extractor.train()
self.likelihood.train()
self.model.train()
if validation_results["loss"] < self.curr_valid_loss:
self.save_checkpoint(os.path.join(self.checkpoint_path, "weights"))
self.curr_valid_loss = validation_results["loss"]
self.valid_metrics.reset()
self.train_metrics.reset()
def test_loop(self, train=None, optimizer=None): # no optimizer needed for GP
(x_support, y_support), (x_query, y_query) = self.get_support_and_queries(train)
z_support = self.feature_extractor(x_support).detach()
self.model.set_train_data(inputs=z_support, targets=y_support, strict=False)
self.model.eval()
self.feature_extractor.eval()
self.likelihood.eval()
with torch.no_grad():
z_query = self.feature_extractor(x_query).detach()
pred = self.likelihood(self.model(z_query))
loss = -self.mll(pred, y_query)
lower, upper = pred.confidence_region() # 2 standard deviations above and below the mean
mse = self.mse(pred.mean, y_query)
return mse, loss
def get_batch(self, task, valid=False):
# we want to fit the gp given context info to new observations
#TODO: scale the response as in FSBO(needed for train)
Lambda, response = np.array(self.dataset_enc), np.array(self.acc)
inputs, labels = Lambda[task], response[task]
inputs, labels = totorch([inputs], device=self.device), totorch([labels], device=self.device)
return inputs, labels
def get_support_and_queries(self, task, train=False):
# TODO: scale the response as in FSBO(not necessary for test)
Lambda, response = np.array(self.dataset_enc_val), np.array(self.acc_val)
card, dim = Lambda.shape
support = RandomSupportGenerator.choice(np.arange(card),
replace=False, size=self.fixed_context_size)
indexes = RandomQueryGenerator.choice(
np.setdiff1d(np.arange(card), support), replace=False, size=self.minibatch_size)
support_x, support_y = Lambda[support], response[support]
query_x, query_y = Lambda[indexes], response[indexes]
return (totorch(support_x, self.device), totorch(support_y.reshape(-1, ), self.device)), \
(totorch(query_x, self.device), totorch(query_y.reshape(-1, ), self.device))
def save_checkpoint(self, checkpoint):
# save state
gp_state_dict = self.model.state_dict()
likelihood_state_dict = self.likelihood.state_dict()
nn_state_dict = self.feature_extractor.state_dict()
torch.save({'gp': gp_state_dict, 'likelihood': likelihood_state_dict, 'net': nn_state_dict}, checkpoint)
def load_checkpoint(self, checkpoint):
ckpt = torch.load(checkpoint)
self.model.load_state_dict(ckpt['gp'])
self.likelihood.load_state_dict(ckpt['likelihood'])
self.feature_extractor.load_state_dict(ckpt['net'])
def predict(self, x_support, y_support, x_query, y_query, GP_only=False):
if not GP_only:
z_support = self.feature_extractor(x_support).detach()
else:
z_support = x_support
self.model.set_train_data(inputs=z_support, targets=y_support, strict=False)
self.model.eval()
self.feature_extractor.eval()
self.likelihood.eval()
with torch.no_grad():
if not GP_only:
z_query = self.feature_extractor(x_query).detach()
else:
z_query = x_query
pred = self.likelihood(self.model(z_query))
mu = pred.mean.detach().to("cpu").numpy().reshape(-1, )
stddev = pred.stddev.detach().to("cpu").numpy().reshape(-1, )
return mu, stddev