Update LFNA

This commit is contained in:
D-X-Y 2021-05-22 23:49:09 +08:00
parent 8109ed166a
commit df9917371e
2 changed files with 41 additions and 34 deletions

View File

@ -9,4 +9,4 @@
- [2020.10.15] [446262a](https://github.com/D-X-Y/AutoDL-Projects/tree/446262a) Update NATS-BENCH to version 1.0 - [2020.10.15] [446262a](https://github.com/D-X-Y/AutoDL-Projects/tree/446262a) Update NATS-BENCH to version 1.0
- [2020.12.20] [dae387a](https://github.com/D-X-Y/AutoDL-Projects/tree/dae387a) Update NATS-BENCH to version 1.1 - [2020.12.20] [dae387a](https://github.com/D-X-Y/AutoDL-Projects/tree/dae387a) Update NATS-BENCH to version 1.1
- [2021.05.18] [98fadf8](https://github.com/D-X-Y/AutoDL-Projects/tree/98fadf8) Before moving to `xautodl` - [2021.05.18] [98fadf8](https://github.com/D-X-Y/AutoDL-Projects/tree/98fadf8) Before moving to `xautodl`
- [2021.05.21] [5b09f05](https://github.com/D-X-Y/AutoDL-Projects/tree/5b09f05) `xautodl` is close to ready - [2021.05.21] [8109ed1](https://github.com/D-X-Y/AutoDL-Projects/tree/8109ed1) `xautodl` is close to ready

View File

@ -1,5 +1,5 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # # Learning to Generate Model One Step Ahead #
##################################################### #####################################################
# python exps/LFNA/lfna.py --env_version v1 --workers 0 # python exps/LFNA/lfna.py --env_version v1 --workers 0
# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001 # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001
@ -109,6 +109,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
final_best_name = "final-pretrain-{:}.pth".format(args.rand_seed) final_best_name = "final-pretrain-{:}.pth".format(args.rand_seed)
if meta_model.has_best(final_best_name): if meta_model.has_best(final_best_name):
meta_model.load_best(final_best_name) meta_model.load_best(final_best_name)
logger.log("Directly load the best model from {:}".format(final_best_name))
return return
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
@ -118,58 +119,64 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
left_time = "Time Left: {:}".format( left_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
) )
total_meta_losses, total_match_losses = [], [] total_meta_v1_losses, total_meta_v2_losses, total_match_losses = [], [], []
optimizer.zero_grad() optimizer.zero_grad()
for ibatch in range(args.meta_batch): for ibatch in range(args.meta_batch):
rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1)
timestamps = meta_model.meta_timestamps[ timestamps = meta_model.meta_timestamps[
rand_index : rand_index + xenv.seq_length rand_index : rand_index + xenv.seq_length
] ]
meta_embeds = meta_model.super_meta_embed[
rand_index : rand_index + xenv.seq_length
]
seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) _, (seq_inputs, seq_targets) = xenv.seq_call(timestamps)
[seq_containers], time_embeds = meta_model(
torch.unsqueeze(timestamps, dim=0), None
)
# performance loss
losses = []
seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to( seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to(
args.device args.device
) )
# generate models one step ahead
[seq_containers], time_embeds = meta_model(
torch.unsqueeze(timestamps, dim=0), None
)
for container, inputs, targets in zip( for container, inputs, targets in zip(
seq_containers, seq_inputs, seq_targets seq_containers, seq_inputs, seq_targets
): ):
predictions = base_model.forward_with_container(inputs, container) predictions = base_model.forward_with_container(inputs, container)
loss = criterion(predictions, targets) total_meta_v1_losses.append(criterion(predictions, targets))
losses.append(loss) # the matching loss
meta_loss = torch.stack(losses).mean() match_loss = criterion(torch.squeeze(time_embeds, dim=0), meta_embeds)
match_loss = criterion(
torch.squeeze(time_embeds, dim=0),
meta_model.super_meta_embed[rand_index : rand_index + xenv.seq_length],
)
total_meta_losses.append(meta_loss)
total_match_losses.append(match_loss) total_match_losses.append(match_loss)
# generate models via memory
[seq_containers], _ = meta_model(None, torch.unsqueeze(meta_embeds, dim=0))
for container, inputs, targets in zip(
seq_containers, seq_inputs, seq_targets
):
predictions = base_model.forward_with_container(inputs, container)
total_meta_v2_losses.append(criterion(predictions, targets))
with torch.no_grad(): with torch.no_grad():
meta_std = torch.stack(total_meta_losses).std().item() meta_std = torch.stack(total_meta_v1_losses).std().item()
final_meta_loss = torch.stack(total_meta_losses).mean() meta_v1_loss = torch.stack(total_meta_v1_losses).mean()
final_match_loss = torch.stack(total_match_losses).mean() meta_v2_loss = torch.stack(total_meta_v2_losses).mean()
total_loss = final_meta_loss + final_match_loss match_loss = torch.stack(total_match_losses).mean()
total_loss = meta_v1_loss + meta_v2_loss + match_loss
total_loss.backward() total_loss.backward()
optimizer.step() optimizer.step()
# success # success
success, best_score = meta_model.save_best(-total_loss.item()) success, best_score = meta_model.save_best(-total_loss.item())
logger.log( logger.log(
"{:} [Pre-V2 {:04d}/{:}] loss : {:.5f} +- {:.5f} = {:.5f} + {:.5f} (match)".format( "{:} [Pre-V2 {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f} (match)".format(
time_string(), time_string(),
iepoch, iepoch,
args.epochs, args.epochs,
total_loss.item(), total_loss.item(),
meta_std, meta_std,
final_meta_loss.item(), meta_v1_loss.item(),
final_match_loss.item(), meta_v2_loss.item(),
match_loss.item(),
) )
+ ", batch={:}".format(len(total_meta_losses)) + ", batch={:}".format(len(total_meta_v1_losses))
+ ", success={:}, best_score={:.4f}".format(success, -best_score) + ", success={:}, best={:.4f}".format(success, -best_score)
+ ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh) + ", LS={:}/{:}".format(last_success_epoch - iepoch, early_stop_thresh)
+ ", {:}".format(left_time) + ", {:}".format(left_time)
) )
if success: if success:
@ -184,6 +191,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
meta_model.set_best_name(final_best_name) meta_model.set_best_name(final_best_name)
success, _ = meta_model.save_best(best_score + 1e-6) success, _ = meta_model.save_best(best_score + 1e-6)
assert success assert success
logger.log("Save the best model into {:}".format(final_best_name))
def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger): def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
@ -243,8 +251,8 @@ def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
final_loss.item(), final_loss.item(),
) )
+ ", batch={:}".format(len(losses)) + ", batch={:}".format(len(losses))
+ ", success={:}, best_score={:.4f}".format(success, -best_score) + ", success={:}, best={:.4f}".format(success, -best_score)
+ ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh) + ", LS={:}/{:}".format(last_success_epoch - iepoch, early_stop_thresh)
+ " {:}".format(left_time) + " {:}".format(left_time)
) )
if success: if success:
@ -277,6 +285,8 @@ def main(args):
logger.log("The base-model has {:} weights.".format(base_model.numel())) logger.log("The base-model has {:} weights.".format(base_model.numel()))
logger.log("The meta-model has {:} weights.".format(meta_model.numel())) logger.log("The meta-model has {:} weights.".format(meta_model.numel()))
logger.log("The base-model is\n{:}".format(base_model))
logger.log("The meta-model is\n{:}".format(meta_model))
batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
train_env.reset_max_seq_length(args.seq_length) train_env.reset_max_seq_length(args.seq_length)
@ -294,9 +304,10 @@ def main(args):
num_workers=args.workers, num_workers=args.workers,
pin_memory=True, pin_memory=True,
) )
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
meta_model.parameters(), meta_model.get_parameters(True, True, False), # fix hypernet
lr=args.lr, lr=args.lr,
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
amsgrad=True, amsgrad=True,
@ -306,14 +317,10 @@ def main(args):
milestones=[1, 2, 3, 4, 5], milestones=[1, 2, 3, 4, 5],
gamma=0.2, gamma=0.2,
) )
logger.log("The base-model is\n{:}".format(base_model))
logger.log("The meta-model is\n{:}".format(meta_model))
logger.log("The optimizer is\n{:}".format(optimizer)) logger.log("The optimizer is\n{:}".format(optimizer))
logger.log("The scheduler is\n{:}".format(lr_scheduler)) logger.log("The scheduler is\n{:}".format(lr_scheduler))
logger.log("Per epoch iterations = {:}".format(len(train_env_loader))) logger.log("Per epoch iterations = {:}".format(len(train_env_loader)))
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
if logger.path("model").exists(): if logger.path("model").exists():
ckp_data = torch.load(logger.path("model")) ckp_data = torch.load(logger.path("model"))
base_model.load_state_dict(ckp_data["base_model"]) base_model.load_state_dict(ckp_data["base_model"])