add autodl
This commit is contained in:
		
							
								
								
									
										319
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/baselines/maml-ft.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										319
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/baselines/maml-ft.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,319 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/GeMOSA/baselines/maml-ft.py --env_version v1 --hidden_dim 16 --inner_step 5 --device cuda | ||||
| # python exps/GeMOSA/baselines/maml-ft.py --env_version v2 --hidden_dim 16 --inner_step 5 --device cuda | ||||
| # python exps/GeMOSA/baselines/maml-ft.py --env_version v3 --hidden_dim 32 --inner_step 5 --device cuda | ||||
| # python exps/GeMOSA/baselines/maml-ft.py --env_version v4 --hidden_dim 32 --inner_step 5 --device cuda | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "..").resolve() | ||||
| print(lib_dir) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
| ) | ||||
| from xautodl.log_utils import time_string | ||||
| from xautodl.log_utils import AverageMeter, convert_secs2time | ||||
|  | ||||
| from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, Top1AccMetric | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| from xautodl.models.xcore import get_model | ||||
| from xautodl.xlayers import super_core | ||||
|  | ||||
|  | ||||
| class MAML: | ||||
|     """A LFNA meta-model that uses the MLP as delta-net.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, network, criterion, epochs, meta_lr, inner_lr=0.01, inner_step=1 | ||||
|     ): | ||||
|         self.criterion = criterion | ||||
|         self.network = network | ||||
|         self.meta_optimizer = torch.optim.Adam( | ||||
|             self.network.parameters(), lr=meta_lr, amsgrad=True | ||||
|         ) | ||||
|         self.inner_lr = inner_lr | ||||
|         self.inner_step = inner_step | ||||
|         self._best_info = dict(state_dict=None, iepoch=None, score=None) | ||||
|         print("There are {:} weights.".format(self.network.get_w_container().numel())) | ||||
|  | ||||
|     def adapt(self, x, y): | ||||
|         # create a container for the future timestamp | ||||
|         container = self.network.get_w_container() | ||||
|  | ||||
|         for k in range(0, self.inner_step): | ||||
|             y_hat = self.network.forward_with_container(x, container) | ||||
|             loss = self.criterion(y_hat, y) | ||||
|             grads = torch.autograd.grad(loss, container.parameters()) | ||||
|             container = container.additive([-self.inner_lr * grad for grad in grads]) | ||||
|         return container | ||||
|  | ||||
|     def predict(self, x, container=None): | ||||
|         if container is not None: | ||||
|             y_hat = self.network.forward_with_container(x, container) | ||||
|         else: | ||||
|             y_hat = self.network(x) | ||||
|         return y_hat | ||||
|  | ||||
|     def step(self): | ||||
|         torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0) | ||||
|         self.meta_optimizer.step() | ||||
|  | ||||
|     def zero_grad(self): | ||||
|         self.meta_optimizer.zero_grad() | ||||
|  | ||||
|     def load_state_dict(self, state_dict): | ||||
|         self.criterion.load_state_dict(state_dict["criterion"]) | ||||
|         self.network.load_state_dict(state_dict["network"]) | ||||
|         self.meta_optimizer.load_state_dict(state_dict["meta_optimizer"]) | ||||
|  | ||||
|     def state_dict(self): | ||||
|         state_dict = dict() | ||||
|         state_dict["criterion"] = self.criterion.state_dict() | ||||
|         state_dict["network"] = self.network.state_dict() | ||||
|         state_dict["meta_optimizer"] = self.meta_optimizer.state_dict() | ||||
|         return state_dict | ||||
|  | ||||
|     def save_best(self, score): | ||||
|         success, best_score = self.network.save_best(score) | ||||
|         return success, best_score | ||||
|  | ||||
|     def load_best(self): | ||||
|         self.network.load_best() | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|     train_env = get_synthetic_env(mode="train", version=args.env_version) | ||||
|     valid_env = get_synthetic_env(mode="valid", version=args.env_version) | ||||
|     trainval_env = get_synthetic_env(mode="trainval", version=args.env_version) | ||||
|     test_env = get_synthetic_env(mode="test", version=args.env_version) | ||||
|     all_env = get_synthetic_env(mode=None, version=args.env_version) | ||||
|     logger.log("The training enviornment: {:}".format(train_env)) | ||||
|     logger.log("The validation enviornment: {:}".format(valid_env)) | ||||
|     logger.log("The trainval enviornment: {:}".format(trainval_env)) | ||||
|     logger.log("The total enviornment: {:}".format(all_env)) | ||||
|     logger.log("The test enviornment: {:}".format(test_env)) | ||||
|     model_kwargs = dict( | ||||
|         config=dict(model_type="norm_mlp"), | ||||
|         input_dim=all_env.meta_info["input_dim"], | ||||
|         output_dim=all_env.meta_info["output_dim"], | ||||
|         hidden_dims=[args.hidden_dim] * 2, | ||||
|         act_cls="relu", | ||||
|         norm_cls="layer_norm_1d", | ||||
|     ) | ||||
|  | ||||
|     model = get_model(**model_kwargs) | ||||
|     model = model.to(args.device) | ||||
|     if all_env.meta_info["task"] == "regression": | ||||
|         criterion = torch.nn.MSELoss() | ||||
|         metric_cls = MSEMetric | ||||
|     elif all_env.meta_info["task"] == "classification": | ||||
|         criterion = torch.nn.CrossEntropyLoss() | ||||
|         metric_cls = Top1AccMetric | ||||
|     else: | ||||
|         raise ValueError( | ||||
|             "This task ({:}) is not supported.".format(all_env.meta_info["task"]) | ||||
|         ) | ||||
|  | ||||
|     maml = MAML( | ||||
|         model, criterion, args.epochs, args.meta_lr, args.inner_lr, args.inner_step | ||||
|     ) | ||||
|  | ||||
|     # meta-training | ||||
|     last_success_epoch = 0 | ||||
|     per_epoch_time, start_time = AverageMeter(), time.time() | ||||
|     for iepoch in range(args.epochs): | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||
|         ) | ||||
|         head_str = ( | ||||
|             "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) | ||||
|             + need_time | ||||
|         ) | ||||
|  | ||||
|         maml.zero_grad() | ||||
|         meta_losses = [] | ||||
|         for ibatch in range(args.meta_batch): | ||||
|             future_idx = random.randint(0, len(trainval_env) - 1) | ||||
|             future_t, (future_x, future_y) = trainval_env[future_idx] | ||||
|             # -->> | ||||
|             seq_times = trainval_env.get_seq_times(future_idx, args.seq_length) | ||||
|             _, (allxs, allys) = trainval_env.seq_call(seq_times) | ||||
|             allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1) | ||||
|             if trainval_env.meta_info["task"] == "classification": | ||||
|                 allys = allys.view(-1) | ||||
|             historical_x, historical_y = allxs.to(args.device), allys.to(args.device) | ||||
|             future_container = maml.adapt(historical_x, historical_y) | ||||
|  | ||||
|             future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||
|             future_y_hat = maml.predict(future_x, future_container) | ||||
|             future_loss = maml.criterion(future_y_hat, future_y) | ||||
|             meta_losses.append(future_loss) | ||||
|         meta_loss = torch.stack(meta_losses).mean() | ||||
|         meta_loss.backward() | ||||
|         maml.step() | ||||
|  | ||||
|         logger.log(head_str + " meta-loss: {:.4f}".format(meta_loss.item())) | ||||
|         success, best_score = maml.save_best(-meta_loss.item()) | ||||
|         if success: | ||||
|             logger.log("Achieve the best with best_score = {:.3f}".format(best_score)) | ||||
|             save_checkpoint(maml.state_dict(), logger.path("model"), logger) | ||||
|             last_success_epoch = iepoch | ||||
|         if iepoch - last_success_epoch >= args.early_stop_thresh: | ||||
|             logger.log("Early stop at {:}".format(iepoch)) | ||||
|             break | ||||
|  | ||||
|         per_epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     # meta-test | ||||
|     maml.load_best() | ||||
|  | ||||
|     def finetune(index): | ||||
|         seq_times = test_env.get_seq_times(index, args.seq_length) | ||||
|         _, (allxs, allys) = test_env.seq_call(seq_times) | ||||
|         allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1) | ||||
|         if test_env.meta_info["task"] == "classification": | ||||
|             allys = allys.view(-1) | ||||
|         historical_x, historical_y = allxs.to(args.device), allys.to(args.device) | ||||
|         future_container = maml.adapt(historical_x, historical_y) | ||||
|  | ||||
|         historical_y_hat = maml.predict(historical_x, future_container) | ||||
|         train_metric = metric_cls(True) | ||||
|         # model.analyze_weights() | ||||
|         with torch.no_grad(): | ||||
|             train_metric(historical_y_hat, historical_y) | ||||
|         train_results = train_metric.get_info() | ||||
|         return train_results, future_container | ||||
|  | ||||
|     metric = metric_cls(True) | ||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||
|     for idx, (future_time, (future_x, future_y)) in enumerate(test_env): | ||||
|  | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_timestamp_time.avg * (len(test_env) - idx), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, len(test_env)) | ||||
|             + " " | ||||
|             + need_time | ||||
|         ) | ||||
|  | ||||
|         # build optimizer | ||||
|         train_results, future_container = finetune(idx) | ||||
|  | ||||
|         future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||
|         future_y_hat = maml.predict(future_x, future_container) | ||||
|         future_loss = criterion(future_y_hat, future_y) | ||||
|         metric(future_y_hat, future_y) | ||||
|         log_str = ( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, len(test_env)) | ||||
|             + " train-score: {:.5f}, eval-score: {:.5f}".format( | ||||
|                 train_results["score"], metric.get_info()["score"] | ||||
|             ) | ||||
|         ) | ||||
|         logger.log(log_str) | ||||
|         logger.log("") | ||||
|         per_timestamp_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Use the maml.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/GeMOSA-synthetic/use-maml-ft", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--env_version", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         help="The synthetic enviornment version.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--hidden_dim", | ||||
|         type=int, | ||||
|         default=16, | ||||
|         help="The hidden dimension.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--meta_lr", | ||||
|         type=float, | ||||
|         default=0.02, | ||||
|         help="The learning rate for the MAML optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--inner_lr", | ||||
|         type=float, | ||||
|         default=0.005, | ||||
|         help="The learning rate for the inner optimization", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--inner_step", type=int, default=1, help="The inner loop steps for MAML." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--seq_length", type=int, default=20, help="The sequence length." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--meta_batch", | ||||
|         type=int, | ||||
|         default=256, | ||||
|         help="The batch size for the meta-model", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--epochs", | ||||
|         type=int, | ||||
|         default=2000, | ||||
|         help="The total number of epochs.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--early_stop_thresh", | ||||
|         type=int, | ||||
|         default=50, | ||||
|         help="The maximum epochs for early stop.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--device", | ||||
|         type=str, | ||||
|         default="cpu", | ||||
|         help="", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=4, | ||||
|         help="The number of data loading workers (default: 4)", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "The save dir argument can not be None" | ||||
|     args.save_dir = "{:}-s{:}-mlr{:}-d{:}-e{:}-env{:}".format( | ||||
|         args.save_dir, | ||||
|         args.inner_step, | ||||
|         args.meta_lr, | ||||
|         args.hidden_dim, | ||||
|         args.epochs, | ||||
|         args.env_version, | ||||
|     ) | ||||
|     main(args) | ||||
							
								
								
									
										319
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/baselines/maml-nof.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										319
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/baselines/maml-nof.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,319 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/GeMOSA/baselines/maml-nof.py --env_version v1 --hidden_dim 16 --inner_step 5 --device cuda | ||||
| # python exps/GeMOSA/baselines/maml-nof.py --env_version v2 --hidden_dim 16 --inner_step 5 --device cuda | ||||
| # python exps/GeMOSA/baselines/maml-nof.py --env_version v3 --hidden_dim 32 --inner_step 5 --device cuda | ||||
| # python exps/GeMOSA/baselines/maml-nof.py --env_version v4 --hidden_dim 32 --inner_step 5 --device cuda | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "..").resolve() | ||||
| print(lib_dir) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
| ) | ||||
| from xautodl.log_utils import time_string | ||||
| from xautodl.log_utils import AverageMeter, convert_secs2time | ||||
|  | ||||
| from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, Top1AccMetric | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| from xautodl.models.xcore import get_model | ||||
| from xautodl.xlayers import super_core | ||||
|  | ||||
|  | ||||
| class MAML: | ||||
|     """A LFNA meta-model that uses the MLP as delta-net.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, network, criterion, epochs, meta_lr, inner_lr=0.01, inner_step=1 | ||||
|     ): | ||||
|         self.criterion = criterion | ||||
|         self.network = network | ||||
|         self.meta_optimizer = torch.optim.Adam( | ||||
|             self.network.parameters(), lr=meta_lr, amsgrad=True | ||||
|         ) | ||||
|         self.inner_lr = inner_lr | ||||
|         self.inner_step = inner_step | ||||
|         self._best_info = dict(state_dict=None, iepoch=None, score=None) | ||||
|         print("There are {:} weights.".format(self.network.get_w_container().numel())) | ||||
|  | ||||
|     def adapt(self, x, y): | ||||
|         # create a container for the future timestamp | ||||
|         container = self.network.get_w_container() | ||||
|  | ||||
|         for k in range(0, self.inner_step): | ||||
|             y_hat = self.network.forward_with_container(x, container) | ||||
|             loss = self.criterion(y_hat, y) | ||||
|             grads = torch.autograd.grad(loss, container.parameters()) | ||||
|             container = container.additive([-self.inner_lr * grad for grad in grads]) | ||||
|         return container | ||||
|  | ||||
|     def predict(self, x, container=None): | ||||
|         if container is not None: | ||||
|             y_hat = self.network.forward_with_container(x, container) | ||||
|         else: | ||||
|             y_hat = self.network(x) | ||||
|         return y_hat | ||||
|  | ||||
|     def step(self): | ||||
|         torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0) | ||||
|         self.meta_optimizer.step() | ||||
|  | ||||
|     def zero_grad(self): | ||||
|         self.meta_optimizer.zero_grad() | ||||
|  | ||||
|     def load_state_dict(self, state_dict): | ||||
|         self.criterion.load_state_dict(state_dict["criterion"]) | ||||
|         self.network.load_state_dict(state_dict["network"]) | ||||
|         self.meta_optimizer.load_state_dict(state_dict["meta_optimizer"]) | ||||
|  | ||||
|     def state_dict(self): | ||||
|         state_dict = dict() | ||||
|         state_dict["criterion"] = self.criterion.state_dict() | ||||
|         state_dict["network"] = self.network.state_dict() | ||||
|         state_dict["meta_optimizer"] = self.meta_optimizer.state_dict() | ||||
|         return state_dict | ||||
|  | ||||
|     def save_best(self, score): | ||||
|         success, best_score = self.network.save_best(score) | ||||
|         return success, best_score | ||||
|  | ||||
|     def load_best(self): | ||||
|         self.network.load_best() | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|     train_env = get_synthetic_env(mode="train", version=args.env_version) | ||||
|     valid_env = get_synthetic_env(mode="valid", version=args.env_version) | ||||
|     trainval_env = get_synthetic_env(mode="trainval", version=args.env_version) | ||||
|     test_env = get_synthetic_env(mode="test", version=args.env_version) | ||||
|     all_env = get_synthetic_env(mode=None, version=args.env_version) | ||||
|     logger.log("The training enviornment: {:}".format(train_env)) | ||||
|     logger.log("The validation enviornment: {:}".format(valid_env)) | ||||
|     logger.log("The trainval enviornment: {:}".format(trainval_env)) | ||||
|     logger.log("The total enviornment: {:}".format(all_env)) | ||||
|     logger.log("The test enviornment: {:}".format(test_env)) | ||||
|     model_kwargs = dict( | ||||
|         config=dict(model_type="norm_mlp"), | ||||
|         input_dim=all_env.meta_info["input_dim"], | ||||
|         output_dim=all_env.meta_info["output_dim"], | ||||
|         hidden_dims=[args.hidden_dim] * 2, | ||||
|         act_cls="relu", | ||||
|         norm_cls="layer_norm_1d", | ||||
|     ) | ||||
|  | ||||
|     model = get_model(**model_kwargs) | ||||
|     model = model.to(args.device) | ||||
|     if all_env.meta_info["task"] == "regression": | ||||
|         criterion = torch.nn.MSELoss() | ||||
|         metric_cls = MSEMetric | ||||
|     elif all_env.meta_info["task"] == "classification": | ||||
|         criterion = torch.nn.CrossEntropyLoss() | ||||
|         metric_cls = Top1AccMetric | ||||
|     else: | ||||
|         raise ValueError( | ||||
|             "This task ({:}) is not supported.".format(all_env.meta_info["task"]) | ||||
|         ) | ||||
|  | ||||
|     maml = MAML( | ||||
|         model, criterion, args.epochs, args.meta_lr, args.inner_lr, args.inner_step | ||||
|     ) | ||||
|  | ||||
|     # meta-training | ||||
|     last_success_epoch = 0 | ||||
|     per_epoch_time, start_time = AverageMeter(), time.time() | ||||
|     for iepoch in range(args.epochs): | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||
|         ) | ||||
|         head_str = ( | ||||
|             "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) | ||||
|             + need_time | ||||
|         ) | ||||
|  | ||||
|         maml.zero_grad() | ||||
|         meta_losses = [] | ||||
|         for ibatch in range(args.meta_batch): | ||||
|             future_idx = random.randint(0, len(trainval_env) - 1) | ||||
|             future_t, (future_x, future_y) = trainval_env[future_idx] | ||||
|             # -->> | ||||
|             seq_times = trainval_env.get_seq_times(future_idx, args.seq_length) | ||||
|             _, (allxs, allys) = trainval_env.seq_call(seq_times) | ||||
|             allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1) | ||||
|             if trainval_env.meta_info["task"] == "classification": | ||||
|                 allys = allys.view(-1) | ||||
|             historical_x, historical_y = allxs.to(args.device), allys.to(args.device) | ||||
|             future_container = maml.adapt(historical_x, historical_y) | ||||
|  | ||||
|             future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||
|             future_y_hat = maml.predict(future_x, future_container) | ||||
|             future_loss = maml.criterion(future_y_hat, future_y) | ||||
|             meta_losses.append(future_loss) | ||||
|         meta_loss = torch.stack(meta_losses).mean() | ||||
|         meta_loss.backward() | ||||
|         maml.step() | ||||
|  | ||||
|         logger.log(head_str + " meta-loss: {:.4f}".format(meta_loss.item())) | ||||
|         success, best_score = maml.save_best(-meta_loss.item()) | ||||
|         if success: | ||||
|             logger.log("Achieve the best with best_score = {:.3f}".format(best_score)) | ||||
|             save_checkpoint(maml.state_dict(), logger.path("model"), logger) | ||||
|             last_success_epoch = iepoch | ||||
|         if iepoch - last_success_epoch >= args.early_stop_thresh: | ||||
|             logger.log("Early stop at {:}".format(iepoch)) | ||||
|             break | ||||
|  | ||||
|         per_epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     # meta-test | ||||
|     maml.load_best() | ||||
|  | ||||
|     def finetune(index): | ||||
|         seq_times = test_env.get_seq_times(index, args.seq_length) | ||||
|         _, (allxs, allys) = test_env.seq_call(seq_times) | ||||
|         allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1) | ||||
|         if test_env.meta_info["task"] == "classification": | ||||
|             allys = allys.view(-1) | ||||
|         historical_x, historical_y = allxs.to(args.device), allys.to(args.device) | ||||
|         future_container = maml.adapt(historical_x, historical_y) | ||||
|  | ||||
|         historical_y_hat = maml.predict(historical_x, future_container) | ||||
|         train_metric = metric_cls(True) | ||||
|         # model.analyze_weights() | ||||
|         with torch.no_grad(): | ||||
|             train_metric(historical_y_hat, historical_y) | ||||
|         train_results = train_metric.get_info() | ||||
|         return train_results, future_container | ||||
|  | ||||
|     train_results, future_container = finetune(0) | ||||
|  | ||||
|     metric = metric_cls(True) | ||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||
|     for idx, (future_time, (future_x, future_y)) in enumerate(test_env): | ||||
|  | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_timestamp_time.avg * (len(test_env) - idx), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, len(test_env)) | ||||
|             + " " | ||||
|             + need_time | ||||
|         ) | ||||
|  | ||||
|         # build optimizer | ||||
|         future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||
|         future_y_hat = maml.predict(future_x, future_container) | ||||
|         future_loss = criterion(future_y_hat, future_y) | ||||
|         metric(future_y_hat, future_y) | ||||
|         log_str = ( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, len(test_env)) | ||||
|             + " train-score: {:.5f}, eval-score: {:.5f}".format( | ||||
|                 train_results["score"], metric.get_info()["score"] | ||||
|             ) | ||||
|         ) | ||||
|         logger.log(log_str) | ||||
|         logger.log("") | ||||
|         per_timestamp_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Use the maml.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/GeMOSA-synthetic/use-maml-nft", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--env_version", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         help="The synthetic enviornment version.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--hidden_dim", | ||||
|         type=int, | ||||
|         default=16, | ||||
|         help="The hidden dimension.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--meta_lr", | ||||
|         type=float, | ||||
|         default=0.02, | ||||
|         help="The learning rate for the MAML optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--inner_lr", | ||||
|         type=float, | ||||
|         default=0.005, | ||||
|         help="The learning rate for the inner optimization", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--inner_step", type=int, default=1, help="The inner loop steps for MAML." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--seq_length", type=int, default=20, help="The sequence length." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--meta_batch", | ||||
|         type=int, | ||||
|         default=256, | ||||
|         help="The batch size for the meta-model", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--epochs", | ||||
|         type=int, | ||||
|         default=2000, | ||||
|         help="The total number of epochs.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--early_stop_thresh", | ||||
|         type=int, | ||||
|         default=50, | ||||
|         help="The maximum epochs for early stop.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--device", | ||||
|         type=str, | ||||
|         default="cpu", | ||||
|         help="", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=4, | ||||
|         help="The number of data loading workers (default: 4)", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "The save dir argument can not be None" | ||||
|     args.save_dir = "{:}-s{:}-mlr{:}-d{:}-e{:}-env{:}".format( | ||||
|         args.save_dir, | ||||
|         args.inner_step, | ||||
|         args.meta_lr, | ||||
|         args.hidden_dim, | ||||
|         args.epochs, | ||||
|         args.env_version, | ||||
|     ) | ||||
|     main(args) | ||||
							
								
								
									
										228
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/baselines/slbm-ft.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										228
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/baselines/slbm-ft.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,228 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/GeMOSA/baselines/slbm-ft.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda | ||||
| # python exps/GeMOSA/baselines/slbm-ft.py --env_version v2 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda | ||||
| # python exps/GeMOSA/baselines/slbm-ft.py --env_version v3 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda | ||||
| # python exps/GeMOSA/baselines/slbm-ft.py --env_version v4 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "..").resolve() | ||||
| print("LIB-DIR: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
|  | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
| ) | ||||
| from xautodl.log_utils import time_string | ||||
| from xautodl.log_utils import AverageMeter, convert_secs2time | ||||
|  | ||||
| from xautodl.procedures.metric_utils import ( | ||||
|     SaveMetric, | ||||
|     MSEMetric, | ||||
|     Top1AccMetric, | ||||
|     ComposeMetric, | ||||
| ) | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| from xautodl.models.xcore import get_model | ||||
| from xautodl.utils import show_mean_var | ||||
|  | ||||
|  | ||||
| def subsample(historical_x, historical_y, maxn=10000): | ||||
|     total = historical_x.size(0) | ||||
|     if total <= maxn: | ||||
|         return historical_x, historical_y | ||||
|     else: | ||||
|         indexes = torch.randint(low=0, high=total, size=[maxn]) | ||||
|         return historical_x[indexes], historical_y[indexes] | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|     env = get_synthetic_env(mode="test", version=args.env_version) | ||||
|     model_kwargs = dict( | ||||
|         config=dict(model_type="norm_mlp"), | ||||
|         input_dim=env.meta_info["input_dim"], | ||||
|         output_dim=env.meta_info["output_dim"], | ||||
|         hidden_dims=[args.hidden_dim] * 2, | ||||
|         act_cls="relu", | ||||
|         norm_cls="layer_norm_1d", | ||||
|     ) | ||||
|     logger.log("The total enviornment: {:}".format(env)) | ||||
|     w_containers = dict() | ||||
|  | ||||
|     if env.meta_info["task"] == "regression": | ||||
|         criterion = torch.nn.MSELoss() | ||||
|         metric_cls = MSEMetric | ||||
|     elif env.meta_info["task"] == "classification": | ||||
|         criterion = torch.nn.CrossEntropyLoss() | ||||
|         metric_cls = Top1AccMetric | ||||
|     else: | ||||
|         raise ValueError( | ||||
|             "This task ({:}) is not supported.".format(all_env.meta_info["task"]) | ||||
|         ) | ||||
|  | ||||
|     def finetune(index): | ||||
|         seq_times = env.get_seq_times(index, args.seq_length) | ||||
|         _, (allxs, allys) = env.seq_call(seq_times) | ||||
|         allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1) | ||||
|         if env.meta_info["task"] == "classification": | ||||
|             allys = allys.view(-1) | ||||
|         historical_x, historical_y = allxs.to(args.device), allys.to(args.device) | ||||
|         model = get_model(**model_kwargs) | ||||
|         model = model.to(args.device) | ||||
|  | ||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||
|         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|             optimizer, | ||||
|             milestones=[ | ||||
|                 int(args.epochs * 0.25), | ||||
|                 int(args.epochs * 0.5), | ||||
|                 int(args.epochs * 0.75), | ||||
|             ], | ||||
|             gamma=0.3, | ||||
|         ) | ||||
|  | ||||
|         train_metric = metric_cls(True) | ||||
|         best_loss, best_param = None, None | ||||
|         for _iepoch in range(args.epochs): | ||||
|             preds = model(historical_x) | ||||
|             optimizer.zero_grad() | ||||
|             loss = criterion(preds, historical_y) | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|             lr_scheduler.step() | ||||
|             # save best | ||||
|             if best_loss is None or best_loss > loss.item(): | ||||
|                 best_loss = loss.item() | ||||
|                 best_param = copy.deepcopy(model.state_dict()) | ||||
|         model.load_state_dict(best_param) | ||||
|         # model.analyze_weights() | ||||
|         with torch.no_grad(): | ||||
|             train_metric(preds, historical_y) | ||||
|         train_results = train_metric.get_info() | ||||
|         return train_results, model | ||||
|  | ||||
|     metric = metric_cls(True) | ||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||
|     for idx, (future_time, (future_x, future_y)) in enumerate(env): | ||||
|  | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_timestamp_time.avg * (len(env) - idx), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, len(env)) | ||||
|             + " " | ||||
|             + need_time | ||||
|         ) | ||||
|         # train the same data | ||||
|         train_results, model = finetune(idx) | ||||
|  | ||||
|         # build optimizer | ||||
|         xmetric = ComposeMetric(metric_cls(True), SaveMetric()) | ||||
|         future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||
|         future_y_hat = model(future_x) | ||||
|         future_loss = criterion(future_y_hat, future_y) | ||||
|         metric(future_y_hat, future_y) | ||||
|         log_str = ( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, len(env)) | ||||
|             + " train-score: {:.5f}, eval-score: {:.5f}".format( | ||||
|                 train_results["score"], metric.get_info()["score"] | ||||
|             ) | ||||
|         ) | ||||
|         logger.log(log_str) | ||||
|         logger.log("") | ||||
|         per_timestamp_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     save_checkpoint( | ||||
|         {"w_containers": w_containers}, | ||||
|         logger.path(None) / "final-ckp.pth", | ||||
|         logger, | ||||
|     ) | ||||
|  | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|     return metric.get_info()["score"] | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Use the data in the past.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/GeMOSA-synthetic/use-same-ft-timestamp", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--env_version", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         help="The synthetic enviornment version.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--hidden_dim", | ||||
|         type=int, | ||||
|         required=True, | ||||
|         help="The hidden dimension.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--init_lr", | ||||
|         type=float, | ||||
|         default=0.1, | ||||
|         help="The initial learning rate for the optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--seq_length", type=int, default=20, help="The sequence length." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--batch_size", | ||||
|         type=int, | ||||
|         default=512, | ||||
|         help="The batch size", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--epochs", | ||||
|         type=int, | ||||
|         default=300, | ||||
|         help="The total number of epochs.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--device", | ||||
|         type=str, | ||||
|         default="cpu", | ||||
|         help="", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=4, | ||||
|         help="The number of data loading workers (default: 4)", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     args.save_dir = "{:}-d{:}_e{:}_lr{:}-env{:}".format( | ||||
|         args.save_dir, args.hidden_dim, args.epochs, args.init_lr, args.env_version | ||||
|     ) | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         results = [] | ||||
|         for iseed in range(3): | ||||
|             args.rand_seed = random.randint(1, 100000) | ||||
|             result = main(args) | ||||
|             results.append(result) | ||||
|         show_mean_var(results) | ||||
|     else: | ||||
|         main(args) | ||||
							
								
								
									
										227
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/baselines/slbm-nof.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										227
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/baselines/slbm-nof.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,227 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/GeMOSA/baselines/slbm-nof.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda | ||||
| # python exps/GeMOSA/baselines/slbm-nof.py --env_version v2 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda | ||||
| # python exps/GeMOSA/baselines/slbm-nof.py --env_version v3 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda | ||||
| # python exps/GeMOSA/baselines/slbm-nof.py --env_version v4 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "..").resolve() | ||||
| print("LIB-DIR: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
|  | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
| ) | ||||
| from xautodl.log_utils import time_string | ||||
| from xautodl.log_utils import AverageMeter, convert_secs2time | ||||
|  | ||||
| from xautodl.procedures.metric_utils import ( | ||||
|     SaveMetric, | ||||
|     MSEMetric, | ||||
|     Top1AccMetric, | ||||
|     ComposeMetric, | ||||
| ) | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| from xautodl.models.xcore import get_model | ||||
| from xautodl.utils import show_mean_var | ||||
|  | ||||
|  | ||||
| def subsample(historical_x, historical_y, maxn=10000): | ||||
|     total = historical_x.size(0) | ||||
|     if total <= maxn: | ||||
|         return historical_x, historical_y | ||||
|     else: | ||||
|         indexes = torch.randint(low=0, high=total, size=[maxn]) | ||||
|         return historical_x[indexes], historical_y[indexes] | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|     env = get_synthetic_env(mode="test", version=args.env_version) | ||||
|     model_kwargs = dict( | ||||
|         config=dict(model_type="norm_mlp"), | ||||
|         input_dim=env.meta_info["input_dim"], | ||||
|         output_dim=env.meta_info["output_dim"], | ||||
|         hidden_dims=[args.hidden_dim] * 2, | ||||
|         act_cls="relu", | ||||
|         norm_cls="layer_norm_1d", | ||||
|     ) | ||||
|     logger.log("The total enviornment: {:}".format(env)) | ||||
|     w_containers = dict() | ||||
|  | ||||
|     if env.meta_info["task"] == "regression": | ||||
|         criterion = torch.nn.MSELoss() | ||||
|         metric_cls = MSEMetric | ||||
|     elif env.meta_info["task"] == "classification": | ||||
|         criterion = torch.nn.CrossEntropyLoss() | ||||
|         metric_cls = Top1AccMetric | ||||
|     else: | ||||
|         raise ValueError( | ||||
|             "This task ({:}) is not supported.".format(all_env.meta_info["task"]) | ||||
|         ) | ||||
|  | ||||
|     seq_times = env.get_seq_times(0, args.seq_length) | ||||
|     _, (allxs, allys) = env.seq_call(seq_times) | ||||
|     allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1) | ||||
|     if env.meta_info["task"] == "classification": | ||||
|         allys = allys.view(-1) | ||||
|  | ||||
|     historical_x, historical_y = allxs.to(args.device), allys.to(args.device) | ||||
|     model = get_model(**model_kwargs) | ||||
|     model = model.to(args.device) | ||||
|  | ||||
|     optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||
|     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|         optimizer, | ||||
|         milestones=[ | ||||
|             int(args.epochs * 0.25), | ||||
|             int(args.epochs * 0.5), | ||||
|             int(args.epochs * 0.75), | ||||
|         ], | ||||
|         gamma=0.3, | ||||
|     ) | ||||
|  | ||||
|     train_metric = metric_cls(True) | ||||
|     best_loss, best_param = None, None | ||||
|     for _iepoch in range(args.epochs): | ||||
|         preds = model(historical_x) | ||||
|         optimizer.zero_grad() | ||||
|         loss = criterion(preds, historical_y) | ||||
|         loss.backward() | ||||
|         optimizer.step() | ||||
|         lr_scheduler.step() | ||||
|         # save best | ||||
|         if best_loss is None or best_loss > loss.item(): | ||||
|             best_loss = loss.item() | ||||
|             best_param = copy.deepcopy(model.state_dict()) | ||||
|     model.load_state_dict(best_param) | ||||
|     model.analyze_weights() | ||||
|     with torch.no_grad(): | ||||
|         train_metric(preds, historical_y) | ||||
|     train_results = train_metric.get_info() | ||||
|     print(train_results) | ||||
|  | ||||
|     metric = metric_cls(True) | ||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||
|     for idx, (future_time, (future_x, future_y)) in enumerate(env): | ||||
|  | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_timestamp_time.avg * (len(env) - idx), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, len(env)) | ||||
|             + " " | ||||
|             + need_time | ||||
|         ) | ||||
|         # train the same data | ||||
|  | ||||
|         # build optimizer | ||||
|         xmetric = ComposeMetric(metric_cls(True), SaveMetric()) | ||||
|         future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||
|         future_y_hat = model(future_x) | ||||
|         future_loss = criterion(future_y_hat, future_y) | ||||
|         metric(future_y_hat, future_y) | ||||
|         log_str = ( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, len(env)) | ||||
|             + " train-score: {:.5f}, eval-score: {:.5f}".format( | ||||
|                 train_results["score"], metric.get_info()["score"] | ||||
|             ) | ||||
|         ) | ||||
|         logger.log(log_str) | ||||
|         logger.log("") | ||||
|         per_timestamp_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     save_checkpoint( | ||||
|         {"w_containers": w_containers}, | ||||
|         logger.path(None) / "final-ckp.pth", | ||||
|         logger, | ||||
|     ) | ||||
|  | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|     return metric.get_info()["score"] | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Use the data in the past.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/GeMOSA-synthetic/use-same-nof-timestamp", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--env_version", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         help="The synthetic enviornment version.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--hidden_dim", | ||||
|         type=int, | ||||
|         required=True, | ||||
|         help="The hidden dimension.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--seq_length", type=int, default=20, help="The sequence length." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--init_lr", | ||||
|         type=float, | ||||
|         default=0.1, | ||||
|         help="The initial learning rate for the optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--batch_size", | ||||
|         type=int, | ||||
|         default=512, | ||||
|         help="The batch size", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--epochs", | ||||
|         type=int, | ||||
|         default=300, | ||||
|         help="The total number of epochs.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--device", | ||||
|         type=str, | ||||
|         default="cpu", | ||||
|         help="", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=4, | ||||
|         help="The number of data loading workers (default: 4)", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     args.save_dir = "{:}-d{:}_e{:}_lr{:}-env{:}".format( | ||||
|         args.save_dir, args.hidden_dim, args.epochs, args.init_lr, args.env_version | ||||
|     ) | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         results = [] | ||||
|         for iseed in range(3): | ||||
|             args.rand_seed = random.randint(1, 100000) | ||||
|             result = main(args) | ||||
|             results.append(result) | ||||
|         show_mean_var(results) | ||||
|     else: | ||||
|         main(args) | ||||
							
								
								
									
										206
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/basic-his.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										206
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/basic-his.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,206 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/LFNA/basic-his.py --srange 1-999 --env_version v1 --hidden_dim 16 | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| from copy import deepcopy | ||||
|  | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
| ) | ||||
| from xautodl.log_utils import time_string | ||||
| from xautodl.log_utils import AverageMeter, convert_secs2time | ||||
|  | ||||
| from xautodl.utils import split_str2indexes | ||||
|  | ||||
| from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||
| from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| from xautodl.models.xcore import get_model | ||||
|  | ||||
|  | ||||
| from lfna_utils import lfna_setup | ||||
|  | ||||
|  | ||||
| def subsample(historical_x, historical_y, maxn=10000): | ||||
|     total = historical_x.size(0) | ||||
|     if total <= maxn: | ||||
|         return historical_x, historical_y | ||||
|     else: | ||||
|         indexes = torch.randint(low=0, high=total, size=[maxn]) | ||||
|         return historical_x[indexes], historical_y[indexes] | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     logger, env_info, model_kwargs = lfna_setup(args) | ||||
|  | ||||
|     # check indexes to be evaluated | ||||
|     to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None) | ||||
|     logger.log( | ||||
|         "Evaluate {:}, which has {:} timestamps in total.".format( | ||||
|             args.srange, len(to_evaluate_indexes) | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     w_container_per_epoch = dict() | ||||
|  | ||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||
|     for i, idx in enumerate(to_evaluate_indexes): | ||||
|  | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time( | ||||
|                 per_timestamp_time.avg * (len(to_evaluate_indexes) - i), True | ||||
|             ) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}][{:04d}]".format(i, len(to_evaluate_indexes), idx) | ||||
|             + " " | ||||
|             + need_time | ||||
|         ) | ||||
|         # train the same data | ||||
|         assert idx != 0 | ||||
|         historical_x, historical_y = [], [] | ||||
|         for past_i in range(idx): | ||||
|             historical_x.append(env_info["{:}-x".format(past_i)]) | ||||
|             historical_y.append(env_info["{:}-y".format(past_i)]) | ||||
|         historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y) | ||||
|         historical_x, historical_y = subsample(historical_x, historical_y) | ||||
|         # build model | ||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||
|         # build optimizer | ||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||
|         criterion = torch.nn.MSELoss() | ||||
|         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|             optimizer, | ||||
|             milestones=[ | ||||
|                 int(args.epochs * 0.25), | ||||
|                 int(args.epochs * 0.5), | ||||
|                 int(args.epochs * 0.75), | ||||
|             ], | ||||
|             gamma=0.3, | ||||
|         ) | ||||
|         train_metric = MSEMetric() | ||||
|         best_loss, best_param = None, None | ||||
|         for _iepoch in range(args.epochs): | ||||
|             preds = model(historical_x) | ||||
|             optimizer.zero_grad() | ||||
|             loss = criterion(preds, historical_y) | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|             lr_scheduler.step() | ||||
|             # save best | ||||
|             if best_loss is None or best_loss > loss.item(): | ||||
|                 best_loss = loss.item() | ||||
|                 best_param = copy.deepcopy(model.state_dict()) | ||||
|         model.load_state_dict(best_param) | ||||
|         with torch.no_grad(): | ||||
|             train_metric(preds, historical_y) | ||||
|         train_results = train_metric.get_info() | ||||
|  | ||||
|         metric = ComposeMetric(MSEMetric(), SaveMetric()) | ||||
|         eval_dataset = torch.utils.data.TensorDataset( | ||||
|             env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)] | ||||
|         ) | ||||
|         eval_loader = torch.utils.data.DataLoader( | ||||
|             eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 | ||||
|         ) | ||||
|         results = basic_eval_fn(eval_loader, model, metric, logger) | ||||
|         log_str = ( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, env_info["total"]) | ||||
|             + " train-mse: {:.5f}, eval-mse: {:.5f}".format( | ||||
|                 train_results["mse"], results["mse"] | ||||
|             ) | ||||
|         ) | ||||
|         logger.log(log_str) | ||||
|  | ||||
|         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( | ||||
|             idx, env_info["total"] | ||||
|         ) | ||||
|         w_container_per_epoch[idx] = model.get_w_container().no_grad_clone() | ||||
|         save_checkpoint( | ||||
|             { | ||||
|                 "model_state_dict": model.state_dict(), | ||||
|                 "model": model, | ||||
|                 "index": idx, | ||||
|                 "timestamp": env_info["{:}-timestamp".format(idx)], | ||||
|             }, | ||||
|             save_path, | ||||
|             logger, | ||||
|         ) | ||||
|         logger.log("") | ||||
|         per_timestamp_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     save_checkpoint( | ||||
|         {"w_container_per_epoch": w_container_per_epoch}, | ||||
|         logger.path(None) / "final-ckp.pth", | ||||
|         logger, | ||||
|     ) | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Use all the past data to train.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/lfna-synthetic/use-all-past-data", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--env_version", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         help="The synthetic enviornment version.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--hidden_dim", | ||||
|         type=int, | ||||
|         required=True, | ||||
|         help="The hidden dimension.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--init_lr", | ||||
|         type=float, | ||||
|         default=0.1, | ||||
|         help="The initial learning rate for the optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--batch_size", | ||||
|         type=int, | ||||
|         default=512, | ||||
|         help="The batch size", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--epochs", | ||||
|         type=int, | ||||
|         default=1000, | ||||
|         help="The total number of epochs.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--srange", type=str, required=True, help="The range of models to be evaluated" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=4, | ||||
|         help="The number of data loading workers (default: 4)", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "The save dir argument can not be None" | ||||
|     args.save_dir = "{:}-{:}-d{:}".format( | ||||
|         args.save_dir, args.env_version, args.hidden_dim | ||||
|     ) | ||||
|     main(args) | ||||
							
								
								
									
										207
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/basic-prev.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										207
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/basic-prev.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,207 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/GeMOSA/basic-prev.py --env_version v1 --prev_time 5 --hidden_dim 16 --epochs 500 --init_lr 0.1 | ||||
| # python exps/GeMOSA/basic-prev.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05 | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
| ) | ||||
| from xautodl.log_utils import time_string | ||||
| from xautodl.log_utils import AverageMeter, convert_secs2time | ||||
|  | ||||
| from xautodl.utils import split_str2indexes | ||||
|  | ||||
| from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||
| from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| from xautodl.models.xcore import get_model | ||||
|  | ||||
| from lfna_utils import lfna_setup | ||||
|  | ||||
|  | ||||
| def subsample(historical_x, historical_y, maxn=10000): | ||||
|     total = historical_x.size(0) | ||||
|     if total <= maxn: | ||||
|         return historical_x, historical_y | ||||
|     else: | ||||
|         indexes = torch.randint(low=0, high=total, size=[maxn]) | ||||
|         return historical_x[indexes], historical_y[indexes] | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     logger, model_kwargs = lfna_setup(args) | ||||
|  | ||||
|     w_containers = dict() | ||||
|  | ||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||
|     for idx in range(args.prev_time, env_info["total"]): | ||||
|  | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, env_info["total"]) | ||||
|             + " " | ||||
|             + need_time | ||||
|         ) | ||||
|         # train the same data | ||||
|         historical_x = env_info["{:}-x".format(idx - args.prev_time)] | ||||
|         historical_y = env_info["{:}-y".format(idx - args.prev_time)] | ||||
|         # build model | ||||
|         model = get_model(**model_kwargs) | ||||
|         print(model) | ||||
|         # build optimizer | ||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||
|         criterion = torch.nn.MSELoss() | ||||
|         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|             optimizer, | ||||
|             milestones=[ | ||||
|                 int(args.epochs * 0.25), | ||||
|                 int(args.epochs * 0.5), | ||||
|                 int(args.epochs * 0.75), | ||||
|             ], | ||||
|             gamma=0.3, | ||||
|         ) | ||||
|         train_metric = MSEMetric() | ||||
|         best_loss, best_param = None, None | ||||
|         for _iepoch in range(args.epochs): | ||||
|             preds = model(historical_x) | ||||
|             optimizer.zero_grad() | ||||
|             loss = criterion(preds, historical_y) | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|             lr_scheduler.step() | ||||
|             # save best | ||||
|             if best_loss is None or best_loss > loss.item(): | ||||
|                 best_loss = loss.item() | ||||
|                 best_param = copy.deepcopy(model.state_dict()) | ||||
|         model.load_state_dict(best_param) | ||||
|         model.analyze_weights() | ||||
|         with torch.no_grad(): | ||||
|             train_metric(preds, historical_y) | ||||
|         train_results = train_metric.get_info() | ||||
|  | ||||
|         metric = ComposeMetric(MSEMetric(), SaveMetric()) | ||||
|         eval_dataset = torch.utils.data.TensorDataset( | ||||
|             env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)] | ||||
|         ) | ||||
|         eval_loader = torch.utils.data.DataLoader( | ||||
|             eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 | ||||
|         ) | ||||
|         results = basic_eval_fn(eval_loader, model, metric, logger) | ||||
|         log_str = ( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, env_info["total"]) | ||||
|             + " train-mse: {:.5f}, eval-mse: {:.5f}".format( | ||||
|                 train_results["mse"], results["mse"] | ||||
|             ) | ||||
|         ) | ||||
|         logger.log(log_str) | ||||
|  | ||||
|         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( | ||||
|             idx, env_info["total"] | ||||
|         ) | ||||
|         w_containers[idx] = model.get_w_container().no_grad_clone() | ||||
|         save_checkpoint( | ||||
|             { | ||||
|                 "model_state_dict": model.state_dict(), | ||||
|                 "model": model, | ||||
|                 "index": idx, | ||||
|                 "timestamp": env_info["{:}-timestamp".format(idx)], | ||||
|             }, | ||||
|             save_path, | ||||
|             logger, | ||||
|         ) | ||||
|         logger.log("") | ||||
|         per_timestamp_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     save_checkpoint( | ||||
|         {"w_containers": w_containers}, | ||||
|         logger.path(None) / "final-ckp.pth", | ||||
|         logger, | ||||
|     ) | ||||
|  | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Use the data in the last timestamp.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/lfna-synthetic/use-prev-timestamp", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--env_version", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         help="The synthetic enviornment version.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--hidden_dim", | ||||
|         type=int, | ||||
|         required=True, | ||||
|         help="The hidden dimension.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--init_lr", | ||||
|         type=float, | ||||
|         default=0.1, | ||||
|         help="The initial learning rate for the optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--prev_time", | ||||
|         type=int, | ||||
|         default=5, | ||||
|         help="The gap between prev_time and current_timestamp", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--batch_size", | ||||
|         type=int, | ||||
|         default=512, | ||||
|         help="The batch size", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--epochs", | ||||
|         type=int, | ||||
|         default=300, | ||||
|         help="The total number of epochs.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=4, | ||||
|         help="The number of data loading workers (default: 4)", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "The save dir argument can not be None" | ||||
|     args.save_dir = "{:}-d{:}_e{:}_lr{:}-prev{:}-env{:}".format( | ||||
|         args.save_dir, | ||||
|         args.hidden_dim, | ||||
|         args.epochs, | ||||
|         args.init_lr, | ||||
|         args.prev_time, | ||||
|         args.env_version, | ||||
|     ) | ||||
|     main(args) | ||||
							
								
								
									
										228
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/basic-same.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										228
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/basic-same.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,228 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/GeMOSA/basic-same.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda | ||||
| # python exps/GeMOSA/basic-same.py --env_version v2 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda | ||||
| # python exps/GeMOSA/basic-same.py --env_version v3 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda | ||||
| # python exps/GeMOSA/basic-same.py --env_version v4 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
| print("LIB-DIR: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
|  | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
| ) | ||||
| from xautodl.log_utils import time_string | ||||
| from xautodl.log_utils import AverageMeter, convert_secs2time | ||||
|  | ||||
| from xautodl.utils import split_str2indexes | ||||
|  | ||||
| from xautodl.procedures.metric_utils import ( | ||||
|     SaveMetric, | ||||
|     MSEMetric, | ||||
|     Top1AccMetric, | ||||
|     ComposeMetric, | ||||
| ) | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| from xautodl.models.xcore import get_model | ||||
|  | ||||
|  | ||||
| def subsample(historical_x, historical_y, maxn=10000): | ||||
|     total = historical_x.size(0) | ||||
|     if total <= maxn: | ||||
|         return historical_x, historical_y | ||||
|     else: | ||||
|         indexes = torch.randint(low=0, high=total, size=[maxn]) | ||||
|         return historical_x[indexes], historical_y[indexes] | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|     env = get_synthetic_env(mode=None, version=args.env_version) | ||||
|     model_kwargs = dict( | ||||
|         config=dict(model_type="norm_mlp"), | ||||
|         input_dim=env.meta_info["input_dim"], | ||||
|         output_dim=env.meta_info["output_dim"], | ||||
|         hidden_dims=[args.hidden_dim] * 2, | ||||
|         act_cls="relu", | ||||
|         norm_cls="layer_norm_1d", | ||||
|     ) | ||||
|     logger.log("The total enviornment: {:}".format(env)) | ||||
|     w_containers = dict() | ||||
|  | ||||
|     if env.meta_info["task"] == "regression": | ||||
|         criterion = torch.nn.MSELoss() | ||||
|         metric_cls = MSEMetric | ||||
|     elif env.meta_info["task"] == "classification": | ||||
|         criterion = torch.nn.CrossEntropyLoss() | ||||
|         metric_cls = Top1AccMetric | ||||
|     else: | ||||
|         raise ValueError( | ||||
|             "This task ({:}) is not supported.".format(all_env.meta_info["task"]) | ||||
|         ) | ||||
|  | ||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||
|     for idx, (future_time, (future_x, future_y)) in enumerate(env): | ||||
|  | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_timestamp_time.avg * (len(env) - idx), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, len(env)) | ||||
|             + " " | ||||
|             + need_time | ||||
|         ) | ||||
|         # train the same data | ||||
|         historical_x = future_x.to(args.device) | ||||
|         historical_y = future_y.to(args.device) | ||||
|         # build model | ||||
|         model = get_model(**model_kwargs) | ||||
|         model = model.to(args.device) | ||||
|         if idx == 0: | ||||
|             print(model) | ||||
|         # build optimizer | ||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||
|         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|             optimizer, | ||||
|             milestones=[ | ||||
|                 int(args.epochs * 0.25), | ||||
|                 int(args.epochs * 0.5), | ||||
|                 int(args.epochs * 0.75), | ||||
|             ], | ||||
|             gamma=0.3, | ||||
|         ) | ||||
|         train_metric = metric_cls(True) | ||||
|         best_loss, best_param = None, None | ||||
|         for _iepoch in range(args.epochs): | ||||
|             preds = model(historical_x) | ||||
|             optimizer.zero_grad() | ||||
|             loss = criterion(preds, historical_y) | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|             lr_scheduler.step() | ||||
|             # save best | ||||
|             if best_loss is None or best_loss > loss.item(): | ||||
|                 best_loss = loss.item() | ||||
|                 best_param = copy.deepcopy(model.state_dict()) | ||||
|         model.load_state_dict(best_param) | ||||
|         model.analyze_weights() | ||||
|         with torch.no_grad(): | ||||
|             train_metric(preds, historical_y) | ||||
|         train_results = train_metric.get_info() | ||||
|  | ||||
|         xmetric = ComposeMetric(metric_cls(True), SaveMetric()) | ||||
|         eval_dataset = torch.utils.data.TensorDataset( | ||||
|             future_x.to(args.device), future_y.to(args.device) | ||||
|         ) | ||||
|         eval_loader = torch.utils.data.DataLoader( | ||||
|             eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 | ||||
|         ) | ||||
|         results = basic_eval_fn(eval_loader, model, xmetric, logger) | ||||
|         log_str = ( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, len(env)) | ||||
|             + " train-score: {:.5f}, eval-score: {:.5f}".format( | ||||
|                 train_results["score"], results["score"] | ||||
|             ) | ||||
|         ) | ||||
|         logger.log(log_str) | ||||
|  | ||||
|         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(idx, len(env)) | ||||
|         w_containers[idx] = model.get_w_container().no_grad_clone() | ||||
|         save_checkpoint( | ||||
|             { | ||||
|                 "model_state_dict": model.state_dict(), | ||||
|                 "model": model, | ||||
|                 "index": idx, | ||||
|                 "timestamp": future_time.item(), | ||||
|             }, | ||||
|             save_path, | ||||
|             logger, | ||||
|         ) | ||||
|         logger.log("") | ||||
|         per_timestamp_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     save_checkpoint( | ||||
|         {"w_containers": w_containers}, | ||||
|         logger.path(None) / "final-ckp.pth", | ||||
|         logger, | ||||
|     ) | ||||
|  | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Use the data in the past.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/GeMOSA-synthetic/use-same-timestamp", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--env_version", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         help="The synthetic enviornment version.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--hidden_dim", | ||||
|         type=int, | ||||
|         required=True, | ||||
|         help="The hidden dimension.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--init_lr", | ||||
|         type=float, | ||||
|         default=0.1, | ||||
|         help="The initial learning rate for the optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--batch_size", | ||||
|         type=int, | ||||
|         default=512, | ||||
|         help="The batch size", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--epochs", | ||||
|         type=int, | ||||
|         default=300, | ||||
|         help="The total number of epochs.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--device", | ||||
|         type=str, | ||||
|         default="cpu", | ||||
|         help="", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=4, | ||||
|         help="The number of data loading workers (default: 4)", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "The save dir argument can not be None" | ||||
|     args.save_dir = "{:}-d{:}_e{:}_lr{:}-env{:}".format( | ||||
|         args.save_dir, args.hidden_dim, args.epochs, args.init_lr, args.env_version | ||||
|     ) | ||||
|     main(args) | ||||
							
								
								
									
										438
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										438
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,438 @@ | ||||
| ########################################################## | ||||
| # Learning to Efficiently Generate Models One Step Ahead # | ||||
| ########################################################## | ||||
| # <----> run on CPU | ||||
| # python exps/GeMOSA/main.py --env_version v1 --workers 0 | ||||
| # <----> run on a GPU | ||||
| # python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda | ||||
| # python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda | ||||
| # python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | ||||
| # python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | ||||
| # <----> ablation commands | ||||
| # python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --ablation old --device cuda | ||||
| # python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --ablation old --device cuda | ||||
| # python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda | ||||
| # python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda | ||||
| ########################################################## | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
| from torch.nn import functional as F | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
| print("LIB-DIR: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
| ) | ||||
| from xautodl.log_utils import time_string | ||||
| from xautodl.log_utils import AverageMeter, convert_secs2time | ||||
|  | ||||
| from xautodl.utils import split_str2indexes | ||||
|  | ||||
| from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||
| from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| from xautodl.models.xcore import get_model | ||||
| from xautodl.procedures.metric_utils import MSEMetric, Top1AccMetric | ||||
|  | ||||
| from meta_model import MetaModelV1 | ||||
| from meta_model_ablation import MetaModel_TraditionalAtt | ||||
|  | ||||
|  | ||||
| def online_evaluate( | ||||
|     env, | ||||
|     meta_model, | ||||
|     base_model, | ||||
|     criterion, | ||||
|     metric, | ||||
|     args, | ||||
|     logger, | ||||
|     save=False, | ||||
|     easy_adapt=False, | ||||
| ): | ||||
|     logger.log("Online evaluate: {:}".format(env)) | ||||
|     metric.reset() | ||||
|     loss_meter = AverageMeter() | ||||
|     w_containers = dict() | ||||
|     for idx, (future_time, (future_x, future_y)) in enumerate(env): | ||||
|         with torch.no_grad(): | ||||
|             meta_model.eval() | ||||
|             base_model.eval() | ||||
|             future_time_embed = meta_model.gen_time_embed( | ||||
|                 future_time.to(args.device).view(-1) | ||||
|             ) | ||||
|             [future_container] = meta_model.gen_model(future_time_embed) | ||||
|             if save: | ||||
|                 w_containers[idx] = future_container.no_grad_clone() | ||||
|             future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||
|             future_y_hat = base_model.forward_with_container(future_x, future_container) | ||||
|             future_loss = criterion(future_y_hat, future_y) | ||||
|             loss_meter.update(future_loss.item()) | ||||
|             # accumulate the metric scores | ||||
|             score = metric(future_y_hat, future_y) | ||||
|         if easy_adapt: | ||||
|             meta_model.easy_adapt(future_time.item(), future_time_embed) | ||||
|             refine, post_refine_loss = False, -1 | ||||
|         else: | ||||
|             refine, post_refine_loss = meta_model.adapt( | ||||
|                 base_model, | ||||
|                 criterion, | ||||
|                 future_time.item(), | ||||
|                 future_x, | ||||
|                 future_y, | ||||
|                 args.refine_lr, | ||||
|                 args.refine_epochs, | ||||
|                 {"param": future_time_embed, "loss": future_loss.item()}, | ||||
|             ) | ||||
|         logger.log( | ||||
|             "[ONLINE] [{:03d}/{:03d}] loss={:.4f}, score={:.4f}".format( | ||||
|                 idx, len(env), future_loss.item(), score | ||||
|             ) | ||||
|             + ", post-loss={:.4f}".format(post_refine_loss if refine else -1) | ||||
|         ) | ||||
|     meta_model.clear_fixed() | ||||
|     meta_model.clear_learnt() | ||||
|     return w_containers, loss_meter.avg, metric.get_info()["score"] | ||||
|  | ||||
|  | ||||
| def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): | ||||
|     base_model.train() | ||||
|     meta_model.train() | ||||
|     optimizer = torch.optim.Adam( | ||||
|         meta_model.get_parameters(True, True, True), | ||||
|         lr=args.lr, | ||||
|         weight_decay=args.weight_decay, | ||||
|         amsgrad=True, | ||||
|     ) | ||||
|     logger.log("Pre-train the meta-model") | ||||
|     logger.log("Using the optimizer: {:}".format(optimizer)) | ||||
|  | ||||
|     meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v2") | ||||
|     final_best_name = "final-pretrain-{:}.pth".format(args.rand_seed) | ||||
|     if meta_model.has_best(final_best_name): | ||||
|         meta_model.load_best(final_best_name) | ||||
|         logger.log("Directly load the best model from {:}".format(final_best_name)) | ||||
|         return | ||||
|  | ||||
|     total_indexes = list(range(meta_model.meta_length)) | ||||
|     meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) | ||||
|     last_success_epoch, early_stop_thresh = 0, args.pretrain_early_stop_thresh | ||||
|     per_epoch_time, start_time = AverageMeter(), time.time() | ||||
|     device = args.device | ||||
|     for iepoch in range(args.epochs): | ||||
|         left_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||
|         ) | ||||
|         optimizer.zero_grad() | ||||
|  | ||||
|         generated_time_embeds = meta_model.gen_time_embed(meta_model.meta_timestamps) | ||||
|  | ||||
|         batch_indexes = random.choices(total_indexes, k=args.meta_batch) | ||||
|  | ||||
|         raw_time_steps = meta_model.meta_timestamps[batch_indexes] | ||||
|  | ||||
|         regularization_loss = F.l1_loss( | ||||
|             generated_time_embeds, meta_model.super_meta_embed, reduction="mean" | ||||
|         ) | ||||
|         # future loss | ||||
|         total_future_losses, total_present_losses = [], [] | ||||
|         future_containers = meta_model.gen_model(generated_time_embeds[batch_indexes]) | ||||
|         present_containers = meta_model.gen_model( | ||||
|             meta_model.super_meta_embed[batch_indexes] | ||||
|         ) | ||||
|         for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()): | ||||
|             _, (inputs, targets) = xenv(time_step) | ||||
|             inputs, targets = inputs.to(device), targets.to(device) | ||||
|  | ||||
|             predictions = base_model.forward_with_container( | ||||
|                 inputs, future_containers[ibatch] | ||||
|             ) | ||||
|             total_future_losses.append(criterion(predictions, targets)) | ||||
|  | ||||
|             predictions = base_model.forward_with_container( | ||||
|                 inputs, present_containers[ibatch] | ||||
|             ) | ||||
|             total_present_losses.append(criterion(predictions, targets)) | ||||
|  | ||||
|         with torch.no_grad(): | ||||
|             meta_std = torch.stack(total_future_losses).std().item() | ||||
|         loss_future = torch.stack(total_future_losses).mean() | ||||
|         loss_present = torch.stack(total_present_losses).mean() | ||||
|         total_loss = loss_future + loss_present + regularization_loss | ||||
|         total_loss.backward() | ||||
|         optimizer.step() | ||||
|         # success | ||||
|         success, best_score = meta_model.save_best(-total_loss.item()) | ||||
|         logger.log( | ||||
|             "{:} [META {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f}".format( | ||||
|                 time_string(), | ||||
|                 iepoch, | ||||
|                 args.epochs, | ||||
|                 total_loss.item(), | ||||
|                 meta_std, | ||||
|                 loss_future.item(), | ||||
|                 loss_present.item(), | ||||
|                 regularization_loss.item(), | ||||
|             ) | ||||
|             + ", batch={:}".format(len(total_future_losses)) | ||||
|             + ", success={:}, best={:.4f}".format(success, -best_score) | ||||
|             + ", LS={:}/{:}".format(iepoch - last_success_epoch, early_stop_thresh) | ||||
|             + ", {:}".format(left_time) | ||||
|         ) | ||||
|         if success: | ||||
|             last_success_epoch = iepoch | ||||
|         if iepoch - last_success_epoch >= early_stop_thresh: | ||||
|             logger.log("Early stop the pre-training at {:}".format(iepoch)) | ||||
|             break | ||||
|         per_epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|     meta_model.load_best() | ||||
|     # save to the final model | ||||
|     meta_model.set_best_name(final_best_name) | ||||
|     success, _ = meta_model.save_best(best_score + 1e-6) | ||||
|     assert success | ||||
|     logger.log("Save the best model into {:}".format(final_best_name)) | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|     train_env = get_synthetic_env(mode="train", version=args.env_version) | ||||
|     valid_env = get_synthetic_env(mode="valid", version=args.env_version) | ||||
|     trainval_env = get_synthetic_env(mode="trainval", version=args.env_version) | ||||
|     test_env = get_synthetic_env(mode="test", version=args.env_version) | ||||
|     all_env = get_synthetic_env(mode=None, version=args.env_version) | ||||
|     logger.log("The training enviornment: {:}".format(train_env)) | ||||
|     logger.log("The validation enviornment: {:}".format(valid_env)) | ||||
|     logger.log("The trainval enviornment: {:}".format(trainval_env)) | ||||
|     logger.log("The total enviornment: {:}".format(all_env)) | ||||
|     logger.log("The test enviornment: {:}".format(test_env)) | ||||
|     model_kwargs = dict( | ||||
|         config=dict(model_type="norm_mlp"), | ||||
|         input_dim=all_env.meta_info["input_dim"], | ||||
|         output_dim=all_env.meta_info["output_dim"], | ||||
|         hidden_dims=[args.hidden_dim] * 2, | ||||
|         act_cls="relu", | ||||
|         norm_cls="layer_norm_1d", | ||||
|     ) | ||||
|  | ||||
|     base_model = get_model(**model_kwargs) | ||||
|     base_model = base_model.to(args.device) | ||||
|     if all_env.meta_info["task"] == "regression": | ||||
|         criterion = torch.nn.MSELoss() | ||||
|         metric = MSEMetric(True) | ||||
|     elif all_env.meta_info["task"] == "classification": | ||||
|         criterion = torch.nn.CrossEntropyLoss() | ||||
|         metric = Top1AccMetric(True) | ||||
|     else: | ||||
|         raise ValueError( | ||||
|             "This task ({:}) is not supported.".format(all_env.meta_info["task"]) | ||||
|         ) | ||||
|  | ||||
|     shape_container = base_model.get_w_container().to_shape_container() | ||||
|  | ||||
|     # pre-train the hypernetwork | ||||
|     timestamps = trainval_env.get_timestamp(None) | ||||
|     if args.ablation is None: | ||||
|         MetaModel_cls = MetaModelV1 | ||||
|     elif args.ablation == "old": | ||||
|         MetaModel_cls = MetaModel_TraditionalAtt | ||||
|     else: | ||||
|         raise ValueError("Unknown ablation : {:}".format(args.ablation)) | ||||
|     meta_model = MetaModel_cls( | ||||
|         shape_container, | ||||
|         args.layer_dim, | ||||
|         args.time_dim, | ||||
|         timestamps, | ||||
|         seq_length=args.seq_length, | ||||
|         interval=trainval_env.time_interval, | ||||
|     ) | ||||
|     meta_model = meta_model.to(args.device) | ||||
|  | ||||
|     logger.log("The base-model has {:} weights.".format(base_model.numel())) | ||||
|     logger.log("The meta-model has {:} weights.".format(meta_model.numel())) | ||||
|     logger.log("The base-model is\n{:}".format(base_model)) | ||||
|     logger.log("The meta-model is\n{:}".format(meta_model)) | ||||
|  | ||||
|     meta_train_procedure(base_model, meta_model, criterion, trainval_env, args, logger) | ||||
|  | ||||
|     # try to evaluate once | ||||
|     # online_evaluate(train_env, meta_model, base_model, criterion, args, logger) | ||||
|     # online_evaluate(valid_env, meta_model, base_model, criterion, args, logger) | ||||
|     """ | ||||
|     w_containers, loss_meter = online_evaluate( | ||||
|         all_env, meta_model, base_model, criterion, args, logger, True | ||||
|     ) | ||||
|     logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) | ||||
|     """ | ||||
|     w_containers_care_adapt, loss_adapt_v1, metric_adapt_v1 = online_evaluate( | ||||
|         test_env, meta_model, base_model, criterion, metric, args, logger, True, False | ||||
|     ) | ||||
|     w_containers_easy_adapt, loss_adapt_v2, metric_adapt_v2 = online_evaluate( | ||||
|         test_env, meta_model, base_model, criterion, metric, args, logger, True, True | ||||
|     ) | ||||
|     logger.log( | ||||
|         "[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format( | ||||
|             loss_adapt_v1, metric_adapt_v1 | ||||
|         ) | ||||
|     ) | ||||
|     logger.log( | ||||
|         "[Easy-Adapt] loss = {:.6f}, metric = {:.6f}".format( | ||||
|             loss_adapt_v2, metric_adapt_v2 | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     save_checkpoint( | ||||
|         { | ||||
|             "w_containers_care_adapt": w_containers_care_adapt, | ||||
|             "w_containers_easy_adapt": w_containers_easy_adapt, | ||||
|             "test_loss_adapt_v1": loss_adapt_v1, | ||||
|             "test_loss_adapt_v2": loss_adapt_v2, | ||||
|             "test_metric_adapt_v1": metric_adapt_v1, | ||||
|             "test_metric_adapt_v2": metric_adapt_v2, | ||||
|         }, | ||||
|         logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed), | ||||
|         logger, | ||||
|     ) | ||||
|  | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(".") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/GeMOSA-synthetic/GeMOSA", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--env_version", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         help="The synthetic enviornment version.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--hidden_dim", | ||||
|         type=int, | ||||
|         default=16, | ||||
|         help="The hidden dimension.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--layer_dim", | ||||
|         type=int, | ||||
|         default=16, | ||||
|         help="The layer chunk dimension.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--time_dim", | ||||
|         type=int, | ||||
|         default=16, | ||||
|         help="The timestamp dimension.", | ||||
|     ) | ||||
|     ##### | ||||
|     parser.add_argument( | ||||
|         "--lr", | ||||
|         type=float, | ||||
|         default=0.002, | ||||
|         help="The initial learning rate for the optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--weight_decay", | ||||
|         type=float, | ||||
|         default=0.00001, | ||||
|         help="The weight decay for the optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--meta_batch", | ||||
|         type=int, | ||||
|         default=64, | ||||
|         help="The batch size for the meta-model", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--sampler_enlarge", | ||||
|         type=int, | ||||
|         default=5, | ||||
|         help="Enlarge the #iterations for an epoch", | ||||
|     ) | ||||
|     parser.add_argument("--epochs", type=int, default=10000, help="The total #epochs.") | ||||
|     parser.add_argument( | ||||
|         "--refine_lr", | ||||
|         type=float, | ||||
|         default=0.001, | ||||
|         help="The learning rate for the optimizer, during refine", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--refine_epochs", type=int, default=150, help="The final refine #epochs." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--early_stop_thresh", | ||||
|         type=int, | ||||
|         default=20, | ||||
|         help="The #epochs for early stop.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--pretrain_early_stop_thresh", | ||||
|         type=int, | ||||
|         default=300, | ||||
|         help="The #epochs for early stop.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--seq_length", type=int, default=10, help="The sequence length." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--workers", type=int, default=4, help="The number of workers in parallel." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--ablation", type=str, default=None, help="The ablation indicator." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--device", | ||||
|         type=str, | ||||
|         default="cpu", | ||||
|         help="", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "The save dir argument can not be None" | ||||
|     if args.ablation is None: | ||||
|         args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format( | ||||
|             args.save_dir, | ||||
|             args.meta_batch, | ||||
|             args.hidden_dim, | ||||
|             args.layer_dim, | ||||
|             args.time_dim, | ||||
|             args.seq_length, | ||||
|             args.lr, | ||||
|             args.weight_decay, | ||||
|             args.epochs, | ||||
|             args.env_version, | ||||
|         ) | ||||
|     else: | ||||
|         args.save_dir = ( | ||||
|             "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-ab{:}-env{:}".format( | ||||
|                 args.save_dir, | ||||
|                 args.meta_batch, | ||||
|                 args.hidden_dim, | ||||
|                 args.layer_dim, | ||||
|                 args.time_dim, | ||||
|                 args.seq_length, | ||||
|                 args.lr, | ||||
|                 args.weight_decay, | ||||
|                 args.epochs, | ||||
|                 args.ablation, | ||||
|                 args.env_version, | ||||
|             ) | ||||
|         ) | ||||
|     main(args) | ||||
							
								
								
									
										257
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/meta_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										257
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/meta_model.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,257 @@ | ||||
| import torch | ||||
|  | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from xautodl.xlayers import super_core | ||||
| from xautodl.xlayers import trunc_normal_ | ||||
| from xautodl.xmodels.xcore import get_model | ||||
|  | ||||
|  | ||||
| class MetaModelV1(super_core.SuperModule): | ||||
|     """Learning to Generate Models One Step Ahead (Meta Model Design).""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         shape_container, | ||||
|         layer_dim, | ||||
|         time_dim, | ||||
|         meta_timestamps, | ||||
|         dropout: float = 0.1, | ||||
|         seq_length: int = None, | ||||
|         interval: float = None, | ||||
|         thresh: float = None, | ||||
|     ): | ||||
|         super(MetaModelV1, self).__init__() | ||||
|         self._shape_container = shape_container | ||||
|         self._num_layers = len(shape_container) | ||||
|         self._numel_per_layer = [] | ||||
|         for ilayer in range(self._num_layers): | ||||
|             self._numel_per_layer.append(shape_container[ilayer].numel()) | ||||
|         self._raw_meta_timestamps = meta_timestamps | ||||
|         assert interval is not None | ||||
|         self._interval = interval | ||||
|         self._thresh = interval * seq_length if thresh is None else thresh | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "_super_layer_embed", | ||||
|             torch.nn.Parameter(torch.Tensor(self._num_layers, layer_dim)), | ||||
|         ) | ||||
|         self.register_parameter( | ||||
|             "_super_meta_embed", | ||||
|             torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)), | ||||
|         ) | ||||
|         self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) | ||||
|         self._time_embed_dim = time_dim | ||||
|         self._append_meta_embed = dict(fixed=None, learnt=None) | ||||
|         self._append_meta_timestamps = dict(fixed=None, learnt=None) | ||||
|  | ||||
|         self._tscalar_embed = super_core.SuperDynamicPositionE( | ||||
|             time_dim, scale=1 / interval | ||||
|         ) | ||||
|  | ||||
|         # build transformer | ||||
|         self._trans_att = super_core.SuperQKVAttentionV2( | ||||
|             qk_att_dim=time_dim, | ||||
|             in_v_dim=time_dim, | ||||
|             hidden_dim=time_dim, | ||||
|             num_heads=4, | ||||
|             proj_dim=time_dim, | ||||
|             qkv_bias=True, | ||||
|             attn_drop=None, | ||||
|             proj_drop=dropout, | ||||
|         ) | ||||
|  | ||||
|         model_kwargs = dict( | ||||
|             config=dict(model_type="dual_norm_mlp"), | ||||
|             input_dim=layer_dim + time_dim, | ||||
|             output_dim=max(self._numel_per_layer), | ||||
|             hidden_dims=[(layer_dim + time_dim) * 2] * 3, | ||||
|             act_cls="gelu", | ||||
|             norm_cls="layer_norm_1d", | ||||
|             dropout=dropout, | ||||
|         ) | ||||
|         self._generator = get_model(**model_kwargs) | ||||
|  | ||||
|         # initialization | ||||
|         trunc_normal_( | ||||
|             [self._super_layer_embed, self._super_meta_embed], | ||||
|             std=0.02, | ||||
|         ) | ||||
|  | ||||
|     def get_parameters(self, time_embed, attention, generator): | ||||
|         parameters = [] | ||||
|         if time_embed: | ||||
|             parameters.append(self._super_meta_embed) | ||||
|         if attention: | ||||
|             parameters.extend(list(self._trans_att.parameters())) | ||||
|         if generator: | ||||
|             parameters.append(self._super_layer_embed) | ||||
|             parameters.extend(list(self._generator.parameters())) | ||||
|         return parameters | ||||
|  | ||||
|     @property | ||||
|     def meta_timestamps(self): | ||||
|         with torch.no_grad(): | ||||
|             meta_timestamps = [self._meta_timestamps] | ||||
|             for key in ("fixed", "learnt"): | ||||
|                 if self._append_meta_timestamps[key] is not None: | ||||
|                     meta_timestamps.append(self._append_meta_timestamps[key]) | ||||
|         return torch.cat(meta_timestamps) | ||||
|  | ||||
|     @property | ||||
|     def super_meta_embed(self): | ||||
|         meta_embed = [self._super_meta_embed] | ||||
|         for key in ("fixed", "learnt"): | ||||
|             if self._append_meta_embed[key] is not None: | ||||
|                 meta_embed.append(self._append_meta_embed[key]) | ||||
|         return torch.cat(meta_embed) | ||||
|  | ||||
|     def create_meta_embed(self): | ||||
|         param = torch.Tensor(1, self._time_embed_dim) | ||||
|         trunc_normal_(param, std=0.02) | ||||
|         param = param.to(self._super_meta_embed.device) | ||||
|         param = torch.nn.Parameter(param, True) | ||||
|         return param | ||||
|  | ||||
|     def get_closest_meta_distance(self, timestamp): | ||||
|         with torch.no_grad(): | ||||
|             distances = torch.abs(self.meta_timestamps - timestamp) | ||||
|             return torch.min(distances).item() | ||||
|  | ||||
|     def replace_append_learnt(self, timestamp, meta_embed): | ||||
|         self._append_meta_timestamps["learnt"] = timestamp | ||||
|         self._append_meta_embed["learnt"] = meta_embed | ||||
|  | ||||
|     @property | ||||
|     def meta_length(self): | ||||
|         return self.meta_timestamps.numel() | ||||
|  | ||||
|     def clear_fixed(self): | ||||
|         self._append_meta_timestamps["fixed"] = None | ||||
|         self._append_meta_embed["fixed"] = None | ||||
|  | ||||
|     def clear_learnt(self): | ||||
|         self.replace_append_learnt(None, None) | ||||
|  | ||||
|     def append_fixed(self, timestamp, meta_embed): | ||||
|         with torch.no_grad(): | ||||
|             device = self._super_meta_embed.device | ||||
|             timestamp = timestamp.detach().clone().to(device) | ||||
|             meta_embed = meta_embed.detach().clone().to(device) | ||||
|             if self._append_meta_timestamps["fixed"] is None: | ||||
|                 self._append_meta_timestamps["fixed"] = timestamp | ||||
|             else: | ||||
|                 self._append_meta_timestamps["fixed"] = torch.cat( | ||||
|                     (self._append_meta_timestamps["fixed"], timestamp), dim=0 | ||||
|                 ) | ||||
|             if self._append_meta_embed["fixed"] is None: | ||||
|                 self._append_meta_embed["fixed"] = meta_embed | ||||
|             else: | ||||
|                 self._append_meta_embed["fixed"] = torch.cat( | ||||
|                     (self._append_meta_embed["fixed"], meta_embed), dim=0 | ||||
|                 ) | ||||
|  | ||||
|     def gen_time_embed(self, timestamps): | ||||
|         # timestamps is a batch of timestamps | ||||
|         [B] = timestamps.shape | ||||
|         # batch, seq = timestamps.shape | ||||
|         timestamps = timestamps.view(-1, 1) | ||||
|         meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed | ||||
|         timestamp_v_embed = meta_embeds.unsqueeze(dim=0) | ||||
|         timestamp_qk_att_embed = self._tscalar_embed( | ||||
|             torch.unsqueeze(timestamps, dim=-1) - meta_timestamps | ||||
|         ) | ||||
|         # create the mask | ||||
|         mask = ( | ||||
|             torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1) | ||||
|         ) | ( | ||||
|             torch.abs( | ||||
|                 torch.unsqueeze(timestamps, dim=-1) - meta_timestamps.view(1, 1, -1) | ||||
|             ) | ||||
|             > self._thresh | ||||
|         ) | ||||
|         timestamp_embeds = self._trans_att( | ||||
|             timestamp_qk_att_embed, | ||||
|             timestamp_v_embed, | ||||
|             mask, | ||||
|         ) | ||||
|         return timestamp_embeds[:, -1, :] | ||||
|  | ||||
|     def gen_model(self, time_embeds): | ||||
|         B, _ = time_embeds.shape | ||||
|         # create joint embed | ||||
|         num_layer, _ = self._super_layer_embed.shape | ||||
|         # The shape of `joint_embed` is batch * num-layers * input-dim | ||||
|         joint_embeds = torch.cat( | ||||
|             ( | ||||
|                 time_embeds.view(B, 1, -1).expand(-1, num_layer, -1), | ||||
|                 self._super_layer_embed.view(1, num_layer, -1).expand(B, -1, -1), | ||||
|             ), | ||||
|             dim=-1, | ||||
|         ) | ||||
|         batch_weights = self._generator(joint_embeds) | ||||
|         batch_containers = [] | ||||
|         for weights in torch.split(batch_weights, 1): | ||||
|             batch_containers.append( | ||||
|                 self._shape_container.translate(torch.split(weights.squeeze(0), 1)) | ||||
|             ) | ||||
|         return batch_containers | ||||
|  | ||||
|     def forward_raw(self, timestamps, time_embeds, tembed_only=False): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def forward_candidate(self, input): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def easy_adapt(self, timestamp, time_embed): | ||||
|         with torch.no_grad(): | ||||
|             timestamp = torch.Tensor([timestamp]).to(self._meta_timestamps.device) | ||||
|             self.replace_append_learnt(None, None) | ||||
|             self.append_fixed(timestamp, time_embed) | ||||
|  | ||||
|     def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info): | ||||
|         distance = self.get_closest_meta_distance(timestamp) | ||||
|         if distance + self._interval * 1e-2 <= self._interval: | ||||
|             return False, None | ||||
|         x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device) | ||||
|         with torch.set_grad_enabled(True): | ||||
|             new_param = self.create_meta_embed() | ||||
|  | ||||
|             optimizer = torch.optim.Adam( | ||||
|                 [new_param], lr=lr, weight_decay=1e-5, amsgrad=True | ||||
|             ) | ||||
|             timestamp = torch.Tensor([timestamp]).to(new_param.device) | ||||
|             self.replace_append_learnt(timestamp, new_param) | ||||
|             self.train() | ||||
|             base_model.train() | ||||
|             if init_info is not None: | ||||
|                 best_loss = init_info["loss"] | ||||
|                 new_param.data.copy_(init_info["param"].data) | ||||
|             else: | ||||
|                 best_loss = 1e9 | ||||
|             with torch.no_grad(): | ||||
|                 best_new_param = new_param.detach().clone() | ||||
|             for iepoch in range(epochs): | ||||
|                 optimizer.zero_grad() | ||||
|                 time_embed = self.gen_time_embed(timestamp.view(1)) | ||||
|                 match_loss = F.l1_loss(new_param, time_embed) | ||||
|  | ||||
|                 [container] = self.gen_model(new_param.view(1, -1)) | ||||
|                 y_hat = base_model.forward_with_container(x, container) | ||||
|                 meta_loss = criterion(y_hat, y) | ||||
|                 loss = meta_loss + match_loss | ||||
|                 loss.backward() | ||||
|                 optimizer.step() | ||||
|                 if meta_loss.item() < best_loss: | ||||
|                     with torch.no_grad(): | ||||
|                         best_loss = meta_loss.item() | ||||
|                         best_new_param = new_param.detach().clone() | ||||
|         self.easy_adapt(timestamp, best_new_param) | ||||
|         return True, best_loss | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format( | ||||
|             list(self._super_layer_embed.shape), | ||||
|             list(self._super_meta_embed.shape), | ||||
|             list(self._meta_timestamps.shape), | ||||
|         ) | ||||
							
								
								
									
										260
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/meta_model_ablation.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										260
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/meta_model_ablation.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,260 @@ | ||||
| # | ||||
| # This is used for the ablation studies: | ||||
| # The meta-model in this file uses the traditional attention in | ||||
| # transformer. | ||||
| # | ||||
| import torch | ||||
|  | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from xautodl.xlayers import super_core | ||||
| from xautodl.xlayers import trunc_normal_ | ||||
| from xautodl.models.xcore import get_model | ||||
|  | ||||
|  | ||||
| class MetaModel_TraditionalAtt(super_core.SuperModule): | ||||
|     """Learning to Generate Models One Step Ahead (Meta Model Design).""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         shape_container, | ||||
|         layer_dim, | ||||
|         time_dim, | ||||
|         meta_timestamps, | ||||
|         dropout: float = 0.1, | ||||
|         seq_length: int = None, | ||||
|         interval: float = None, | ||||
|         thresh: float = None, | ||||
|     ): | ||||
|         super(MetaModel_TraditionalAtt, self).__init__() | ||||
|         self._shape_container = shape_container | ||||
|         self._num_layers = len(shape_container) | ||||
|         self._numel_per_layer = [] | ||||
|         for ilayer in range(self._num_layers): | ||||
|             self._numel_per_layer.append(shape_container[ilayer].numel()) | ||||
|         self._raw_meta_timestamps = meta_timestamps | ||||
|         assert interval is not None | ||||
|         self._interval = interval | ||||
|         self._thresh = interval * seq_length if thresh is None else thresh | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "_super_layer_embed", | ||||
|             torch.nn.Parameter(torch.Tensor(self._num_layers, layer_dim)), | ||||
|         ) | ||||
|         self.register_parameter( | ||||
|             "_super_meta_embed", | ||||
|             torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)), | ||||
|         ) | ||||
|         self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) | ||||
|         self._time_embed_dim = time_dim | ||||
|         self._append_meta_embed = dict(fixed=None, learnt=None) | ||||
|         self._append_meta_timestamps = dict(fixed=None, learnt=None) | ||||
|  | ||||
|         self._tscalar_embed = super_core.SuperDynamicPositionE( | ||||
|             time_dim, scale=1 / interval | ||||
|         ) | ||||
|  | ||||
|         # build transformer | ||||
|         self._trans_att = super_core.SuperQKVAttention( | ||||
|             in_q_dim=time_dim, | ||||
|             in_k_dim=time_dim, | ||||
|             in_v_dim=time_dim, | ||||
|             num_heads=4, | ||||
|             proj_dim=time_dim, | ||||
|             qkv_bias=True, | ||||
|             attn_drop=None, | ||||
|             proj_drop=dropout, | ||||
|         ) | ||||
|  | ||||
|         model_kwargs = dict( | ||||
|             config=dict(model_type="dual_norm_mlp"), | ||||
|             input_dim=layer_dim + time_dim, | ||||
|             output_dim=max(self._numel_per_layer), | ||||
|             hidden_dims=[(layer_dim + time_dim) * 2] * 3, | ||||
|             act_cls="gelu", | ||||
|             norm_cls="layer_norm_1d", | ||||
|             dropout=dropout, | ||||
|         ) | ||||
|         self._generator = get_model(**model_kwargs) | ||||
|  | ||||
|         # initialization | ||||
|         trunc_normal_( | ||||
|             [self._super_layer_embed, self._super_meta_embed], | ||||
|             std=0.02, | ||||
|         ) | ||||
|  | ||||
|     def get_parameters(self, time_embed, attention, generator): | ||||
|         parameters = [] | ||||
|         if time_embed: | ||||
|             parameters.append(self._super_meta_embed) | ||||
|         if attention: | ||||
|             parameters.extend(list(self._trans_att.parameters())) | ||||
|         if generator: | ||||
|             parameters.append(self._super_layer_embed) | ||||
|             parameters.extend(list(self._generator.parameters())) | ||||
|         return parameters | ||||
|  | ||||
|     @property | ||||
|     def meta_timestamps(self): | ||||
|         with torch.no_grad(): | ||||
|             meta_timestamps = [self._meta_timestamps] | ||||
|             for key in ("fixed", "learnt"): | ||||
|                 if self._append_meta_timestamps[key] is not None: | ||||
|                     meta_timestamps.append(self._append_meta_timestamps[key]) | ||||
|         return torch.cat(meta_timestamps) | ||||
|  | ||||
|     @property | ||||
|     def super_meta_embed(self): | ||||
|         meta_embed = [self._super_meta_embed] | ||||
|         for key in ("fixed", "learnt"): | ||||
|             if self._append_meta_embed[key] is not None: | ||||
|                 meta_embed.append(self._append_meta_embed[key]) | ||||
|         return torch.cat(meta_embed) | ||||
|  | ||||
|     def create_meta_embed(self): | ||||
|         param = torch.Tensor(1, self._time_embed_dim) | ||||
|         trunc_normal_(param, std=0.02) | ||||
|         param = param.to(self._super_meta_embed.device) | ||||
|         param = torch.nn.Parameter(param, True) | ||||
|         return param | ||||
|  | ||||
|     def get_closest_meta_distance(self, timestamp): | ||||
|         with torch.no_grad(): | ||||
|             distances = torch.abs(self.meta_timestamps - timestamp) | ||||
|             return torch.min(distances).item() | ||||
|  | ||||
|     def replace_append_learnt(self, timestamp, meta_embed): | ||||
|         self._append_meta_timestamps["learnt"] = timestamp | ||||
|         self._append_meta_embed["learnt"] = meta_embed | ||||
|  | ||||
|     @property | ||||
|     def meta_length(self): | ||||
|         return self.meta_timestamps.numel() | ||||
|  | ||||
|     def clear_fixed(self): | ||||
|         self._append_meta_timestamps["fixed"] = None | ||||
|         self._append_meta_embed["fixed"] = None | ||||
|  | ||||
|     def clear_learnt(self): | ||||
|         self.replace_append_learnt(None, None) | ||||
|  | ||||
|     def append_fixed(self, timestamp, meta_embed): | ||||
|         with torch.no_grad(): | ||||
|             device = self._super_meta_embed.device | ||||
|             timestamp = timestamp.detach().clone().to(device) | ||||
|             meta_embed = meta_embed.detach().clone().to(device) | ||||
|             if self._append_meta_timestamps["fixed"] is None: | ||||
|                 self._append_meta_timestamps["fixed"] = timestamp | ||||
|             else: | ||||
|                 self._append_meta_timestamps["fixed"] = torch.cat( | ||||
|                     (self._append_meta_timestamps["fixed"], timestamp), dim=0 | ||||
|                 ) | ||||
|             if self._append_meta_embed["fixed"] is None: | ||||
|                 self._append_meta_embed["fixed"] = meta_embed | ||||
|             else: | ||||
|                 self._append_meta_embed["fixed"] = torch.cat( | ||||
|                     (self._append_meta_embed["fixed"], meta_embed), dim=0 | ||||
|                 ) | ||||
|  | ||||
|     def gen_time_embed(self, timestamps): | ||||
|         # timestamps is a batch of timestamps | ||||
|         [B] = timestamps.shape | ||||
|         # batch, seq = timestamps.shape | ||||
|         timestamps = timestamps.view(-1, 1) | ||||
|         meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed | ||||
|         timestamp_v_embed = meta_embeds.unsqueeze(dim=0) | ||||
|         timestamp_q_embed = self._tscalar_embed(timestamps) | ||||
|         timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1)) | ||||
|  | ||||
|         # create the mask | ||||
|         mask = ( | ||||
|             torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1) | ||||
|         ) | ( | ||||
|             torch.abs( | ||||
|                 torch.unsqueeze(timestamps, dim=-1) - meta_timestamps.view(1, 1, -1) | ||||
|             ) | ||||
|             > self._thresh | ||||
|         ) | ||||
|         timestamp_embeds = self._trans_att( | ||||
|             timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask | ||||
|         ) | ||||
|         return timestamp_embeds[:, -1, :] | ||||
|  | ||||
|     def gen_model(self, time_embeds): | ||||
|         B, _ = time_embeds.shape | ||||
|         # create joint embed | ||||
|         num_layer, _ = self._super_layer_embed.shape | ||||
|         # The shape of `joint_embed` is batch * num-layers * input-dim | ||||
|         joint_embeds = torch.cat( | ||||
|             ( | ||||
|                 time_embeds.view(B, 1, -1).expand(-1, num_layer, -1), | ||||
|                 self._super_layer_embed.view(1, num_layer, -1).expand(B, -1, -1), | ||||
|             ), | ||||
|             dim=-1, | ||||
|         ) | ||||
|         batch_weights = self._generator(joint_embeds) | ||||
|         batch_containers = [] | ||||
|         for weights in torch.split(batch_weights, 1): | ||||
|             batch_containers.append( | ||||
|                 self._shape_container.translate(torch.split(weights.squeeze(0), 1)) | ||||
|             ) | ||||
|         return batch_containers | ||||
|  | ||||
|     def forward_raw(self, timestamps, time_embeds, tembed_only=False): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def forward_candidate(self, input): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def easy_adapt(self, timestamp, time_embed): | ||||
|         with torch.no_grad(): | ||||
|             timestamp = torch.Tensor([timestamp]).to(self._meta_timestamps.device) | ||||
|             self.replace_append_learnt(None, None) | ||||
|             self.append_fixed(timestamp, time_embed) | ||||
|  | ||||
|     def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info): | ||||
|         distance = self.get_closest_meta_distance(timestamp) | ||||
|         if distance + self._interval * 1e-2 <= self._interval: | ||||
|             return False, None | ||||
|         x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device) | ||||
|         with torch.set_grad_enabled(True): | ||||
|             new_param = self.create_meta_embed() | ||||
|  | ||||
|             optimizer = torch.optim.Adam( | ||||
|                 [new_param], lr=lr, weight_decay=1e-5, amsgrad=True | ||||
|             ) | ||||
|             timestamp = torch.Tensor([timestamp]).to(new_param.device) | ||||
|             self.replace_append_learnt(timestamp, new_param) | ||||
|             self.train() | ||||
|             base_model.train() | ||||
|             if init_info is not None: | ||||
|                 best_loss = init_info["loss"] | ||||
|                 new_param.data.copy_(init_info["param"].data) | ||||
|             else: | ||||
|                 best_loss = 1e9 | ||||
|             with torch.no_grad(): | ||||
|                 best_new_param = new_param.detach().clone() | ||||
|             for iepoch in range(epochs): | ||||
|                 optimizer.zero_grad() | ||||
|                 time_embed = self.gen_time_embed(timestamp.view(1)) | ||||
|                 match_loss = F.l1_loss(new_param, time_embed) | ||||
|  | ||||
|                 [container] = self.gen_model(new_param.view(1, -1)) | ||||
|                 y_hat = base_model.forward_with_container(x, container) | ||||
|                 meta_loss = criterion(y_hat, y) | ||||
|                 loss = meta_loss + match_loss | ||||
|                 loss.backward() | ||||
|                 optimizer.step() | ||||
|                 if meta_loss.item() < best_loss: | ||||
|                     with torch.no_grad(): | ||||
|                         best_loss = meta_loss.item() | ||||
|                         best_new_param = new_param.detach().clone() | ||||
|         self.easy_adapt(timestamp, best_new_param) | ||||
|         return True, best_loss | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format( | ||||
|             list(self._super_layer_embed.shape), | ||||
|             list(self._super_meta_embed.shape), | ||||
|             list(self._meta_timestamps.shape), | ||||
|         ) | ||||
							
								
								
									
										441
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/vis-synthetic.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										441
									
								
								AutoDL-Projects/exps/experimental/GeMOSA/vis-synthetic.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,441 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||
| ############################################################################ | ||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v1                     # | ||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v2                     # | ||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v3                     # | ||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v4                     # | ||||
| ############################################################################ | ||||
| import os, sys, copy, random | ||||
| import torch | ||||
| import numpy as np | ||||
| import argparse | ||||
| from collections import OrderedDict, defaultdict | ||||
| from pathlib import Path | ||||
| from tqdm import tqdm | ||||
| from pprint import pprint | ||||
|  | ||||
| import matplotlib | ||||
| from matplotlib import cm | ||||
|  | ||||
| matplotlib.use("agg") | ||||
| import matplotlib.pyplot as plt | ||||
| import matplotlib.ticker as ticker | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from xautodl.models.xcore import get_model | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| from xautodl.procedures.metric_utils import MSEMetric | ||||
|  | ||||
|  | ||||
| def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None): | ||||
|     cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths[0], label=label) | ||||
|     cur_ax.scatter( | ||||
|         xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None): | ||||
|     save_path = save_dir / "{:04d}".format(timestamp) | ||||
|     # print('Plot the figure at timestamp-{:} into {:}'.format(timestamp, save_path)) | ||||
|     dpi, width, height = 40, wh[0], wh[1] | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize, font_gap = 80, 80, 5 | ||||
|  | ||||
|     fig = plt.figure(figsize=figsize) | ||||
|     if fig_title is not None: | ||||
|         fig.suptitle( | ||||
|             fig_title, fontsize=LegendFontsize, fontweight="bold", x=0.5, y=0.92 | ||||
|         ) | ||||
|  | ||||
|     for idx, scatter_dict in enumerate(scatter_list): | ||||
|         cur_ax = fig.add_subplot(len(scatter_list), 1, idx + 1) | ||||
|         plot_scatter( | ||||
|             cur_ax, | ||||
|             scatter_dict["xaxis"], | ||||
|             scatter_dict["yaxis"], | ||||
|             scatter_dict["color"], | ||||
|             scatter_dict["alpha"], | ||||
|             scatter_dict["linewidths"], | ||||
|             scatter_dict["label"], | ||||
|         ) | ||||
|         cur_ax.set_xlabel("X", fontsize=LabelSize) | ||||
|         cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize) | ||||
|         cur_ax.set_xlim(scatter_dict["xlim"][0], scatter_dict["xlim"][1]) | ||||
|         cur_ax.set_ylim(scatter_dict["ylim"][0], scatter_dict["ylim"][1]) | ||||
|         for tick in cur_ax.xaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|             tick.label.set_rotation(10) | ||||
|         for tick in cur_ax.yaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||
|     fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|     fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") | ||||
|     plt.close("all") | ||||
|  | ||||
|  | ||||
| def find_min(cur, others): | ||||
|     if cur is None: | ||||
|         return float(others) | ||||
|     else: | ||||
|         return float(min(cur, others)) | ||||
|  | ||||
|  | ||||
| def find_max(cur, others): | ||||
|     if cur is None: | ||||
|         return float(others.max()) | ||||
|     else: | ||||
|         return float(max(cur, others)) | ||||
|  | ||||
|  | ||||
| def compare_cl(save_dir): | ||||
|     save_dir = Path(str(save_dir)) | ||||
|     save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     dynamic_env, cl_function = create_example_v1( | ||||
|         # timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0), | ||||
|         timestamp_config=dict(num=200), | ||||
|         num_per_task=1000, | ||||
|     ) | ||||
|  | ||||
|     models = dict() | ||||
|  | ||||
|     cl_function.set_timestamp(0) | ||||
|     cl_xaxis_min = None | ||||
|     cl_xaxis_max = None | ||||
|  | ||||
|     all_data = OrderedDict() | ||||
|  | ||||
|     for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         xaxis_all = dataset[0][:, 0].numpy() | ||||
|         yaxis_all = dataset[1][:, 0].numpy() | ||||
|         current_data = dict() | ||||
|         current_data["lfna_xaxis_all"] = xaxis_all | ||||
|         current_data["lfna_yaxis_all"] = yaxis_all | ||||
|  | ||||
|         # compute cl-min | ||||
|         cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all.mean() - xaxis_all.std()) | ||||
|         cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all.mean() + xaxis_all.std()) | ||||
|         all_data[timestamp] = current_data | ||||
|  | ||||
|     global_cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.1) | ||||
|     global_cl_yaxis_all = cl_function.noise_call(global_cl_xaxis_all) | ||||
|  | ||||
|     for idx, (timestamp, xdata) in enumerate(tqdm(all_data.items(), ncols=50)): | ||||
|         scatter_list = [] | ||||
|         scatter_list.append( | ||||
|             { | ||||
|                 "xaxis": xdata["lfna_xaxis_all"], | ||||
|                 "yaxis": xdata["lfna_yaxis_all"], | ||||
|                 "color": "k", | ||||
|                 "linewidths": 15, | ||||
|                 "alpha": 0.99, | ||||
|                 "xlim": (-6, 6), | ||||
|                 "ylim": (-40, 40), | ||||
|                 "label": "LFNA", | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|         cur_cl_xaxis_min = cl_xaxis_min | ||||
|         cur_cl_xaxis_max = cl_xaxis_min + (cl_xaxis_max - cl_xaxis_min) * ( | ||||
|             idx + 1 | ||||
|         ) / len(all_data) | ||||
|         cl_xaxis_all = np.arange(cur_cl_xaxis_min, cur_cl_xaxis_max, step=0.01) | ||||
|         cl_yaxis_all = cl_function.noise_call(cl_xaxis_all, std=0.2) | ||||
|  | ||||
|         scatter_list.append( | ||||
|             { | ||||
|                 "xaxis": cl_xaxis_all, | ||||
|                 "yaxis": cl_yaxis_all, | ||||
|                 "color": "k", | ||||
|                 "linewidths": 15, | ||||
|                 "xlim": (round(cl_xaxis_min, 1), round(cl_xaxis_max, 1)), | ||||
|                 "ylim": (-20, 6), | ||||
|                 "alpha": 0.99, | ||||
|                 "label": "Continual Learning", | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|         draw_multi_fig( | ||||
|             save_dir, | ||||
|             idx, | ||||
|             scatter_list, | ||||
|             wh=(2200, 1800), | ||||
|             fig_title="Timestamp={:03d}".format(idx), | ||||
|         ) | ||||
|     print("Save all figures into {:}".format(save_dir)) | ||||
|     save_dir = save_dir.resolve() | ||||
|     base_cmd = ( | ||||
|         "ffmpeg -y -i {xdir}/%04d.png -vf fps=1 -vf scale=2200:1800 -vb 5000k".format( | ||||
|             xdir=save_dir | ||||
|         ) | ||||
|     ) | ||||
|     video_cmd = "{:} -pix_fmt yuv420p {xdir}/compare-cl.mp4".format( | ||||
|         base_cmd, xdir=save_dir | ||||
|     ) | ||||
|     print(video_cmd + "\n") | ||||
|     os.system(video_cmd) | ||||
|     os.system( | ||||
|         "{:} -pix_fmt yuv420p {xdir}/compare-cl.webm".format(base_cmd, xdir=save_dir) | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def visualize_env(save_dir, version): | ||||
|     save_dir = Path(str(save_dir)) | ||||
|     for substr in ("pdf", "png"): | ||||
|         sub_save_dir = save_dir / "{:}-{:}".format(substr, version) | ||||
|         sub_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     dynamic_env = get_synthetic_env(version=version) | ||||
|     print("env: {:}".format(dynamic_env)) | ||||
|     print("oracle_map: {:}".format(dynamic_env.oracle_map)) | ||||
|     allxs, allys = [], [] | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         allxs.append(allx) | ||||
|         allys.append(ally) | ||||
|     if dynamic_env.meta_info["task"] == "regression": | ||||
|         allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) | ||||
|         print( | ||||
|             "x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()) | ||||
|         ) | ||||
|         print( | ||||
|             "y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()) | ||||
|         ) | ||||
|     elif dynamic_env.meta_info["task"] == "classification": | ||||
|         allxs = torch.cat(allxs) | ||||
|         print( | ||||
|             "x[0] - min={:.3f}, max={:.3f}".format( | ||||
|                 allxs[:, 0].min().item(), allxs[:, 0].max().item() | ||||
|             ) | ||||
|         ) | ||||
|         print( | ||||
|             "x[1] - min={:.3f}, max={:.3f}".format( | ||||
|                 allxs[:, 1].min().item(), allxs[:, 1].max().item() | ||||
|             ) | ||||
|         ) | ||||
|     else: | ||||
|         raise ValueError("Unknown task".format(dynamic_env.meta_info["task"])) | ||||
|  | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         dpi, width, height = 30, 1800, 1400 | ||||
|         figsize = width / float(dpi), height / float(dpi) | ||||
|         LabelSize, LegendFontsize, font_gap = 80, 80, 5 | ||||
|         fig = plt.figure(figsize=figsize) | ||||
|  | ||||
|         cur_ax = fig.add_subplot(1, 1, 1) | ||||
|         if dynamic_env.meta_info["task"] == "regression": | ||||
|             allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy() | ||||
|             plot_scatter( | ||||
|                 cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx) | ||||
|             ) | ||||
|             cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1)) | ||||
|             cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) | ||||
|         elif dynamic_env.meta_info["task"] == "classification": | ||||
|             positive, negative = ally == 1, ally == 0 | ||||
|             # plot_scatter(cur_ax, [1], [1], "k", 0.1, 1, "timestamp={:05d}".format(idx)) | ||||
|             plot_scatter( | ||||
|                 cur_ax, | ||||
|                 allx[positive, 0], | ||||
|                 allx[positive, 1], | ||||
|                 "r", | ||||
|                 0.99, | ||||
|                 (20, 10), | ||||
|                 "positive", | ||||
|             ) | ||||
|             plot_scatter( | ||||
|                 cur_ax, | ||||
|                 allx[negative, 0], | ||||
|                 allx[negative, 1], | ||||
|                 "g", | ||||
|                 0.99, | ||||
|                 (20, 10), | ||||
|                 "negative", | ||||
|             ) | ||||
|             cur_ax.set_xlim( | ||||
|                 round(allxs[:, 0].min().item(), 1), round(allxs[:, 0].max().item(), 1) | ||||
|             ) | ||||
|             cur_ax.set_ylim( | ||||
|                 round(allxs[:, 1].min().item(), 1), round(allxs[:, 1].max().item(), 1) | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("Unknown task".format(dynamic_env.meta_info["task"])) | ||||
|  | ||||
|         cur_ax.set_xlabel("X", fontsize=LabelSize) | ||||
|         cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize) | ||||
|         for tick in cur_ax.xaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|             tick.label.set_rotation(10) | ||||
|         for tick in cur_ax.yaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||
|         pdf_save_path = ( | ||||
|             save_dir | ||||
|             / "pdf-{:}".format(version) | ||||
|             / "v{:}-{:05d}.pdf".format(version, idx) | ||||
|         ) | ||||
|         fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|         png_save_path = ( | ||||
|             save_dir | ||||
|             / "png-{:}".format(version) | ||||
|             / "v{:}-{:05d}.png".format(version, idx) | ||||
|         ) | ||||
|         fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png") | ||||
|         plt.close("all") | ||||
|     save_dir = save_dir.resolve() | ||||
|     base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format( | ||||
|         xdir=save_dir / "png-{:}".format(version), version=version | ||||
|     ) | ||||
|     print(base_cmd) | ||||
|     os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)) | ||||
|     os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version)) | ||||
|  | ||||
|  | ||||
| def compare_algs(save_dir, version, alg_dir="./outputs/GeMOSA-synthetic"): | ||||
|     save_dir = Path(str(save_dir)) | ||||
|     for substr in ("pdf", "png"): | ||||
|         sub_save_dir = save_dir / substr | ||||
|         sub_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     dpi, width, height = 30, 3200, 2000 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize, font_gap = 80, 80, 5 | ||||
|  | ||||
|     dynamic_env = get_synthetic_env(mode=None, version=version) | ||||
|     allxs, allys = [], [] | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         allxs.append(allx) | ||||
|         allys.append(ally) | ||||
|     allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) | ||||
|  | ||||
|     alg_name2dir = OrderedDict() | ||||
|     # alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" | ||||
|     # alg_name2dir["MAML"] = "use-maml-s1" | ||||
|     # alg_name2dir["LFNA (fix init)"] = "lfna-fix-init" | ||||
|     if version == "v1": | ||||
|         # alg_name2dir["Optimal"] = "use-same-timestamp" | ||||
|         alg_name2dir[ | ||||
|             "GMOA" | ||||
|         ] = "lfna-battle-bs128-d16_16_16-s16-lr0.002-wd1e-05-e10000-envv1" | ||||
|     else: | ||||
|         raise ValueError("Invalid version: {:}".format(version)) | ||||
|     alg_name2all_containers = OrderedDict() | ||||
|     for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): | ||||
|         ckp_path = Path(alg_dir) / str(xdir) / "final-ckp.pth" | ||||
|         xdata = torch.load(ckp_path, map_location="cpu") | ||||
|         alg_name2all_containers[alg] = xdata["w_containers"] | ||||
|     # load the basic model | ||||
|     model = get_model( | ||||
|         dict(model_type="norm_mlp"), | ||||
|         input_dim=1, | ||||
|         output_dim=1, | ||||
|         hidden_dims=[16] * 2, | ||||
|         act_cls="gelu", | ||||
|         norm_cls="layer_norm_1d", | ||||
|     ) | ||||
|  | ||||
|     alg2xs, alg2ys = defaultdict(list), defaultdict(list) | ||||
|     colors = ["r", "g", "b", "m", "y"] | ||||
|  | ||||
|     linewidths, skip = 10, 5 | ||||
|     for idx, (timestamp, (ori_allx, ori_ally)) in enumerate( | ||||
|         tqdm(dynamic_env, ncols=50) | ||||
|     ): | ||||
|         if idx <= skip: | ||||
|             continue | ||||
|         fig = plt.figure(figsize=figsize) | ||||
|         cur_ax = fig.add_subplot(2, 1, 1) | ||||
|  | ||||
|         # the data | ||||
|         allx, ally = ori_allx[:, 0].numpy(), ori_ally[:, 0].numpy() | ||||
|         plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data") | ||||
|  | ||||
|         for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): | ||||
|             with torch.no_grad(): | ||||
|                 predicts = model.forward_with_container( | ||||
|                     ori_allx, alg_name2all_containers[alg][idx] | ||||
|                 ) | ||||
|                 predicts = predicts.cpu() | ||||
|                 # keep data | ||||
|                 metric = MSEMetric() | ||||
|                 metric(predicts, ori_ally) | ||||
|                 predicts = predicts.view(-1).numpy() | ||||
|                 alg2xs[alg].append(idx) | ||||
|                 alg2ys[alg].append(metric.get_info()["mse"]) | ||||
|             plot_scatter(cur_ax, allx, predicts, colors[idx_alg], 0.99, linewidths, alg) | ||||
|  | ||||
|         cur_ax.set_xlabel("X", fontsize=LabelSize) | ||||
|         cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize) | ||||
|         for tick in cur_ax.xaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|             tick.label.set_rotation(10) | ||||
|         for tick in cur_ax.yaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|         cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1)) | ||||
|         cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||
|  | ||||
|         # the trajectory data | ||||
|         cur_ax = fig.add_subplot(2, 1, 2) | ||||
|         for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): | ||||
|             # plot_scatter(cur_ax, alg2xs[alg], alg2ys[alg], olors[idx_alg], 0.99, linewidths, alg) | ||||
|             cur_ax.plot( | ||||
|                 alg2xs[alg], | ||||
|                 alg2ys[alg], | ||||
|                 color=colors[idx_alg], | ||||
|                 linestyle="-", | ||||
|                 linewidth=5, | ||||
|                 label=alg, | ||||
|             ) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||
|  | ||||
|         cur_ax.set_xlabel("Timestamp", fontsize=LabelSize) | ||||
|         cur_ax.set_ylabel("MSE", fontsize=LabelSize) | ||||
|         for tick in cur_ax.xaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|             tick.label.set_rotation(10) | ||||
|         for tick in cur_ax.yaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|         cur_ax.set_xlim(1, len(dynamic_env)) | ||||
|         cur_ax.set_ylim(0, 10) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||
|  | ||||
|         pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx - skip) | ||||
|         fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|         png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx - skip) | ||||
|         fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png") | ||||
|         plt.close("all") | ||||
|     save_dir = save_dir.resolve() | ||||
|     base_cmd = "ffmpeg -y -i {xdir}/v{ver}-%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format( | ||||
|         xdir=save_dir / "png", w=width, h=height, ver=version | ||||
|     ) | ||||
|     os.system( | ||||
|         "{:} {xdir}/com-alg-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version) | ||||
|     ) | ||||
|     os.system( | ||||
|         "{:} {xdir}/com-alg-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version) | ||||
|     ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
|     parser = argparse.ArgumentParser("Visualize synthetic data.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/vis-synthetic", | ||||
|         help="The save directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--env_version", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         help="The synthetic enviornment version.", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     visualize_env(os.path.join(args.save_dir, "vis-env"), args.env_version) | ||||
|     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") | ||||
|     # compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version) | ||||
|     # compare_cl(os.path.join(args.save_dir, "compare-cl")) | ||||
							
								
								
									
										66
									
								
								AutoDL-Projects/exps/experimental/example-nas-bench.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								AutoDL-Projects/exps/experimental/example-nas-bench.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,66 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||
| ########################################################################################################################################################### | ||||
| # Before run these commands, the files must be properly put. | ||||
| # | ||||
| # python exps/experimental/example-nas-bench.py --api_path $HOME/.torch/NAS-Bench-201-v1_1-096897.pth --archive_path $HOME/.torch/NAS-Bench-201-v1_1-archive | ||||
| ########################################################################################################################################################### | ||||
| import os, gc, sys, math, argparse, psutil | ||||
| import numpy as np | ||||
| import torch | ||||
| from pathlib import Path | ||||
| from collections import OrderedDict | ||||
| import matplotlib | ||||
| import seaborn as sns | ||||
|  | ||||
| matplotlib.use("agg") | ||||
| import matplotlib.pyplot as plt | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
| from nas_201_api import NASBench201API | ||||
| from log_utils import time_string | ||||
| from models import get_cell_based_tiny_net | ||||
| from utils import weight_watcher | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") | ||||
|     parser.add_argument( | ||||
|         "--api_path", | ||||
|         type=str, | ||||
|         default=None, | ||||
|         help="The path to the NAS-Bench-201 benchmark file and weight dir.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--archive_path", | ||||
|         type=str, | ||||
|         default=None, | ||||
|         help="The path to the NAS-Bench-201 weight dir.", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     meta_file = Path(args.api_path) | ||||
|     weight_dir = Path(args.archive_path) | ||||
|     assert meta_file.exists(), "invalid path for api : {:}".format(meta_file) | ||||
|     assert ( | ||||
|         weight_dir.exists() and weight_dir.is_dir() | ||||
|     ), "invalid path for weight dir : {:}".format(weight_dir) | ||||
|  | ||||
|     api = NASBench201API(meta_file, verbose=True) | ||||
|  | ||||
|     arch_index = 3  # query the 3-th architecture | ||||
|     api.reload(weight_dir, arch_index)  # reload the data of 3-th from archive dir | ||||
|  | ||||
|     data = "cifar10"  # query the info from CIFAR-10 | ||||
|     config = api.get_net_config(arch_index, data) | ||||
|     net = get_cell_based_tiny_net(config) | ||||
|     meta_info = api.query_meta_info_by_index( | ||||
|         arch_index, hp="200" | ||||
|     )  # all info about this architecture | ||||
|     params = meta_info.get_net_param(data, 888) | ||||
|  | ||||
|     net.load_state_dict(params) | ||||
|     _, summary = weight_watcher.analyze(net, alphas=False) | ||||
|     print("The summary of {:}-th architecture:\n{:}".format(arch_index, summary)) | ||||
							
								
								
									
										57
									
								
								AutoDL-Projects/exps/experimental/test-dks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								AutoDL-Projects/exps/experimental/test-dks.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,57 @@ | ||||
| from dks.base.activation_getter import ( | ||||
|     get_activation_function as _get_numpy_activation_function, | ||||
| ) | ||||
| from dks.base.activation_transform import _get_activations_params | ||||
|  | ||||
|  | ||||
| def subnet_max_func(x, r_fn): | ||||
|     depth = 7 | ||||
|     res_x = r_fn(x) | ||||
|     x = r_fn(x) | ||||
|     for _ in range(depth): | ||||
|         x = r_fn(r_fn(x)) + x | ||||
|     return max(x, res_x) | ||||
|  | ||||
|  | ||||
| def subnet_max_func_v2(x, r_fn): | ||||
|     depth = 2 | ||||
|     res_x = r_fn(x) | ||||
|  | ||||
|     x = r_fn(x) | ||||
|     for _ in range(depth): | ||||
|         x = 0.8 * r_fn(r_fn(x)) + 0.2 * x | ||||
|  | ||||
|     return max(x, res_x) | ||||
|  | ||||
|  | ||||
| def get_transformed_activations( | ||||
|     activation_names, | ||||
|     method="TAT", | ||||
|     dks_params=None, | ||||
|     tat_params=None, | ||||
|     max_slope_func=None, | ||||
|     max_curv_func=None, | ||||
|     subnet_max_func=None, | ||||
|     activation_getter=_get_numpy_activation_function, | ||||
| ): | ||||
|     params = _get_activations_params( | ||||
|         activation_names, | ||||
|         method=method, | ||||
|         dks_params=dks_params, | ||||
|         tat_params=tat_params, | ||||
|         max_slope_func=max_slope_func, | ||||
|         max_curv_func=max_curv_func, | ||||
|         subnet_max_func=subnet_max_func, | ||||
|     ) | ||||
|     return params | ||||
|  | ||||
|  | ||||
| params = get_transformed_activations( | ||||
|     ["swish"], method="TAT", subnet_max_func=subnet_max_func | ||||
| ) | ||||
| print(params) | ||||
|  | ||||
| params = get_transformed_activations( | ||||
|     ["leaky_relu"], method="TAT", subnet_max_func=subnet_max_func_v2 | ||||
| ) | ||||
| print(params) | ||||
							
								
								
									
										21
									
								
								AutoDL-Projects/exps/experimental/test-dynamic.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								AutoDL-Projects/exps/experimental/test-dynamic.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||
| ##################################################### | ||||
| # python test-dynamic.py | ||||
| ##################################################### | ||||
| import sys | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
| print("LIB-DIR: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from xautodl.datasets.math_core import ConstantFunc | ||||
| from xautodl.datasets.math_core import GaussianDGenerator | ||||
|  | ||||
| mean_generator = ConstantFunc(0) | ||||
| cov_generator = ConstantFunc(1) | ||||
|  | ||||
| generator = GaussianDGenerator([mean_generator], [[cov_generator]], (-1, 1)) | ||||
| generator(0, 10) | ||||
							
								
								
									
										28
									
								
								AutoDL-Projects/exps/experimental/test-flops.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								AutoDL-Projects/exps/experimental/test-flops.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,28 @@ | ||||
| import sys, time, random, argparse | ||||
| from copy import deepcopy | ||||
| import torchvision.models as models | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from utils import get_model_infos | ||||
|  | ||||
| # from models.ImageNet_MobileNetV2 import MobileNetV2 | ||||
| from torchvision.models.mobilenet import MobileNetV2 | ||||
|  | ||||
|  | ||||
| def main(width_mult): | ||||
|     # model = MobileNetV2(1001, width_mult, 32, 1280, 'InvertedResidual', 0.2) | ||||
|     model = MobileNetV2(width_mult=width_mult) | ||||
|     print(model) | ||||
|     flops, params = get_model_infos(model, (2, 3, 224, 224)) | ||||
|     print("FLOPs : {:}".format(flops)) | ||||
|     print("Params : {:}".format(params)) | ||||
|     print("-" * 50) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main(1.0) | ||||
|     main(1.4) | ||||
							
								
								
									
										168
									
								
								AutoDL-Projects/exps/experimental/test-nas-plot.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										168
									
								
								AutoDL-Projects/exps/experimental/test-nas-plot.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,168 @@ | ||||
| # python ./exps/vis/test.py | ||||
| import os, sys, random | ||||
| from pathlib import Path | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import numpy as np | ||||
| from collections import OrderedDict | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from nas_201_api import NASBench201API as API | ||||
|  | ||||
|  | ||||
| def test_nas_api(): | ||||
|     from nas_201_api import ArchResults | ||||
|  | ||||
|     xdata = torch.load( | ||||
|         "/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-201-4/simplifies/architectures/000157-FULL.pth" | ||||
|     ) | ||||
|     for key in ["full", "less"]: | ||||
|         print("\n------------------------- {:} -------------------------".format(key)) | ||||
|         archRes = ArchResults.create_from_state_dict(xdata[key]) | ||||
|         print(archRes) | ||||
|         print(archRes.arch_idx_str()) | ||||
|         print(archRes.get_dataset_names()) | ||||
|         print(archRes.get_comput_costs("cifar10-valid")) | ||||
|         # get the metrics | ||||
|         print(archRes.get_metrics("cifar10-valid", "x-valid", None, False)) | ||||
|         print(archRes.get_metrics("cifar10-valid", "x-valid", None, True)) | ||||
|         print(archRes.query("cifar10-valid", 777)) | ||||
|  | ||||
|  | ||||
| OPS = ["skip-connect", "conv-1x1", "conv-3x3", "pool-3x3"] | ||||
| COLORS = ["chartreuse", "cyan", "navyblue", "chocolate1"] | ||||
|  | ||||
|  | ||||
| def plot(filename): | ||||
|     from graphviz import Digraph | ||||
|  | ||||
|     g = Digraph( | ||||
|         format="png", | ||||
|         edge_attr=dict(fontsize="20", fontname="times"), | ||||
|         node_attr=dict( | ||||
|             style="filled", | ||||
|             shape="rect", | ||||
|             align="center", | ||||
|             fontsize="20", | ||||
|             height="0.5", | ||||
|             width="0.5", | ||||
|             penwidth="2", | ||||
|             fontname="times", | ||||
|         ), | ||||
|         engine="dot", | ||||
|     ) | ||||
|     g.body.extend(["rankdir=LR"]) | ||||
|  | ||||
|     steps = 5 | ||||
|     for i in range(0, steps): | ||||
|         if i == 0: | ||||
|             g.node(str(i), fillcolor="darkseagreen2") | ||||
|         elif i + 1 == steps: | ||||
|             g.node(str(i), fillcolor="palegoldenrod") | ||||
|         else: | ||||
|             g.node(str(i), fillcolor="lightblue") | ||||
|  | ||||
|     for i in range(1, steps): | ||||
|         for xin in range(i): | ||||
|             op_i = random.randint(0, len(OPS) - 1) | ||||
|             # g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i]) | ||||
|             g.edge( | ||||
|                 str(xin), | ||||
|                 str(i), | ||||
|                 label=OPS[op_i], | ||||
|                 color=COLORS[op_i], | ||||
|                 fillcolor=COLORS[op_i], | ||||
|             ) | ||||
|             # import pdb; pdb.set_trace() | ||||
|     g.render(filename, cleanup=True, view=False) | ||||
|  | ||||
|  | ||||
| def test_auto_grad(): | ||||
|     class Net(torch.nn.Module): | ||||
|         def __init__(self, iS): | ||||
|             super(Net, self).__init__() | ||||
|             self.layer = torch.nn.Linear(iS, 1) | ||||
|  | ||||
|         def forward(self, inputs): | ||||
|             outputs = self.layer(inputs) | ||||
|             outputs = torch.exp(outputs) | ||||
|             return outputs.mean() | ||||
|  | ||||
|     net = Net(10) | ||||
|     inputs = torch.rand(256, 10) | ||||
|     loss = net(inputs) | ||||
|     first_order_grads = torch.autograd.grad( | ||||
|         loss, net.parameters(), retain_graph=True, create_graph=True | ||||
|     ) | ||||
|     first_order_grads = torch.cat([x.view(-1) for x in first_order_grads]) | ||||
|     second_order_grads = [] | ||||
|     for grads in first_order_grads: | ||||
|         s_grads = torch.autograd.grad(grads, net.parameters()) | ||||
|         second_order_grads.append(s_grads) | ||||
|  | ||||
|  | ||||
| def test_one_shot_model(ckpath, use_train): | ||||
|     from models import get_cell_based_tiny_net, get_search_spaces | ||||
|     from datasets import get_datasets, SearchDataset | ||||
|     from config_utils import load_config, dict2config | ||||
|     from utils.nas_utils import evaluate_one_shot | ||||
|  | ||||
|     use_train = int(use_train) > 0 | ||||
|     # ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth' | ||||
|     # ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth' | ||||
|     print("ckpath : {:}".format(ckpath)) | ||||
|     ckp = torch.load(ckpath) | ||||
|     xargs = ckp["args"] | ||||
|     train_data, valid_data, xshape, class_num = get_datasets( | ||||
|         xargs.dataset, xargs.data_path, -1 | ||||
|     ) | ||||
|     # config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None) | ||||
|     config = load_config( | ||||
|         "./configs/nas-benchmark/algos/DARTS.config", | ||||
|         {"class_num": class_num, "xshape": xshape}, | ||||
|         None, | ||||
|     ) | ||||
|     if xargs.dataset == "cifar10": | ||||
|         cifar_split = load_config("configs/nas-benchmark/cifar-split.txt", None, None) | ||||
|         xvalid_data = deepcopy(train_data) | ||||
|         xvalid_data.transform = valid_data.transform | ||||
|         valid_loader = torch.utils.data.DataLoader( | ||||
|             xvalid_data, | ||||
|             batch_size=2048, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid), | ||||
|             num_workers=12, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|     else: | ||||
|         raise ValueError("invalid dataset : {:}".format(xargs.dataseet)) | ||||
|     search_space = get_search_spaces("cell", xargs.search_space_name) | ||||
|     model_config = dict2config( | ||||
|         { | ||||
|             "name": "SETN", | ||||
|             "C": xargs.channel, | ||||
|             "N": xargs.num_cells, | ||||
|             "max_nodes": xargs.max_nodes, | ||||
|             "num_classes": class_num, | ||||
|             "space": search_space, | ||||
|             "affine": False, | ||||
|             "track_running_stats": True, | ||||
|         }, | ||||
|         None, | ||||
|     ) | ||||
|     search_model = get_cell_based_tiny_net(model_config) | ||||
|     search_model.load_state_dict(ckp["search_model"]) | ||||
|     search_model = search_model.cuda() | ||||
|     api = API("/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth") | ||||
|     archs, probs, accuracies = evaluate_one_shot( | ||||
|         search_model, valid_loader, api, use_train | ||||
|     ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     # test_nas_api() | ||||
|     # for i in range(200): plot('{:04d}'.format(i)) | ||||
|     # test_auto_grad() | ||||
|     test_one_shot_model(sys.argv[1], sys.argv[2]) | ||||
							
								
								
									
										31
									
								
								AutoDL-Projects/exps/experimental/test-resnest.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								AutoDL-Projects/exps/experimental/test-resnest.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,31 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 # | ||||
| ##################################################### | ||||
| # python exps/experimental/test-resnest.py | ||||
| ##################################################### | ||||
| import sys, time, torch, random, argparse | ||||
| from PIL import ImageFile | ||||
|  | ||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
| from utils import get_model_infos | ||||
|  | ||||
| torch.hub.list("zhanghang1989/ResNeSt", force_reload=True) | ||||
|  | ||||
| for model_name, xshape in [ | ||||
|     ("resnest50", (1, 3, 224, 224)), | ||||
|     ("resnest101", (1, 3, 256, 256)), | ||||
|     ("resnest200", (1, 3, 320, 320)), | ||||
|     ("resnest269", (1, 3, 416, 416)), | ||||
| ]: | ||||
|     # net = torch.hub.load('zhanghang1989/ResNeSt', model_name, pretrained=True) | ||||
|     net = torch.hub.load("zhanghang1989/ResNeSt", model_name, pretrained=False) | ||||
|     print("Model : {:}, input shape : {:}".format(model_name, xshape)) | ||||
|     flops, param = get_model_infos(net, xshape) | ||||
|     print("flops  : {:.3f}M".format(flops)) | ||||
|     print("params : {:.3f}M".format(param)) | ||||
							
								
								
									
										198
									
								
								AutoDL-Projects/exps/experimental/test-ww-bench.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										198
									
								
								AutoDL-Projects/exps/experimental/test-ww-bench.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,198 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||
| ########################################################################################################################################################### | ||||
| # Before run these commands, the files must be properly put. | ||||
| # | ||||
| # CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-tss-v1_0-3ffb9 --dataset cifar10 | ||||
| # CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-sss-v1_0-50262 --dataset cifar100 | ||||
| # CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-sss-v1_0-50262 --dataset ImageNet16-120 | ||||
| # CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space tss --base_path $HOME/.torch/NATS-tss-v1_0-3ffb9 --dataset cifar10 | ||||
| ########################################################################################################################################################### | ||||
| import os, gc, sys, math, argparse, psutil | ||||
| import numpy as np | ||||
| import torch | ||||
| from pathlib import Path | ||||
| from collections import OrderedDict | ||||
| import matplotlib | ||||
| import seaborn as sns | ||||
|  | ||||
| matplotlib.use("agg") | ||||
| import matplotlib.pyplot as plt | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
| print("LIB-DIR: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from log_utils import time_string | ||||
| from nats_bench import create | ||||
| from models import get_cell_based_tiny_net | ||||
| from utils import weight_watcher | ||||
|  | ||||
|  | ||||
| """ | ||||
| def get_cor(A, B): | ||||
|   return float(np.corrcoef(A, B)[0,1]) | ||||
|  | ||||
|  | ||||
| def tostr(accdict, norms): | ||||
|   xstr = [] | ||||
|   for key, accs in accdict.items(): | ||||
|     cor = get_cor(accs, norms) | ||||
|     xstr.append('{:}: {:.3f}'.format(key, cor)) | ||||
|   return ' '.join(xstr) | ||||
| """ | ||||
|  | ||||
|  | ||||
| def evaluate(api, weight_dir, data: str): | ||||
|     print("\nEvaluate dataset={:}".format(data)) | ||||
|     process = psutil.Process(os.getpid()) | ||||
|     norms, accuracies = [], [] | ||||
|     ok, total = 0, 5000 | ||||
|     for idx in range(total): | ||||
|         arch_index = api.random() | ||||
|         api.reload(weight_dir, arch_index) | ||||
|         # compute the weight watcher results | ||||
|         config = api.get_net_config(arch_index, data) | ||||
|         net = get_cell_based_tiny_net(config) | ||||
|         meta_info = api.query_meta_info_by_index( | ||||
|             arch_index, hp="200" if api.search_space_name == "topology" else "90" | ||||
|         ) | ||||
|         params = meta_info.get_net_param( | ||||
|             data, 888 if api.search_space_name == "topology" else 777 | ||||
|         ) | ||||
|         with torch.no_grad(): | ||||
|             net.load_state_dict(params) | ||||
|             _, summary = weight_watcher.analyze(net, alphas=False) | ||||
|             if "lognorm" not in summary: | ||||
|                 api.clear_params(arch_index, None) | ||||
|                 del net | ||||
|                 continue | ||||
|                 continue | ||||
|             cur_norm = -summary["lognorm"] | ||||
|         api.clear_params(arch_index, None) | ||||
|         if math.isnan(cur_norm): | ||||
|             del net, meta_info | ||||
|             continue | ||||
|         else: | ||||
|             ok += 1 | ||||
|             norms.append(cur_norm) | ||||
|         # query the accuracy | ||||
|         info = meta_info.get_metrics( | ||||
|             data, | ||||
|             "ori-test", | ||||
|             iepoch=None, | ||||
|             is_random=888 if api.search_space_name == "topology" else 777, | ||||
|         ) | ||||
|         accuracies.append(info["accuracy"]) | ||||
|         del net, meta_info | ||||
|         # print the information | ||||
|         if idx % 20 == 0: | ||||
|             gc.collect() | ||||
|             print( | ||||
|                 "{:} {:04d}_{:04d}/{:04d} ({:.2f} MB memory)".format( | ||||
|                     time_string(), ok, idx, total, process.memory_info().rss / 1e6 | ||||
|                 ) | ||||
|             ) | ||||
|     return norms, accuracies | ||||
|  | ||||
|  | ||||
| def main(search_space, meta_file: str, weight_dir, save_dir, xdata): | ||||
|     save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     api = create(meta_file, search_space, verbose=False) | ||||
|     datasets = ["cifar10-valid", "cifar10", "cifar100", "ImageNet16-120"] | ||||
|     print(time_string() + " " + "=" * 50) | ||||
|     for data in datasets: | ||||
|         hps = api.avaliable_hps | ||||
|         for hp in hps: | ||||
|             nums = api.statistics(data, hp=hp) | ||||
|             total = sum([k * v for k, v in nums.items()]) | ||||
|             print( | ||||
|                 "Using {:3s} epochs, trained on {:20s} : {:} trials in total ({:}).".format( | ||||
|                     hp, data, total, nums | ||||
|                 ) | ||||
|             ) | ||||
|     print(time_string() + " " + "=" * 50) | ||||
|  | ||||
|     norms, accuracies = evaluate(api, weight_dir, xdata) | ||||
|  | ||||
|     indexes = list(range(len(norms))) | ||||
|     norm_indexes = sorted(indexes, key=lambda i: norms[i]) | ||||
|     accy_indexes = sorted(indexes, key=lambda i: accuracies[i]) | ||||
|     labels = [] | ||||
|     for index in norm_indexes: | ||||
|         labels.append(accy_indexes.index(index)) | ||||
|  | ||||
|     dpi, width, height = 200, 1400, 800 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize = 18, 12 | ||||
|     resnet_scale, resnet_alpha = 120, 0.5 | ||||
|  | ||||
|     fig = plt.figure(figsize=figsize) | ||||
|     ax = fig.add_subplot(111) | ||||
|     plt.xlim(min(indexes), max(indexes)) | ||||
|     plt.ylim(min(indexes), max(indexes)) | ||||
|     # plt.ylabel('y').set_rotation(30) | ||||
|     plt.yticks( | ||||
|         np.arange(min(indexes), max(indexes), max(indexes) // 3), | ||||
|         fontsize=LegendFontsize, | ||||
|         rotation="vertical", | ||||
|     ) | ||||
|     plt.xticks( | ||||
|         np.arange(min(indexes), max(indexes), max(indexes) // 5), | ||||
|         fontsize=LegendFontsize, | ||||
|     ) | ||||
|     ax.scatter(indexes, labels, marker="*", s=0.5, c="tab:red", alpha=0.8) | ||||
|     ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) | ||||
|     ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="Test accuracy") | ||||
|     ax.scatter([-1], [-1], marker="*", s=100, c="tab:red", label="Weight watcher") | ||||
|     plt.grid(zorder=0) | ||||
|     ax.set_axisbelow(True) | ||||
|     plt.legend(loc=0, fontsize=LegendFontsize) | ||||
|     ax.set_xlabel( | ||||
|         "architecture ranking sorted by the test accuracy ", fontsize=LabelSize | ||||
|     ) | ||||
|     ax.set_ylabel("architecture ranking computed by weight watcher", fontsize=LabelSize) | ||||
|     save_path = (save_dir / "{:}-{:}-test-ww.pdf".format(search_space, xdata)).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|     save_path = (save_dir / "{:}-{:}-test-ww.png".format(search_space, xdata)).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("{:} save into {:}".format(time_string(), save_path)) | ||||
|  | ||||
|     print("{:} finish this test.".format(time_string())) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./output/vis-nas-bench/", | ||||
|         help="The base-name of folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--search_space", | ||||
|         type=str, | ||||
|         default=None, | ||||
|         choices=["tss", "sss"], | ||||
|         help="The search space.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--base_path", | ||||
|         type=str, | ||||
|         default=None, | ||||
|         help="The path to the NAS-Bench-201 benchmark file and weight dir.", | ||||
|     ) | ||||
|     parser.add_argument("--dataset", type=str, default=None, help=".") | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     save_dir = Path(args.save_dir) | ||||
|     save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     meta_file = Path(args.base_path + ".pth") | ||||
|     weight_dir = Path(args.base_path + "-full") | ||||
|     assert meta_file.exists(), "invalid path for api : {:}".format(meta_file) | ||||
|     assert ( | ||||
|         weight_dir.exists() and weight_dir.is_dir() | ||||
|     ), "invalid path for weight dir : {:}".format(weight_dir) | ||||
|  | ||||
|     main(args.search_space, str(meta_file), weight_dir, save_dir, args.dataset) | ||||
							
								
								
									
										30
									
								
								AutoDL-Projects/exps/experimental/test-ww.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								AutoDL-Projects/exps/experimental/test-ww.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | ||||
| import sys, time, random, argparse | ||||
| from copy import deepcopy | ||||
| import torchvision.models as models | ||||
| from pathlib import Path | ||||
|  | ||||
| from xautodl.utils import weight_watcher | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     # model = models.vgg19_bn(pretrained=True) | ||||
|     # _, summary = weight_watcher.analyze(model, alphas=False) | ||||
|     # for key, value in summary.items(): | ||||
|     #   print('{:10s} : {:}'.format(key, value)) | ||||
|  | ||||
|     _, summary = weight_watcher.analyze(models.vgg13(pretrained=True), alphas=False) | ||||
|     print("vgg-13 : {:}".format(summary["lognorm"])) | ||||
|     _, summary = weight_watcher.analyze(models.vgg13_bn(pretrained=True), alphas=False) | ||||
|     print("vgg-13-BN : {:}".format(summary["lognorm"])) | ||||
|     _, summary = weight_watcher.analyze(models.vgg16(pretrained=True), alphas=False) | ||||
|     print("vgg-16 : {:}".format(summary["lognorm"])) | ||||
|     _, summary = weight_watcher.analyze(models.vgg16_bn(pretrained=True), alphas=False) | ||||
|     print("vgg-16-BN : {:}".format(summary["lognorm"])) | ||||
|     _, summary = weight_watcher.analyze(models.vgg19(pretrained=True), alphas=False) | ||||
|     print("vgg-19 : {:}".format(summary["lognorm"])) | ||||
|     _, summary = weight_watcher.analyze(models.vgg19_bn(pretrained=True), alphas=False) | ||||
|     print("vgg-19-BN : {:}".format(summary["lognorm"])) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
							
								
								
									
										178
									
								
								AutoDL-Projects/exps/experimental/vis-nats-bench-algos.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										178
									
								
								AutoDL-Projects/exps/experimental/vis-nats-bench-algos.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,178 @@ | ||||
| ############################################################### | ||||
| # NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) # | ||||
| ############################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06           # | ||||
| ############################################################### | ||||
| # Usage: python exps/experimental/vis-nats-bench-algos.py --search_space tss | ||||
| # Usage: python exps/experimental/vis-nats-bench-algos.py --search_space sss | ||||
| ############################################################### | ||||
| import os, gc, sys, time, torch, argparse | ||||
| import numpy as np | ||||
| from typing import List, Text, Dict, Any | ||||
| from shutil import copyfile | ||||
| from collections import defaultdict, OrderedDict | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
| import matplotlib | ||||
| import seaborn as sns | ||||
|  | ||||
| matplotlib.use("agg") | ||||
| import matplotlib.pyplot as plt | ||||
| import matplotlib.ticker as ticker | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import dict2config, load_config | ||||
| from nats_bench import create | ||||
| from log_utils import time_string | ||||
|  | ||||
|  | ||||
| def fetch_data(root_dir="./output/search", search_space="tss", dataset=None): | ||||
|     ss_dir = "{:}-{:}".format(root_dir, search_space) | ||||
|     alg2name, alg2path = OrderedDict(), OrderedDict() | ||||
|     alg2name["REA"] = "R-EA-SS3" | ||||
|     alg2name["REINFORCE"] = "REINFORCE-0.01" | ||||
|     alg2name["RANDOM"] = "RANDOM" | ||||
|     alg2name["BOHB"] = "BOHB" | ||||
|     for alg, name in alg2name.items(): | ||||
|         alg2path[alg] = os.path.join(ss_dir, dataset, name, "results.pth") | ||||
|         assert os.path.isfile(alg2path[alg]), "invalid path : {:}".format(alg2path[alg]) | ||||
|     alg2data = OrderedDict() | ||||
|     for alg, path in alg2path.items(): | ||||
|         data = torch.load(path) | ||||
|         for index, info in data.items(): | ||||
|             info["time_w_arch"] = [ | ||||
|                 (x, y) for x, y in zip(info["all_total_times"], info["all_archs"]) | ||||
|             ] | ||||
|             for j, arch in enumerate(info["all_archs"]): | ||||
|                 assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format( | ||||
|                     alg, search_space, dataset, index, j | ||||
|                 ) | ||||
|         alg2data[alg] = data | ||||
|     return alg2data | ||||
|  | ||||
|  | ||||
| def query_performance(api, data, dataset, ticket): | ||||
|     results, is_size_space = [], api.search_space_name == "size" | ||||
|     for i, info in data.items(): | ||||
|         time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket)) | ||||
|         time_a, arch_a = time_w_arch[0] | ||||
|         time_b, arch_b = time_w_arch[1] | ||||
|         info_a = api.get_more_info( | ||||
|             arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | ||||
|         ) | ||||
|         info_b = api.get_more_info( | ||||
|             arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | ||||
|         ) | ||||
|         accuracy_a, accuracy_b = info_a["test-accuracy"], info_b["test-accuracy"] | ||||
|         interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + ( | ||||
|             ticket - time_a | ||||
|         ) / (time_b - time_a) * accuracy_b | ||||
|         results.append(interplate) | ||||
|     return sum(results) / len(results) | ||||
|  | ||||
|  | ||||
| y_min_s = { | ||||
|     ("cifar10", "tss"): 90, | ||||
|     ("cifar10", "sss"): 92, | ||||
|     ("cifar100", "tss"): 65, | ||||
|     ("cifar100", "sss"): 65, | ||||
|     ("ImageNet16-120", "tss"): 36, | ||||
|     ("ImageNet16-120", "sss"): 40, | ||||
| } | ||||
|  | ||||
| y_max_s = { | ||||
|     ("cifar10", "tss"): 94.5, | ||||
|     ("cifar10", "sss"): 93.3, | ||||
|     ("cifar100", "tss"): 72, | ||||
|     ("cifar100", "sss"): 70, | ||||
|     ("ImageNet16-120", "tss"): 44, | ||||
|     ("ImageNet16-120", "sss"): 46, | ||||
| } | ||||
|  | ||||
| name2label = { | ||||
|     "cifar10": "CIFAR-10", | ||||
|     "cifar100": "CIFAR-100", | ||||
|     "ImageNet16-120": "ImageNet-16-120", | ||||
| } | ||||
|  | ||||
|  | ||||
| def visualize_curve(api, vis_save_dir, search_space, max_time): | ||||
|     vis_save_dir = vis_save_dir.resolve() | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     dpi, width, height = 250, 5200, 1400 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize = 16, 16 | ||||
|  | ||||
|     def sub_plot_fn(ax, dataset): | ||||
|         alg2data = fetch_data(search_space=search_space, dataset=dataset) | ||||
|         alg2accuracies = OrderedDict() | ||||
|         total_tickets = 150 | ||||
|         time_tickets = [ | ||||
|             float(i) / total_tickets * max_time for i in range(total_tickets) | ||||
|         ] | ||||
|         colors = ["b", "g", "c", "m", "y"] | ||||
|         ax.set_xlim(0, 200) | ||||
|         ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)]) | ||||
|         for idx, (alg, data) in enumerate(alg2data.items()): | ||||
|             print("plot alg : {:}".format(alg)) | ||||
|             accuracies = [] | ||||
|             for ticket in time_tickets: | ||||
|                 accuracy = query_performance(api, data, dataset, ticket) | ||||
|                 accuracies.append(accuracy) | ||||
|             alg2accuracies[alg] = accuracies | ||||
|             ax.plot( | ||||
|                 [x / 100 for x in time_tickets], | ||||
|                 accuracies, | ||||
|                 c=colors[idx], | ||||
|                 label="{:}".format(alg), | ||||
|             ) | ||||
|             ax.set_xlabel("Estimated wall-clock time (1e2 seconds)", fontsize=LabelSize) | ||||
|             ax.set_ylabel( | ||||
|                 "Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize | ||||
|             ) | ||||
|             ax.set_title( | ||||
|                 "Searching results on {:}".format(name2label[dataset]), | ||||
|                 fontsize=LabelSize + 4, | ||||
|             ) | ||||
|         ax.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     fig, axs = plt.subplots(1, 3, figsize=figsize) | ||||
|     datasets = ["cifar10", "cifar100", "ImageNet16-120"] | ||||
|     for dataset, ax in zip(datasets, axs): | ||||
|         sub_plot_fn(ax, dataset) | ||||
|         print("sub-plot {:} on {:} done.".format(dataset, search_space)) | ||||
|     save_path = (vis_save_dir / "{:}-curve.png".format(search_space)).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("{:} save into {:}".format(time_string(), save_path)) | ||||
|     plt.close("all") | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NAS-Bench-X", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="output/vis-nas-bench/nas-algos", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--search_space", | ||||
|         type=str, | ||||
|         choices=["tss", "sss"], | ||||
|         help="Choose the search space.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--max_time", type=float, default=20000, help="The maximum time budget." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     save_dir = Path(args.save_dir) | ||||
|  | ||||
|     api = create(None, args.search_space, verbose=False) | ||||
|     visualize_curve(api, save_dir, args.search_space, args.max_time) | ||||
							
								
								
									
										185
									
								
								AutoDL-Projects/exps/experimental/vis-nats-bench-ws.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										185
									
								
								AutoDL-Projects/exps/experimental/vis-nats-bench-ws.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,185 @@ | ||||
| ############################################################### | ||||
| # NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) # | ||||
| ############################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06           # | ||||
| ############################################################### | ||||
| # Usage: python exps/experimental/vis-nats-bench-ws.py --search_space tss | ||||
| # Usage: python exps/experimental/vis-nats-bench-ws.py --search_space sss | ||||
| ############################################################### | ||||
| import os, gc, sys, time, torch, argparse | ||||
| import numpy as np | ||||
| from typing import List, Text, Dict, Any | ||||
| from shutil import copyfile | ||||
| from collections import defaultdict, OrderedDict | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
| import matplotlib | ||||
| import seaborn as sns | ||||
|  | ||||
| matplotlib.use("agg") | ||||
| import matplotlib.pyplot as plt | ||||
| import matplotlib.ticker as ticker | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import dict2config, load_config | ||||
| from nats_bench import create | ||||
| from log_utils import time_string | ||||
|  | ||||
|  | ||||
| # def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARMNone'): | ||||
| def fetch_data( | ||||
|     root_dir="./output/search", search_space="tss", dataset=None, suffix="-WARM0.3" | ||||
| ): | ||||
|     ss_dir = "{:}-{:}".format(root_dir, search_space) | ||||
|     alg2name, alg2path = OrderedDict(), OrderedDict() | ||||
|     seeds = [777, 888, 999] | ||||
|     print("\n[fetch data] from {:} on {:}".format(search_space, dataset)) | ||||
|     if search_space == "tss": | ||||
|         alg2name["GDAS"] = "gdas-affine0_BN0-None" | ||||
|         alg2name["RSPS"] = "random-affine0_BN0-None" | ||||
|         alg2name["DARTS (1st)"] = "darts-v1-affine0_BN0-None" | ||||
|         alg2name["DARTS (2nd)"] = "darts-v2-affine0_BN0-None" | ||||
|         alg2name["ENAS"] = "enas-affine0_BN0-None" | ||||
|         alg2name["SETN"] = "setn-affine0_BN0-None" | ||||
|     else: | ||||
|         # alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix) | ||||
|         # alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix) | ||||
|         # alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix) | ||||
|         alg2name["channel-wise interpolation"] = "tas-affine0_BN0-AWD0.001{:}".format( | ||||
|             suffix | ||||
|         ) | ||||
|         alg2name[ | ||||
|             "masking + Gumbel-Softmax" | ||||
|         ] = "mask_gumbel-affine0_BN0-AWD0.001{:}".format(suffix) | ||||
|         alg2name["masking + sampling"] = "mask_rl-affine0_BN0-AWD0.0{:}".format(suffix) | ||||
|     for alg, name in alg2name.items(): | ||||
|         alg2path[alg] = os.path.join(ss_dir, dataset, name, "seed-{:}-last-info.pth") | ||||
|     alg2data = OrderedDict() | ||||
|     for alg, path in alg2path.items(): | ||||
|         alg2data[alg], ok_num = [], 0 | ||||
|         for seed in seeds: | ||||
|             xpath = path.format(seed) | ||||
|             if os.path.isfile(xpath): | ||||
|                 ok_num += 1 | ||||
|             else: | ||||
|                 print("This is an invalid path : {:}".format(xpath)) | ||||
|                 continue | ||||
|             data = torch.load(xpath, map_location=torch.device("cpu")) | ||||
|             data = torch.load(data["last_checkpoint"], map_location=torch.device("cpu")) | ||||
|             alg2data[alg].append(data["genotypes"]) | ||||
|         print("This algorithm : {:} has {:} valid ckps.".format(alg, ok_num)) | ||||
|         assert ok_num > 0, "Must have at least 1 valid ckps." | ||||
|     return alg2data | ||||
|  | ||||
|  | ||||
| y_min_s = { | ||||
|     ("cifar10", "tss"): 90, | ||||
|     ("cifar10", "sss"): 92, | ||||
|     ("cifar100", "tss"): 65, | ||||
|     ("cifar100", "sss"): 65, | ||||
|     ("ImageNet16-120", "tss"): 36, | ||||
|     ("ImageNet16-120", "sss"): 40, | ||||
| } | ||||
|  | ||||
| y_max_s = { | ||||
|     ("cifar10", "tss"): 94.5, | ||||
|     ("cifar10", "sss"): 93.3, | ||||
|     ("cifar100", "tss"): 72, | ||||
|     ("cifar100", "sss"): 70, | ||||
|     ("ImageNet16-120", "tss"): 44, | ||||
|     ("ImageNet16-120", "sss"): 46, | ||||
| } | ||||
|  | ||||
| name2label = { | ||||
|     "cifar10": "CIFAR-10", | ||||
|     "cifar100": "CIFAR-100", | ||||
|     "ImageNet16-120": "ImageNet-16-120", | ||||
| } | ||||
|  | ||||
|  | ||||
| def visualize_curve(api, vis_save_dir, search_space): | ||||
|     vis_save_dir = vis_save_dir.resolve() | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     dpi, width, height = 250, 5200, 1400 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize = 16, 16 | ||||
|  | ||||
|     def sub_plot_fn(ax, dataset): | ||||
|         alg2data = fetch_data(search_space=search_space, dataset=dataset) | ||||
|         alg2accuracies = OrderedDict() | ||||
|         epochs = 100 | ||||
|         colors = ["b", "g", "c", "m", "y", "r"] | ||||
|         ax.set_xlim(0, epochs) | ||||
|         # ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)]) | ||||
|         for idx, (alg, data) in enumerate(alg2data.items()): | ||||
|             print("plot alg : {:}".format(alg)) | ||||
|             xs, accuracies = [], [] | ||||
|             for iepoch in range(epochs + 1): | ||||
|                 try: | ||||
|                     structures, accs = [_[iepoch - 1] for _ in data], [] | ||||
|                 except: | ||||
|                     raise ValueError( | ||||
|                         "This alg {:} on {:} has invalid checkpoints.".format( | ||||
|                             alg, dataset | ||||
|                         ) | ||||
|                     ) | ||||
|                 for structure in structures: | ||||
|                     info = api.get_more_info( | ||||
|                         structure, | ||||
|                         dataset=dataset, | ||||
|                         hp=90 if api.search_space_name == "size" else 200, | ||||
|                         is_random=False, | ||||
|                     ) | ||||
|                     accs.append(info["test-accuracy"]) | ||||
|                 accuracies.append(sum(accs) / len(accs)) | ||||
|                 xs.append(iepoch) | ||||
|             alg2accuracies[alg] = accuracies | ||||
|             ax.plot(xs, accuracies, c=colors[idx], label="{:}".format(alg)) | ||||
|             ax.set_xlabel("The searching epoch", fontsize=LabelSize) | ||||
|             ax.set_ylabel( | ||||
|                 "Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize | ||||
|             ) | ||||
|             ax.set_title( | ||||
|                 "Searching results on {:}".format(name2label[dataset]), | ||||
|                 fontsize=LabelSize + 4, | ||||
|             ) | ||||
|         ax.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     fig, axs = plt.subplots(1, 3, figsize=figsize) | ||||
|     datasets = ["cifar10", "cifar100", "ImageNet16-120"] | ||||
|     for dataset, ax in zip(datasets, axs): | ||||
|         sub_plot_fn(ax, dataset) | ||||
|         print("sub-plot {:} on {:} done.".format(dataset, search_space)) | ||||
|     save_path = (vis_save_dir / "{:}-ws-curve.png".format(search_space)).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("{:} save into {:}".format(time_string(), save_path)) | ||||
|     plt.close("all") | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NAS-Bench-X", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="output/vis-nas-bench/nas-algos", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--search_space", | ||||
|         type=str, | ||||
|         default="tss", | ||||
|         choices=["tss", "sss"], | ||||
|         help="Choose the search space.", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     save_dir = Path(args.save_dir) | ||||
|  | ||||
|     api = create(None, args.search_space, fast_mode=True, verbose=False) | ||||
|     visualize_curve(api, save_dir, args.search_space) | ||||
							
								
								
									
										657
									
								
								AutoDL-Projects/exps/experimental/visualize-nas-bench-x.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										657
									
								
								AutoDL-Projects/exps/experimental/visualize-nas-bench-x.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,657 @@ | ||||
| ############################################################### | ||||
| # NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) # | ||||
| ############################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06           # | ||||
| ############################################################### | ||||
| # Usage: python exps/experimental/visualize-nas-bench-x.py | ||||
| ############################################################### | ||||
| import os, sys, time, torch, argparse | ||||
| import numpy as np | ||||
| from typing import List, Text, Dict, Any | ||||
| from shutil import copyfile | ||||
| from collections import defaultdict | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
| import matplotlib | ||||
| import seaborn as sns | ||||
|  | ||||
| matplotlib.use("agg") | ||||
| import matplotlib.pyplot as plt | ||||
| import matplotlib.ticker as ticker | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import dict2config, load_config | ||||
| from log_utils import time_string | ||||
| from models import get_cell_based_tiny_net | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| def visualize_info(api, vis_save_dir, indicator): | ||||
|     vis_save_dir = vis_save_dir.resolve() | ||||
|     # print ('{:} start to visualize {:} information'.format(time_string(), api)) | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "cifar10", indicator | ||||
|     ) | ||||
|     cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "cifar100", indicator | ||||
|     ) | ||||
|     imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "ImageNet16-120", indicator | ||||
|     ) | ||||
|     cifar010_info = torch.load(cifar010_cache_path) | ||||
|     cifar100_info = torch.load(cifar100_cache_path) | ||||
|     imagenet_info = torch.load(imagenet_cache_path) | ||||
|     indexes = list(range(len(cifar010_info["params"]))) | ||||
|  | ||||
|     print("{:} start to visualize relative ranking".format(time_string())) | ||||
|  | ||||
|     cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info["test_accs"][i]) | ||||
|     cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info["test_accs"][i]) | ||||
|     imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info["test_accs"][i]) | ||||
|  | ||||
|     cifar100_labels, imagenet_labels = [], [] | ||||
|     for idx in cifar010_ord_indexes: | ||||
|         cifar100_labels.append(cifar100_ord_indexes.index(idx)) | ||||
|         imagenet_labels.append(imagenet_ord_indexes.index(idx)) | ||||
|     print("{:} prepare data done.".format(time_string())) | ||||
|  | ||||
|     dpi, width, height = 200, 1400, 800 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize = 18, 12 | ||||
|     resnet_scale, resnet_alpha = 120, 0.5 | ||||
|  | ||||
|     fig = plt.figure(figsize=figsize) | ||||
|     ax = fig.add_subplot(111) | ||||
|     plt.xlim(min(indexes), max(indexes)) | ||||
|     plt.ylim(min(indexes), max(indexes)) | ||||
|     # plt.ylabel('y').set_rotation(30) | ||||
|     plt.yticks( | ||||
|         np.arange(min(indexes), max(indexes), max(indexes) // 3), | ||||
|         fontsize=LegendFontsize, | ||||
|         rotation="vertical", | ||||
|     ) | ||||
|     plt.xticks( | ||||
|         np.arange(min(indexes), max(indexes), max(indexes) // 5), | ||||
|         fontsize=LegendFontsize, | ||||
|     ) | ||||
|     ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8) | ||||
|     ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8) | ||||
|     ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) | ||||
|     ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="CIFAR-10") | ||||
|     ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="CIFAR-100") | ||||
|     ax.scatter([-1], [-1], marker="*", s=100, c="tab:red", label="ImageNet-16-120") | ||||
|     plt.grid(zorder=0) | ||||
|     ax.set_axisbelow(True) | ||||
|     plt.legend(loc=0, fontsize=LegendFontsize) | ||||
|     ax.set_xlabel("architecture ranking in CIFAR-10", fontsize=LabelSize) | ||||
|     ax.set_ylabel("architecture ranking", fontsize=LabelSize) | ||||
|     save_path = (vis_save_dir / "{:}-relative-rank.pdf".format(indicator)).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|     save_path = (vis_save_dir / "{:}-relative-rank.png".format(indicator)).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("{:} save into {:}".format(time_string(), save_path)) | ||||
|  | ||||
|  | ||||
| def visualize_sss_info(api, dataset, vis_save_dir): | ||||
|     vis_save_dir = vis_save_dir.resolve() | ||||
|     print("{:} start to visualize {:} information".format(time_string(), dataset)) | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     cache_file_path = vis_save_dir / "{:}-cache-sss-info.pth".format(dataset) | ||||
|     if not cache_file_path.exists(): | ||||
|         print("Do not find cache file : {:}".format(cache_file_path)) | ||||
|         params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] | ||||
|         for index in range(len(api)): | ||||
|             cost_info = api.get_cost_info(index, dataset, hp="90") | ||||
|             params.append(cost_info["params"]) | ||||
|             flops.append(cost_info["flops"]) | ||||
|             # accuracy | ||||
|             info = api.get_more_info(index, dataset, hp="90", is_random=False) | ||||
|             train_accs.append(info["train-accuracy"]) | ||||
|             test_accs.append(info["test-accuracy"]) | ||||
|             if dataset == "cifar10": | ||||
|                 info = api.get_more_info( | ||||
|                     index, "cifar10-valid", hp="90", is_random=False | ||||
|                 ) | ||||
|                 valid_accs.append(info["valid-accuracy"]) | ||||
|             else: | ||||
|                 valid_accs.append(info["valid-accuracy"]) | ||||
|         info = { | ||||
|             "params": params, | ||||
|             "flops": flops, | ||||
|             "train_accs": train_accs, | ||||
|             "valid_accs": valid_accs, | ||||
|             "test_accs": test_accs, | ||||
|         } | ||||
|         torch.save(info, cache_file_path) | ||||
|     else: | ||||
|         print("Find cache file : {:}".format(cache_file_path)) | ||||
|         info = torch.load(cache_file_path) | ||||
|         params, flops, train_accs, valid_accs, test_accs = ( | ||||
|             info["params"], | ||||
|             info["flops"], | ||||
|             info["train_accs"], | ||||
|             info["valid_accs"], | ||||
|             info["test_accs"], | ||||
|         ) | ||||
|     print("{:} collect data done.".format(time_string())) | ||||
|  | ||||
|     pyramid = [ | ||||
|         "8:16:32:48:64", | ||||
|         "8:8:16:32:48", | ||||
|         "8:8:16:16:32", | ||||
|         "8:8:16:16:48", | ||||
|         "8:8:16:16:64", | ||||
|         "16:16:32:32:64", | ||||
|         "32:32:64:64:64", | ||||
|     ] | ||||
|     pyramid_indexes = [api.query_index_by_arch(x) for x in pyramid] | ||||
|     largest_indexes = [api.query_index_by_arch("64:64:64:64:64")] | ||||
|  | ||||
|     indexes = list(range(len(params))) | ||||
|     dpi, width, height = 250, 8500, 1300 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize = 24, 24 | ||||
|     # resnet_scale, resnet_alpha = 120, 0.5 | ||||
|     xscale, xalpha = 120, 0.8 | ||||
|  | ||||
|     fig, axs = plt.subplots(1, 4, figsize=figsize) | ||||
|     # ax1, ax2, ax3, ax4, ax5 = axs | ||||
|     for ax in axs: | ||||
|         for tick in ax.xaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize) | ||||
|         ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f")) | ||||
|         for tick in ax.yaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize) | ||||
|     ax2, ax3, ax4, ax5 = axs | ||||
|     # ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5)) | ||||
|     # ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue') | ||||
|     # ax1.set_xlabel('architecture ID', fontsize=LabelSize) | ||||
|     # ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize) | ||||
|  | ||||
|     ax2.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue") | ||||
|     ax2.scatter( | ||||
|         [params[x] for x in pyramid_indexes], | ||||
|         [train_accs[x] for x in pyramid_indexes], | ||||
|         marker="*", | ||||
|         s=xscale, | ||||
|         c="tab:orange", | ||||
|         label="Pyramid Structure", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax2.scatter( | ||||
|         [params[x] for x in largest_indexes], | ||||
|         [train_accs[x] for x in largest_indexes], | ||||
|         marker="x", | ||||
|         s=xscale, | ||||
|         c="tab:green", | ||||
|         label="Largest Candidate", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax2.set_xlabel("#parameters (MB)", fontsize=LabelSize) | ||||
|     ax2.set_ylabel("train accuracy (%)", fontsize=LabelSize) | ||||
|     ax2.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue") | ||||
|     ax3.scatter( | ||||
|         [params[x] for x in pyramid_indexes], | ||||
|         [test_accs[x] for x in pyramid_indexes], | ||||
|         marker="*", | ||||
|         s=xscale, | ||||
|         c="tab:orange", | ||||
|         label="Pyramid Structure", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax3.scatter( | ||||
|         [params[x] for x in largest_indexes], | ||||
|         [test_accs[x] for x in largest_indexes], | ||||
|         marker="x", | ||||
|         s=xscale, | ||||
|         c="tab:green", | ||||
|         label="Largest Candidate", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize) | ||||
|     ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize) | ||||
|     ax3.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     ax4.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue") | ||||
|     ax4.scatter( | ||||
|         [flops[x] for x in pyramid_indexes], | ||||
|         [train_accs[x] for x in pyramid_indexes], | ||||
|         marker="*", | ||||
|         s=xscale, | ||||
|         c="tab:orange", | ||||
|         label="Pyramid Structure", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax4.scatter( | ||||
|         [flops[x] for x in largest_indexes], | ||||
|         [train_accs[x] for x in largest_indexes], | ||||
|         marker="x", | ||||
|         s=xscale, | ||||
|         c="tab:green", | ||||
|         label="Largest Candidate", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize) | ||||
|     ax4.set_ylabel("train accuracy (%)", fontsize=LabelSize) | ||||
|     ax4.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     ax5.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue") | ||||
|     ax5.scatter( | ||||
|         [flops[x] for x in pyramid_indexes], | ||||
|         [test_accs[x] for x in pyramid_indexes], | ||||
|         marker="*", | ||||
|         s=xscale, | ||||
|         c="tab:orange", | ||||
|         label="Pyramid Structure", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax5.scatter( | ||||
|         [flops[x] for x in largest_indexes], | ||||
|         [test_accs[x] for x in largest_indexes], | ||||
|         marker="x", | ||||
|         s=xscale, | ||||
|         c="tab:green", | ||||
|         label="Largest Candidate", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax5.set_xlabel("#FLOPs (M)", fontsize=LabelSize) | ||||
|     ax5.set_ylabel("test accuracy (%)", fontsize=LabelSize) | ||||
|     ax5.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     save_path = vis_save_dir / "sss-{:}.png".format(dataset) | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("{:} save into {:}".format(time_string(), save_path)) | ||||
|     plt.close("all") | ||||
|  | ||||
|  | ||||
| def visualize_tss_info(api, dataset, vis_save_dir): | ||||
|     vis_save_dir = vis_save_dir.resolve() | ||||
|     print("{:} start to visualize {:} information".format(time_string(), dataset)) | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     cache_file_path = vis_save_dir / "{:}-cache-tss-info.pth".format(dataset) | ||||
|     if not cache_file_path.exists(): | ||||
|         print("Do not find cache file : {:}".format(cache_file_path)) | ||||
|         params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] | ||||
|         for index in range(len(api)): | ||||
|             cost_info = api.get_cost_info(index, dataset, hp="12") | ||||
|             params.append(cost_info["params"]) | ||||
|             flops.append(cost_info["flops"]) | ||||
|             # accuracy | ||||
|             info = api.get_more_info(index, dataset, hp="200", is_random=False) | ||||
|             train_accs.append(info["train-accuracy"]) | ||||
|             test_accs.append(info["test-accuracy"]) | ||||
|             if dataset == "cifar10": | ||||
|                 info = api.get_more_info( | ||||
|                     index, "cifar10-valid", hp="200", is_random=False | ||||
|                 ) | ||||
|                 valid_accs.append(info["valid-accuracy"]) | ||||
|             else: | ||||
|                 valid_accs.append(info["valid-accuracy"]) | ||||
|             print("") | ||||
|         info = { | ||||
|             "params": params, | ||||
|             "flops": flops, | ||||
|             "train_accs": train_accs, | ||||
|             "valid_accs": valid_accs, | ||||
|             "test_accs": test_accs, | ||||
|         } | ||||
|         torch.save(info, cache_file_path) | ||||
|     else: | ||||
|         print("Find cache file : {:}".format(cache_file_path)) | ||||
|         info = torch.load(cache_file_path) | ||||
|         params, flops, train_accs, valid_accs, test_accs = ( | ||||
|             info["params"], | ||||
|             info["flops"], | ||||
|             info["train_accs"], | ||||
|             info["valid_accs"], | ||||
|             info["test_accs"], | ||||
|         ) | ||||
|     print("{:} collect data done.".format(time_string())) | ||||
|  | ||||
|     resnet = [ | ||||
|         "|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|" | ||||
|     ] | ||||
|     resnet_indexes = [api.query_index_by_arch(x) for x in resnet] | ||||
|     largest_indexes = [ | ||||
|         api.query_index_by_arch( | ||||
|             "|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|" | ||||
|         ) | ||||
|     ] | ||||
|  | ||||
|     indexes = list(range(len(params))) | ||||
|     dpi, width, height = 250, 8500, 1300 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize = 24, 24 | ||||
|     # resnet_scale, resnet_alpha = 120, 0.5 | ||||
|     xscale, xalpha = 120, 0.8 | ||||
|  | ||||
|     fig, axs = plt.subplots(1, 4, figsize=figsize) | ||||
|     # ax1, ax2, ax3, ax4, ax5 = axs | ||||
|     for ax in axs: | ||||
|         for tick in ax.xaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize) | ||||
|         ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f")) | ||||
|         for tick in ax.yaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize) | ||||
|     ax2, ax3, ax4, ax5 = axs | ||||
|     # ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5)) | ||||
|     # ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue') | ||||
|     # ax1.set_xlabel('architecture ID', fontsize=LabelSize) | ||||
|     # ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize) | ||||
|  | ||||
|     ax2.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue") | ||||
|     ax2.scatter( | ||||
|         [params[x] for x in resnet_indexes], | ||||
|         [train_accs[x] for x in resnet_indexes], | ||||
|         marker="*", | ||||
|         s=xscale, | ||||
|         c="tab:orange", | ||||
|         label="ResNet", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax2.scatter( | ||||
|         [params[x] for x in largest_indexes], | ||||
|         [train_accs[x] for x in largest_indexes], | ||||
|         marker="x", | ||||
|         s=xscale, | ||||
|         c="tab:green", | ||||
|         label="Largest Candidate", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax2.set_xlabel("#parameters (MB)", fontsize=LabelSize) | ||||
|     ax2.set_ylabel("train accuracy (%)", fontsize=LabelSize) | ||||
|     ax2.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue") | ||||
|     ax3.scatter( | ||||
|         [params[x] for x in resnet_indexes], | ||||
|         [test_accs[x] for x in resnet_indexes], | ||||
|         marker="*", | ||||
|         s=xscale, | ||||
|         c="tab:orange", | ||||
|         label="ResNet", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax3.scatter( | ||||
|         [params[x] for x in largest_indexes], | ||||
|         [test_accs[x] for x in largest_indexes], | ||||
|         marker="x", | ||||
|         s=xscale, | ||||
|         c="tab:green", | ||||
|         label="Largest Candidate", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize) | ||||
|     ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize) | ||||
|     ax3.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     ax4.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue") | ||||
|     ax4.scatter( | ||||
|         [flops[x] for x in resnet_indexes], | ||||
|         [train_accs[x] for x in resnet_indexes], | ||||
|         marker="*", | ||||
|         s=xscale, | ||||
|         c="tab:orange", | ||||
|         label="ResNet", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax4.scatter( | ||||
|         [flops[x] for x in largest_indexes], | ||||
|         [train_accs[x] for x in largest_indexes], | ||||
|         marker="x", | ||||
|         s=xscale, | ||||
|         c="tab:green", | ||||
|         label="Largest Candidate", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize) | ||||
|     ax4.set_ylabel("train accuracy (%)", fontsize=LabelSize) | ||||
|     ax4.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     ax5.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue") | ||||
|     ax5.scatter( | ||||
|         [flops[x] for x in resnet_indexes], | ||||
|         [test_accs[x] for x in resnet_indexes], | ||||
|         marker="*", | ||||
|         s=xscale, | ||||
|         c="tab:orange", | ||||
|         label="ResNet", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax5.scatter( | ||||
|         [flops[x] for x in largest_indexes], | ||||
|         [test_accs[x] for x in largest_indexes], | ||||
|         marker="x", | ||||
|         s=xscale, | ||||
|         c="tab:green", | ||||
|         label="Largest Candidate", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax5.set_xlabel("#FLOPs (M)", fontsize=LabelSize) | ||||
|     ax5.set_ylabel("test accuracy (%)", fontsize=LabelSize) | ||||
|     ax5.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     save_path = vis_save_dir / "tss-{:}.png".format(dataset) | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("{:} save into {:}".format(time_string(), save_path)) | ||||
|     plt.close("all") | ||||
|  | ||||
|  | ||||
| def visualize_rank_info(api, vis_save_dir, indicator): | ||||
|     vis_save_dir = vis_save_dir.resolve() | ||||
|     # print ('{:} start to visualize {:} information'.format(time_string(), api)) | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "cifar10", indicator | ||||
|     ) | ||||
|     cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "cifar100", indicator | ||||
|     ) | ||||
|     imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "ImageNet16-120", indicator | ||||
|     ) | ||||
|     cifar010_info = torch.load(cifar010_cache_path) | ||||
|     cifar100_info = torch.load(cifar100_cache_path) | ||||
|     imagenet_info = torch.load(imagenet_cache_path) | ||||
|     indexes = list(range(len(cifar010_info["params"]))) | ||||
|  | ||||
|     print("{:} start to visualize relative ranking".format(time_string())) | ||||
|  | ||||
|     dpi, width, height = 250, 3800, 1200 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize = 14, 14 | ||||
|  | ||||
|     fig, axs = plt.subplots(1, 3, figsize=figsize) | ||||
|     ax1, ax2, ax3 = axs | ||||
|  | ||||
|     def get_labels(info): | ||||
|         ord_test_indexes = sorted(indexes, key=lambda i: info["test_accs"][i]) | ||||
|         ord_valid_indexes = sorted(indexes, key=lambda i: info["valid_accs"][i]) | ||||
|         labels = [] | ||||
|         for idx in ord_test_indexes: | ||||
|             labels.append(ord_valid_indexes.index(idx)) | ||||
|         return labels | ||||
|  | ||||
|     def plot_ax(labels, ax, name): | ||||
|         for tick in ax.xaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize) | ||||
|         for tick in ax.yaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize) | ||||
|             tick.label.set_rotation(90) | ||||
|         ax.set_xlim(min(indexes), max(indexes)) | ||||
|         ax.set_ylim(min(indexes), max(indexes)) | ||||
|         ax.yaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 3)) | ||||
|         ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 5)) | ||||
|         ax.scatter(indexes, labels, marker="^", s=0.5, c="tab:green", alpha=0.8) | ||||
|         ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) | ||||
|         ax.scatter( | ||||
|             [-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name) | ||||
|         ) | ||||
|         ax.scatter( | ||||
|             [-1], | ||||
|             [-1], | ||||
|             marker="o", | ||||
|             s=100, | ||||
|             c="tab:blue", | ||||
|             label="{:} validation".format(name), | ||||
|         ) | ||||
|         ax.legend(loc=4, fontsize=LegendFontsize) | ||||
|         ax.set_xlabel("ranking on the {:} validation".format(name), fontsize=LabelSize) | ||||
|         ax.set_ylabel("architecture ranking", fontsize=LabelSize) | ||||
|  | ||||
|     labels = get_labels(cifar010_info) | ||||
|     plot_ax(labels, ax1, "CIFAR-10") | ||||
|     labels = get_labels(cifar100_info) | ||||
|     plot_ax(labels, ax2, "CIFAR-100") | ||||
|     labels = get_labels(imagenet_info) | ||||
|     plot_ax(labels, ax3, "ImageNet-16-120") | ||||
|  | ||||
|     save_path = ( | ||||
|         vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator) | ||||
|     ).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|     save_path = ( | ||||
|         vis_save_dir / "{:}-same-relative-rank.png".format(indicator) | ||||
|     ).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("{:} save into {:}".format(time_string(), save_path)) | ||||
|     plt.close("all") | ||||
|  | ||||
|  | ||||
| def calculate_correlation(*vectors): | ||||
|     matrix = [] | ||||
|     for i, vectori in enumerate(vectors): | ||||
|         x = [] | ||||
|         for j, vectorj in enumerate(vectors): | ||||
|             x.append(np.corrcoef(vectori, vectorj)[0, 1]) | ||||
|         matrix.append(x) | ||||
|     return np.array(matrix) | ||||
|  | ||||
|  | ||||
| def visualize_all_rank_info(api, vis_save_dir, indicator): | ||||
|     vis_save_dir = vis_save_dir.resolve() | ||||
|     # print ('{:} start to visualize {:} information'.format(time_string(), api)) | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "cifar10", indicator | ||||
|     ) | ||||
|     cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "cifar100", indicator | ||||
|     ) | ||||
|     imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "ImageNet16-120", indicator | ||||
|     ) | ||||
|     cifar010_info = torch.load(cifar010_cache_path) | ||||
|     cifar100_info = torch.load(cifar100_cache_path) | ||||
|     imagenet_info = torch.load(imagenet_cache_path) | ||||
|     indexes = list(range(len(cifar010_info["params"]))) | ||||
|  | ||||
|     print("{:} start to visualize relative ranking".format(time_string())) | ||||
|  | ||||
|     dpi, width, height = 250, 3200, 1400 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize = 14, 14 | ||||
|  | ||||
|     fig, axs = plt.subplots(1, 2, figsize=figsize) | ||||
|     ax1, ax2 = axs | ||||
|  | ||||
|     sns_size = 15 | ||||
|     CoRelMatrix = calculate_correlation( | ||||
|         cifar010_info["valid_accs"], | ||||
|         cifar010_info["test_accs"], | ||||
|         cifar100_info["valid_accs"], | ||||
|         cifar100_info["test_accs"], | ||||
|         imagenet_info["valid_accs"], | ||||
|         imagenet_info["test_accs"], | ||||
|     ) | ||||
|  | ||||
|     sns.heatmap( | ||||
|         CoRelMatrix, | ||||
|         annot=True, | ||||
|         annot_kws={"size": sns_size}, | ||||
|         fmt=".3f", | ||||
|         linewidths=0.5, | ||||
|         ax=ax1, | ||||
|         xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], | ||||
|         yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], | ||||
|     ) | ||||
|  | ||||
|     selected_indexes, acc_bar = [], 92 | ||||
|     for i, acc in enumerate(cifar010_info["test_accs"]): | ||||
|         if acc > acc_bar: | ||||
|             selected_indexes.append(i) | ||||
|     cifar010_valid_accs = np.array(cifar010_info["valid_accs"])[selected_indexes] | ||||
|     cifar010_test_accs = np.array(cifar010_info["test_accs"])[selected_indexes] | ||||
|     cifar100_valid_accs = np.array(cifar100_info["valid_accs"])[selected_indexes] | ||||
|     cifar100_test_accs = np.array(cifar100_info["test_accs"])[selected_indexes] | ||||
|     imagenet_valid_accs = np.array(imagenet_info["valid_accs"])[selected_indexes] | ||||
|     imagenet_test_accs = np.array(imagenet_info["test_accs"])[selected_indexes] | ||||
|     CoRelMatrix = calculate_correlation( | ||||
|         cifar010_valid_accs, | ||||
|         cifar010_test_accs, | ||||
|         cifar100_valid_accs, | ||||
|         cifar100_test_accs, | ||||
|         imagenet_valid_accs, | ||||
|         imagenet_test_accs, | ||||
|     ) | ||||
|  | ||||
|     sns.heatmap( | ||||
|         CoRelMatrix, | ||||
|         annot=True, | ||||
|         annot_kws={"size": sns_size}, | ||||
|         fmt=".3f", | ||||
|         linewidths=0.5, | ||||
|         ax=ax2, | ||||
|         xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], | ||||
|         yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], | ||||
|     ) | ||||
|     ax1.set_title("Correlation coefficient over ALL candidates") | ||||
|     ax2.set_title( | ||||
|         "Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar) | ||||
|     ) | ||||
|     save_path = (vis_save_dir / "{:}-all-relative-rank.png".format(indicator)).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("{:} save into {:}".format(time_string(), save_path)) | ||||
|     plt.close("all") | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NAS-Bench-X", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="output/vis-nas-bench", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     # use for train the model | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     to_save_dir = Path(args.save_dir) | ||||
|  | ||||
|     datasets = ["cifar10", "cifar100", "ImageNet16-120"] | ||||
|     api201 = create(None, "tss", verbose=True) | ||||
|     for xdata in datasets: | ||||
|         visualize_tss_info(api201, xdata, to_save_dir) | ||||
|  | ||||
|     api_sss = create(None, "size", verbose=True) | ||||
|     for xdata in datasets: | ||||
|         visualize_sss_info(api_sss, xdata, to_save_dir) | ||||
|  | ||||
|     visualize_info(None, to_save_dir, "tss") | ||||
|     visualize_info(None, to_save_dir, "sss") | ||||
|     visualize_rank_info(None, to_save_dir, "tss") | ||||
|     visualize_rank_info(None, to_save_dir, "sss") | ||||
|  | ||||
|     visualize_all_rank_info(None, to_save_dir, "tss") | ||||
|     visualize_all_rank_info(None, to_save_dir, "sss") | ||||
		Reference in New Issue
	
	Block a user