Update xlayers
This commit is contained in:
		| @@ -21,6 +21,57 @@ from procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||
| from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||
| from datasets.synthetic_core import get_synthetic_env | ||||
| from models.xcore import get_model | ||||
| from xlayers import super_core | ||||
|  | ||||
|  | ||||
| class LFNAmlp: | ||||
|     """A LFNA meta-model that uses the MLP as delta-net.""" | ||||
|  | ||||
|     def __init__(self, obs_dim, hidden_sizes, act_name): | ||||
|         self.delta_net = super_core.SuperSequential( | ||||
|             super_core.SuperLinear(obs_dim, hidden_sizes[0]), | ||||
|             super_core.super_name2activation[act_name](), | ||||
|             super_core.SuperLinear(hidden_sizes[0], hidden_sizes[1]), | ||||
|             super_core.super_name2activation[act_name](), | ||||
|             super_core.SuperLinear(hidden_sizes[1], 1), | ||||
|         ) | ||||
|         self.meta_optimizer = torch.optim.Adam( | ||||
|             self.delta_net.parameters(), lr=0.01, amsgrad=True | ||||
|         ) | ||||
|  | ||||
|     def adapt(self, model, criterion, w_container, xs, ys): | ||||
|         containers = [w_container] | ||||
|         for idx, (x, y) in enumerate(zip(xs, ys)): | ||||
|             y_hat = model.forward_with_container(x, containers[-1]) | ||||
|             loss = criterion(y_hat, y) | ||||
|             gradients = torch.autograd.grad(loss, containers[-1].tensors) | ||||
|             with torch.no_grad(): | ||||
|                 flatten_w = containers[-1].flatten().view(-1, 1) | ||||
|                 flatten_g = containers[-1].flatten(gradients).view(-1, 1) | ||||
|                 input_statistics = torch.tensor([x.mean(), x.std()]).view(1, 2) | ||||
|                 input_statistics = input_statistics.expand(flatten_w.numel(), -1) | ||||
|             delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1) | ||||
|             delta = self.delta_net(delta_inputs).view(-1) | ||||
|             # delta = torch.clamp(delta, -0.5, 0.5) | ||||
|             unflatten_delta = containers[-1].unflatten(delta) | ||||
|             future_container = containers[-1].additive(unflatten_delta) | ||||
|             containers.append(future_container) | ||||
|         # containers = containers[1:] | ||||
|         meta_loss = [] | ||||
|         for idx, (x, y) in enumerate(zip(xs, ys)): | ||||
|             if idx == 0: | ||||
|                 continue | ||||
|             current_container = containers[idx] | ||||
|             y_hat = model.forward_with_container(x, current_container) | ||||
|             loss = criterion(y_hat, y) | ||||
|             meta_loss.append(loss) | ||||
|         meta_loss = sum(meta_loss) | ||||
|         meta_loss.backward() | ||||
|         self.meta_optimizer.step() | ||||
|  | ||||
|     def zero_grad(self): | ||||
|         self.meta_optimizer.zero_grad() | ||||
|         self.delta_net.zero_grad() | ||||
|  | ||||
|  | ||||
| class Population: | ||||
| @@ -28,11 +79,23 @@ class Population: | ||||
|  | ||||
|     def __init__(self): | ||||
|         self._time2model = dict() | ||||
|         self._time2score = dict()  # higher is better | ||||
|  | ||||
|     def append(self, timestamp, model): | ||||
|     def append(self, timestamp, model, score): | ||||
|         if timestamp in self._time2model: | ||||
|             raise ValueError("This timestamp has been added.") | ||||
|         self._time2model[timestamp] = model | ||||
|         self._time2score[timestamp] = score | ||||
|  | ||||
|     def query(self, timestamp): | ||||
|         closet_timestamp = None | ||||
|         for xtime, model in self._time2model.items(): | ||||
|             if ( | ||||
|                 closet_timestamp is None | ||||
|                 or timestamp - closet_timestamp >= timestamp - xtime | ||||
|             ): | ||||
|                 closet_timestamp = xtime | ||||
|         return self._time2model[closet_timestamp], closet_timestamp | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
| @@ -70,100 +133,39 @@ def main(args): | ||||
|     ) | ||||
|  | ||||
|     w_container = base_model.named_parameters_buffers() | ||||
|     criterion = torch.nn.MSELoss() | ||||
|     print("There are {:} weights.".format(w_container.numel())) | ||||
|  | ||||
|     adaptor = LFNAmlp(4, (50, 20), "leaky_relu") | ||||
|  | ||||
|     pool = Population() | ||||
|     pool.append(0, w_container) | ||||
|  | ||||
|     # LFNA meta-training | ||||
|     per_epoch_time, start_time = AverageMeter(), time.time() | ||||
|     for iepoch in range(args.epochs): | ||||
|         import pdb | ||||
|  | ||||
|         pdb.set_trace() | ||||
|         print("-") | ||||
|  | ||||
|     for i, idx in enumerate(to_evaluate_indexes): | ||||
|  | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time( | ||||
|                 per_timestamp_time.avg * (len(to_evaluate_indexes) - i), True | ||||
|             ) | ||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}][{:04d}]".format(i, len(to_evaluate_indexes), idx) | ||||
|             + " " | ||||
|             "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) | ||||
|             + need_time | ||||
|         ) | ||||
|         # train the same data | ||||
|         assert idx != 0 | ||||
|         historical_x = env_info["{:}-x".format(idx)] | ||||
|         historical_y = env_info["{:}-y".format(idx)] | ||||
|         # build model | ||||
|         mean, std = historical_x.mean().item(), historical_x.std().item() | ||||
|         model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) | ||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||
|         # build optimizer | ||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||
|         criterion = torch.nn.MSELoss() | ||||
|         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|             optimizer, | ||||
|             milestones=[ | ||||
|                 int(args.epochs * 0.25), | ||||
|                 int(args.epochs * 0.5), | ||||
|                 int(args.epochs * 0.75), | ||||
|             ], | ||||
|             gamma=0.3, | ||||
|         ) | ||||
|         train_metric = MSEMetric() | ||||
|         best_loss, best_param = None, None | ||||
|         for _iepoch in range(args.epochs): | ||||
|             preds = model(historical_x) | ||||
|             optimizer.zero_grad() | ||||
|             loss = criterion(preds, historical_y) | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|             lr_scheduler.step() | ||||
|             # save best | ||||
|             if best_loss is None or best_loss > loss.item(): | ||||
|                 best_loss = loss.item() | ||||
|                 best_param = copy.deepcopy(model.state_dict()) | ||||
|         model.load_state_dict(best_param) | ||||
|         with torch.no_grad(): | ||||
|             train_metric(preds, historical_y) | ||||
|         train_results = train_metric.get_info() | ||||
|  | ||||
|         metric = ComposeMetric(MSEMetric(), SaveMetric()) | ||||
|         eval_dataset = torch.utils.data.TensorDataset( | ||||
|             env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)] | ||||
|         ) | ||||
|         eval_loader = torch.utils.data.DataLoader( | ||||
|             eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 | ||||
|         ) | ||||
|         results = basic_eval_fn(eval_loader, model, metric, logger) | ||||
|         log_str = ( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, env_info["total"]) | ||||
|             + " train-mse: {:.5f}, eval-mse: {:.5f}".format( | ||||
|                 train_results["mse"], results["mse"] | ||||
|             ) | ||||
|         ) | ||||
|         logger.log(log_str) | ||||
|         for ibatch in range(args.meta_batch): | ||||
|             sampled_timestamp = random.randint(0, train_time_bar) | ||||
|             query_w_container, query_timestamp = pool.query(sampled_timestamp) | ||||
|             # def adapt(self, model, w_container, xs, ys): | ||||
|             xs, ys = [], [] | ||||
|             for it in range(sampled_timestamp, sampled_timestamp + args.max_seq): | ||||
|                 xs.append(env_info["{:}-x".format(it)]) | ||||
|                 ys.append(env_info["{:}-y".format(it)]) | ||||
|             adaptor.adapt(base_model, criterion, query_w_container, xs, ys) | ||||
|             import pdb | ||||
|  | ||||
|         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( | ||||
|             idx, env_info["total"] | ||||
|         ) | ||||
|         save_checkpoint( | ||||
|             { | ||||
|                 "model_state_dict": model.state_dict(), | ||||
|                 "model": model, | ||||
|                 "index": idx, | ||||
|                 "timestamp": env_info["{:}-timestamp".format(idx)], | ||||
|             }, | ||||
|             save_path, | ||||
|             logger, | ||||
|         ) | ||||
|             pdb.set_trace() | ||||
|         print("-") | ||||
|         logger.log("") | ||||
|  | ||||
|         per_timestamp_time.update(time.time() - start_time) | ||||
| @@ -188,10 +190,10 @@ if __name__ == "__main__": | ||||
|         help="The initial learning rate for the optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--batch_size", | ||||
|         "--meta_batch", | ||||
|         type=int, | ||||
|         default=512, | ||||
|         help="The batch size", | ||||
|         default=2, | ||||
|         help="The batch size for the meta-model", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--epochs", | ||||
| @@ -199,6 +201,12 @@ if __name__ == "__main__": | ||||
|         default=1000, | ||||
|         help="The total number of epochs.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--max_seq", | ||||
|         type=int, | ||||
|         default=5, | ||||
|         help="The maximum length of the sequence.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user