diff --git a/CHANGE-LOG.md b/CHANGE-LOG.md index aaeae65..5850b1b 100644 --- a/CHANGE-LOG.md +++ b/CHANGE-LOG.md @@ -9,4 +9,4 @@ - [2020.10.15] [446262a](https://github.com/D-X-Y/AutoDL-Projects/tree/446262a) Update NATS-BENCH to version 1.0 - [2020.12.20] [dae387a](https://github.com/D-X-Y/AutoDL-Projects/tree/dae387a) Update NATS-BENCH to version 1.1 - [2021.05.18] [98fadf8](https://github.com/D-X-Y/AutoDL-Projects/tree/98fadf8) Before moving to `xautodl` -- [2021.05.21] [5b09f05](https://github.com/D-X-Y/AutoDL-Projects/tree/5b09f05) `xautodl` is close to ready +- [2021.05.21] [8109ed1](https://github.com/D-X-Y/AutoDL-Projects/tree/8109ed1) `xautodl` is close to ready diff --git a/exps/LFNA/lfna.py b/exps/LFNA/lfna.py index 73e56f1..30eb924 100644 --- a/exps/LFNA/lfna.py +++ b/exps/LFNA/lfna.py @@ -1,5 +1,5 @@ ##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # +# Learning to Generate Model One Step Ahead # ##################################################### # python exps/LFNA/lfna.py --env_version v1 --workers 0 # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001 @@ -109,6 +109,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): final_best_name = "final-pretrain-{:}.pth".format(args.rand_seed) if meta_model.has_best(final_best_name): meta_model.load_best(final_best_name) + logger.log("Directly load the best model from {:}".format(final_best_name)) return meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) @@ -118,58 +119,64 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): left_time = "Time Left: {:}".format( convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) ) - total_meta_losses, total_match_losses = [], [] + total_meta_v1_losses, total_meta_v2_losses, total_match_losses = [], [], [] optimizer.zero_grad() for ibatch in range(args.meta_batch): rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) timestamps = meta_model.meta_timestamps[ rand_index : rand_index + xenv.seq_length ] + meta_embeds = meta_model.super_meta_embed[ + rand_index : rand_index + xenv.seq_length + ] - seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) - [seq_containers], time_embeds = meta_model( - torch.unsqueeze(timestamps, dim=0), None - ) - # performance loss - losses = [] + _, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to( args.device ) + # generate models one step ahead + [seq_containers], time_embeds = meta_model( + torch.unsqueeze(timestamps, dim=0), None + ) for container, inputs, targets in zip( seq_containers, seq_inputs, seq_targets ): predictions = base_model.forward_with_container(inputs, container) - loss = criterion(predictions, targets) - losses.append(loss) - meta_loss = torch.stack(losses).mean() - match_loss = criterion( - torch.squeeze(time_embeds, dim=0), - meta_model.super_meta_embed[rand_index : rand_index + xenv.seq_length], - ) - total_meta_losses.append(meta_loss) + total_meta_v1_losses.append(criterion(predictions, targets)) + # the matching loss + match_loss = criterion(torch.squeeze(time_embeds, dim=0), meta_embeds) total_match_losses.append(match_loss) + # generate models via memory + [seq_containers], _ = meta_model(None, torch.unsqueeze(meta_embeds, dim=0)) + for container, inputs, targets in zip( + seq_containers, seq_inputs, seq_targets + ): + predictions = base_model.forward_with_container(inputs, container) + total_meta_v2_losses.append(criterion(predictions, targets)) with torch.no_grad(): - meta_std = torch.stack(total_meta_losses).std().item() - final_meta_loss = torch.stack(total_meta_losses).mean() - final_match_loss = torch.stack(total_match_losses).mean() - total_loss = final_meta_loss + final_match_loss + meta_std = torch.stack(total_meta_v1_losses).std().item() + meta_v1_loss = torch.stack(total_meta_v1_losses).mean() + meta_v2_loss = torch.stack(total_meta_v2_losses).mean() + match_loss = torch.stack(total_match_losses).mean() + total_loss = meta_v1_loss + meta_v2_loss + match_loss total_loss.backward() optimizer.step() # success success, best_score = meta_model.save_best(-total_loss.item()) logger.log( - "{:} [Pre-V2 {:04d}/{:}] loss : {:.5f} +- {:.5f} = {:.5f} + {:.5f} (match)".format( + "{:} [Pre-V2 {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f} (match)".format( time_string(), iepoch, args.epochs, total_loss.item(), meta_std, - final_meta_loss.item(), - final_match_loss.item(), + meta_v1_loss.item(), + meta_v2_loss.item(), + match_loss.item(), ) - + ", batch={:}".format(len(total_meta_losses)) - + ", success={:}, best_score={:.4f}".format(success, -best_score) - + ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh) + + ", batch={:}".format(len(total_meta_v1_losses)) + + ", success={:}, best={:.4f}".format(success, -best_score) + + ", LS={:}/{:}".format(last_success_epoch - iepoch, early_stop_thresh) + ", {:}".format(left_time) ) if success: @@ -184,6 +191,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): meta_model.set_best_name(final_best_name) success, _ = meta_model.save_best(best_score + 1e-6) assert success + logger.log("Save the best model into {:}".format(final_best_name)) def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger): @@ -243,8 +251,8 @@ def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger): final_loss.item(), ) + ", batch={:}".format(len(losses)) - + ", success={:}, best_score={:.4f}".format(success, -best_score) - + ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh) + + ", success={:}, best={:.4f}".format(success, -best_score) + + ", LS={:}/{:}".format(last_success_epoch - iepoch, early_stop_thresh) + " {:}".format(left_time) ) if success: @@ -277,6 +285,8 @@ def main(args): logger.log("The base-model has {:} weights.".format(base_model.numel())) logger.log("The meta-model has {:} weights.".format(meta_model.numel())) + logger.log("The base-model is\n{:}".format(base_model)) + logger.log("The meta-model is\n{:}".format(meta_model)) batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) train_env.reset_max_seq_length(args.seq_length) @@ -294,9 +304,10 @@ def main(args): num_workers=args.workers, pin_memory=True, ) + pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) optimizer = torch.optim.Adam( - meta_model.parameters(), + meta_model.get_parameters(True, True, False), # fix hypernet lr=args.lr, weight_decay=args.weight_decay, amsgrad=True, @@ -306,14 +317,10 @@ def main(args): milestones=[1, 2, 3, 4, 5], gamma=0.2, ) - logger.log("The base-model is\n{:}".format(base_model)) - logger.log("The meta-model is\n{:}".format(meta_model)) logger.log("The optimizer is\n{:}".format(optimizer)) logger.log("The scheduler is\n{:}".format(lr_scheduler)) logger.log("Per epoch iterations = {:}".format(len(train_env_loader))) - pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) - if logger.path("model").exists(): ckp_data = torch.load(logger.path("model")) base_model.load_state_dict(ckp_data["base_model"])