Re-organize GeMOSA
This commit is contained in:
		| @@ -1,10 +1,9 @@ | ||||
| ##################################################### | ||||
| # Learning to Generate Model One Step Ahead         # | ||||
| ##################################################### | ||||
| # python exps/GeMOSA/lfna.py --env_version v1 --workers 0 | ||||
| # python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.001 | ||||
| # python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128 | ||||
| # python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128 | ||||
| # python exps/GeMOSA/main.py --env_version v1 --workers 0 | ||||
| # python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 8 --meta_batch 256 | ||||
| # python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128 | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| @@ -38,7 +37,9 @@ from lfna_utils import lfna_setup, train_model, TimeData | ||||
| from meta_model import MetaModelV1 | ||||
|  | ||||
|  | ||||
| def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False): | ||||
| def online_evaluate( | ||||
|     env, meta_model, base_model, criterion, args, logger, save=False, easy_adapt=False | ||||
| ): | ||||
|     logger.log("Online evaluate: {:}".format(env)) | ||||
|     loss_meter = AverageMeter() | ||||
|     w_containers = dict() | ||||
| @@ -46,25 +47,30 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F | ||||
|         with torch.no_grad(): | ||||
|             meta_model.eval() | ||||
|             base_model.eval() | ||||
|             [future_container], time_embeds = meta_model( | ||||
|                 future_time.to(args.device).view(-1), None, False | ||||
|             future_time_embed = meta_model.gen_time_embed( | ||||
|                 future_time.to(args.device).view(-1) | ||||
|             ) | ||||
|             [future_container] = meta_model.gen_model(future_time_embed) | ||||
|             if save: | ||||
|                 w_containers[idx] = future_container.no_grad_clone() | ||||
|             future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||
|             future_y_hat = base_model.forward_with_container(future_x, future_container) | ||||
|             future_loss = criterion(future_y_hat, future_y) | ||||
|             loss_meter.update(future_loss.item()) | ||||
|         refine, post_refine_loss = meta_model.adapt( | ||||
|             base_model, | ||||
|             criterion, | ||||
|             future_time.item(), | ||||
|             future_x, | ||||
|             future_y, | ||||
|             args.refine_lr, | ||||
|             args.refine_epochs, | ||||
|             {"param": time_embeds, "loss": future_loss.item()}, | ||||
|         ) | ||||
|         if easy_adapt: | ||||
|             meta_model.easy_adapt(future_time.item(), future_time_embed) | ||||
|             refine, post_refine_loss = False, -1 | ||||
|         else: | ||||
|             refine, post_refine_loss = meta_model.adapt( | ||||
|                 base_model, | ||||
|                 criterion, | ||||
|                 future_time.item(), | ||||
|                 future_x, | ||||
|                 future_y, | ||||
|                 args.refine_lr, | ||||
|                 args.refine_epochs, | ||||
|                 {"param": future_time_embed, "loss": future_loss.item()}, | ||||
|             ) | ||||
|         logger.log( | ||||
|             "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( | ||||
|                 idx, len(env), future_loss.item() | ||||
| @@ -106,7 +112,7 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): | ||||
|         ) | ||||
|         optimizer.zero_grad() | ||||
|  | ||||
|         generated_time_embeds = gen_time_embed(meta_model.meta_timestamps) | ||||
|         generated_time_embeds = meta_model.gen_time_embed(meta_model.meta_timestamps) | ||||
|  | ||||
|         batch_indexes = random.choices(total_indexes, k=args.meta_batch) | ||||
|  | ||||
| @@ -117,11 +123,9 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): | ||||
|         ) | ||||
|         # future loss | ||||
|         total_future_losses, total_present_losses = [], [] | ||||
|         future_containers, _ = meta_model( | ||||
|             None, generated_time_embeds[batch_indexes], False | ||||
|         ) | ||||
|         present_containers, _ = meta_model( | ||||
|             None, meta_model.super_meta_embed[batch_indexes], False | ||||
|         future_containers = meta_model.gen_model(generated_time_embeds[batch_indexes]) | ||||
|         present_containers = meta_model.gen_model( | ||||
|             meta_model.super_meta_embed[batch_indexes] | ||||
|         ) | ||||
|         for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()): | ||||
|             _, (inputs, targets) = xenv(time_step) | ||||
| @@ -216,13 +220,34 @@ def main(args): | ||||
|     # try to evaluate once | ||||
|     # online_evaluate(train_env, meta_model, base_model, criterion, args, logger) | ||||
|     # online_evaluate(valid_env, meta_model, base_model, criterion, args, logger) | ||||
|     """ | ||||
|     w_containers, loss_meter = online_evaluate( | ||||
|         all_env, meta_model, base_model, criterion, args, logger, True | ||||
|     ) | ||||
|     logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) | ||||
|     """ | ||||
|     _, test_loss_meter_adapt_v1 = online_evaluate( | ||||
|         valid_env, meta_model, base_model, criterion, args, logger, False, False | ||||
|     ) | ||||
|     _, test_loss_meter_adapt_v2 = online_evaluate( | ||||
|         valid_env, meta_model, base_model, criterion, args, logger, False, True | ||||
|     ) | ||||
|     logger.log( | ||||
|         "In the online test enviornment, the total loss for refine-adapt is {:}".format( | ||||
|             test_loss_meter_adapt_v1 | ||||
|         ) | ||||
|     ) | ||||
|     logger.log( | ||||
|         "In the online test enviornment, the total loss for easy-adapt is {:}".format( | ||||
|             test_loss_meter_adapt_v2 | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     save_checkpoint( | ||||
|         {"all_w_containers": w_containers}, | ||||
|         { | ||||
|             "test_loss_adapt_v1": test_loss_meter_adapt_v1.avg, | ||||
|             "test_loss_adapt_v2": test_loss_meter_adapt_v2.avg, | ||||
|         }, | ||||
|         logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed), | ||||
|         logger, | ||||
|     ) | ||||
|   | ||||
| @@ -198,7 +198,7 @@ class MetaModelV1(super_core.SuperModule): | ||||
|             batch_containers.append( | ||||
|                 self._shape_container.translate(torch.split(weights.squeeze(0), 1)) | ||||
|             ) | ||||
|         return batch_containers, time_embeds | ||||
|         return batch_containers | ||||
|  | ||||
|     def forward_raw(self, timestamps, time_embeds, tembed_only=False): | ||||
|         raise NotImplementedError | ||||
| @@ -206,6 +206,12 @@ class MetaModelV1(super_core.SuperModule): | ||||
|     def forward_candidate(self, input): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def easy_adapt(self, timestamp, time_embed): | ||||
|         with torch.no_grad(): | ||||
|             timestamp = torch.Tensor([timestamp]).to(self._meta_timestamps.device) | ||||
|             self.replace_append_learnt(None, None) | ||||
|             self.append_fixed(timestamp, time_embed) | ||||
|  | ||||
|     def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info): | ||||
|         distance = self.get_closest_meta_distance(timestamp) | ||||
|         if distance + self._interval * 1e-2 <= self._interval: | ||||
| @@ -230,23 +236,20 @@ class MetaModelV1(super_core.SuperModule): | ||||
|                 best_new_param = new_param.detach().clone() | ||||
|             for iepoch in range(epochs): | ||||
|                 optimizer.zero_grad() | ||||
|                 _, time_embed = self(timestamp.view(1), None) | ||||
|                 time_embed = self.gen_time_embed(timestamp.view(1)) | ||||
|                 match_loss = criterion(new_param, time_embed) | ||||
|  | ||||
|                 [container], time_embed = self(None, new_param.view(1, -1)) | ||||
|                 [container] = self.gen_model(new_param.view(1, -1)) | ||||
|                 y_hat = base_model.forward_with_container(x, container) | ||||
|                 meta_loss = criterion(y_hat, y) | ||||
|                 loss = meta_loss + match_loss | ||||
|                 loss.backward() | ||||
|                 optimizer.step() | ||||
|                 # print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item())) | ||||
|                 if meta_loss.item() < best_loss: | ||||
|                     with torch.no_grad(): | ||||
|                         best_loss = meta_loss.item() | ||||
|                         best_new_param = new_param.detach().clone() | ||||
|         with torch.no_grad(): | ||||
|             self.replace_append_learnt(None, None) | ||||
|             self.append_fixed(timestamp, best_new_param) | ||||
|         self.easy_adapt(timestamp, best_new_param) | ||||
|         return True, best_loss | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|   | ||||
| @@ -191,6 +191,8 @@ def visualize_env(save_dir, version): | ||||
|         allxs.append(allx) | ||||
|         allys.append(ally) | ||||
|     allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) | ||||
|     print("env: {:}".format(dynamic_env)) | ||||
|     print("oracle_map: {:}".format(dynamic_env.oracle_map)) | ||||
|     print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())) | ||||
|     print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())) | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user