diff --git a/CHANGE-LOG.md b/CHANGE-LOG.md index 5850b1b..68ac0bc 100644 --- a/CHANGE-LOG.md +++ b/CHANGE-LOG.md @@ -9,4 +9,4 @@ - [2020.10.15] [446262a](https://github.com/D-X-Y/AutoDL-Projects/tree/446262a) Update NATS-BENCH to version 1.0 - [2020.12.20] [dae387a](https://github.com/D-X-Y/AutoDL-Projects/tree/dae387a) Update NATS-BENCH to version 1.1 - [2021.05.18] [98fadf8](https://github.com/D-X-Y/AutoDL-Projects/tree/98fadf8) Before moving to `xautodl` -- [2021.05.21] [8109ed1](https://github.com/D-X-Y/AutoDL-Projects/tree/8109ed1) `xautodl` is close to ready +- [2021.05.21] [df99173](https://github.com/D-X-Y/AutoDL-Projects/tree/df99173) `xautodl` is close to ready diff --git a/exps/LFNA/lfna.py b/exps/LFNA/lfna.py index 30eb924..ec7d3c1 100644 --- a/exps/LFNA/lfna.py +++ b/exps/LFNA/lfna.py @@ -93,6 +93,38 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger): return loss_meter +def online_evaluate(env, meta_model, base_model, criterion, args, logger): + logger.log("Online evaluate: {:}".format(env)) + for idx, (timestamp, (future_x, future_y)) in enumerate(env): + future_time = timestamp.item() + time_seqs = [ + future_time - iseq * env.timestamp_interval + for iseq in range(args.seq_length * 2) + ] + time_seqs.reverse() + with torch.no_grad(): + meta_model.eval() + base_model.eval() + time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device) + [seq_containers], _ = meta_model(time_seqs, None) + future_container = seq_containers[-2] + _, (future_x, future_y) = env(time_seqs[0, -2].item()) + future_x, future_y = future_x.to(args.device), future_y.to(args.device) + future_y_hat = base_model.forward_with_container(future_x, future_container) + future_loss = criterion(future_y_hat, future_y) + logger.log( + "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( + idx, len(env), future_loss.item() + ) + ) + import pdb + + pdb.set_trace() + for iseq in range(args.seq_length): + time_seqs.append(future_time - iseq * eval_env.timestamp_interval) + print("-") + + def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): base_model.train() meta_model.train() @@ -176,7 +208,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): ) + ", batch={:}".format(len(total_meta_v1_losses)) + ", success={:}, best={:.4f}".format(success, -best_score) - + ", LS={:}/{:}".format(last_success_epoch - iepoch, early_stop_thresh) + + ", LS={:}/{:}".format(iepoch - last_success_epoch, early_stop_thresh) + ", {:}".format(left_time) ) if success: @@ -194,77 +226,6 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): logger.log("Save the best model into {:}".format(final_best_name)) -def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger): - base_model.train() - meta_model.train() - optimizer = torch.optim.Adam( - meta_model.parameters(), - lr=args.lr, - weight_decay=args.weight_decay, - amsgrad=True, - ) - logger.log("Pre-train the meta-model's embeddings") - logger.log("Using the optimizer: {:}".format(optimizer)) - - meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v1") - meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) - per_epoch_time, start_time = AverageMeter(), time.time() - last_success_epoch, early_stop_thresh = 0, args.pretrain_early_stop_thresh - for iepoch in range(args.epochs): - left_time = "Time Left: {:}".format( - convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) - ) - losses = [] - optimizer.zero_grad() - for ibatch in range(args.meta_batch): - rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) - timestamps = meta_model.meta_timestamps[ - rand_index : rand_index + xenv.seq_length - ] - - seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) - time_embeds = meta_model.super_meta_embed[ - rand_index : rand_index + xenv.seq_length - ] - [seq_containers], time_embeds = meta_model( - None, torch.unsqueeze(time_embeds, dim=0) - ) - seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to( - args.device - ) - for container, inputs, targets in zip( - seq_containers, seq_inputs, seq_targets - ): - predictions = base_model.forward_with_container(inputs, container) - loss = criterion(predictions, targets) - losses.append(loss) - final_loss = torch.stack(losses).mean() - final_loss.backward() - optimizer.step() - # success - success, best_score = meta_model.save_best(-final_loss.item()) - logger.log( - "{:} [Pre-V1 {:04d}/{:}] loss : {:.5f}".format( - time_string(), - iepoch, - args.epochs, - final_loss.item(), - ) - + ", batch={:}".format(len(losses)) - + ", success={:}, best={:.4f}".format(success, -best_score) - + ", LS={:}/{:}".format(last_success_epoch - iepoch, early_stop_thresh) - + " {:}".format(left_time) - ) - if success: - last_success_epoch = iepoch - if iepoch - last_success_epoch >= early_stop_thresh: - logger.log("Early stop the pre-training at {:}".format(iepoch)) - break - per_epoch_time.update(time.time() - start_time) - start_time = time.time() - meta_model.load_best() - - def main(args): logger, env_info, model_kwargs = lfna_setup(args) train_env = get_synthetic_env(mode="train", version=args.env_version) @@ -290,7 +251,7 @@ def main(args): batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) train_env.reset_max_seq_length(args.seq_length) - valid_env.reset_max_seq_length(args.seq_length) + # valid_env.reset_max_seq_length(args.seq_length) valid_env_loader = torch.utils.data.DataLoader( valid_env, batch_size=args.meta_batch, @@ -306,6 +267,11 @@ def main(args): ) pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) + # try to evaluate once + online_evaluate(valid_env, meta_model, base_model, criterion, args, logger) + import pdb + + pdb.set_trace() optimizer = torch.optim.Adam( meta_model.get_parameters(True, True, False), # fix hypernet lr=args.lr, @@ -558,7 +524,7 @@ if __name__ == "__main__": parser.add_argument( "--pretrain_early_stop_thresh", type=int, - default=200, + default=300, help="The #epochs for early stop.", ) parser.add_argument( diff --git a/exps/LFNA/lfna_meta_model.py b/exps/LFNA/lfna_meta_model.py index 19f80ad..765c890 100644 --- a/exps/LFNA/lfna_meta_model.py +++ b/exps/LFNA/lfna_meta_model.py @@ -22,6 +22,7 @@ class LFNA_Meta(super_core.SuperModule): meta_timestamps, mha_depth: int = 2, dropout: float = 0.1, + thresh: float = 0.05, ): super(LFNA_Meta, self).__init__() self._shape_container = shape_container @@ -30,6 +31,7 @@ class LFNA_Meta(super_core.SuperModule): for ilayer in range(self._num_layers): self._numel_per_layer.append(shape_container[ilayer].numel()) self._raw_meta_timestamps = meta_timestamps + self._thresh = thresh self.register_parameter( "_super_layer_embed", @@ -168,7 +170,14 @@ class LFNA_Meta(super_core.SuperModule): timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1)) timestamp_v_embed = meta_embeds.unsqueeze(dim=0) # create the mask - mask = torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1) + mask = ( + torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1) + ) | ( + torch.abs( + torch.unsqueeze(timestamps, dim=-1) - meta_timestamps.view(1, 1, -1) + ) + > self._thresh + ) timestamp_embeds = self._trans_att( timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask ) diff --git a/xautodl/xlayers/super_module.py b/xautodl/xlayers/super_module.py index ff56f34..f08461f 100644 --- a/xautodl/xlayers/super_module.py +++ b/xautodl/xlayers/super_module.py @@ -117,17 +117,14 @@ class SuperModule(abc.ABC, nn.Module): else: return False, self._meta_info[BEST_SCORE_KEY] - def load_best(self, best_save_path=None): - if best_save_path is None: - if ( - BEST_DIR_KEY not in self._meta_info - or BEST_SCORE_KEY not in self._meta_info - ): - raise ValueError("Please call save_best at first") + def load_best(self, best_save_name=None): + if BEST_DIR_KEY not in self._meta_info: + raise ValueError("Please set BEST_DIR_KEY at first") + if best_save_name is None: best_save_name = self._meta_info.get( BEST_NAME_KEY, "best-{:}.pth".format(self.__class__.__name__) ) - best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name) + best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name) state_dict = torch.load(best_save_path) self.load_state_dict(state_dict)