Update LFNA
This commit is contained in:
parent
8109ed166a
commit
df9917371e
@ -9,4 +9,4 @@
|
|||||||
- [2020.10.15] [446262a](https://github.com/D-X-Y/AutoDL-Projects/tree/446262a) Update NATS-BENCH to version 1.0
|
- [2020.10.15] [446262a](https://github.com/D-X-Y/AutoDL-Projects/tree/446262a) Update NATS-BENCH to version 1.0
|
||||||
- [2020.12.20] [dae387a](https://github.com/D-X-Y/AutoDL-Projects/tree/dae387a) Update NATS-BENCH to version 1.1
|
- [2020.12.20] [dae387a](https://github.com/D-X-Y/AutoDL-Projects/tree/dae387a) Update NATS-BENCH to version 1.1
|
||||||
- [2021.05.18] [98fadf8](https://github.com/D-X-Y/AutoDL-Projects/tree/98fadf8) Before moving to `xautodl`
|
- [2021.05.18] [98fadf8](https://github.com/D-X-Y/AutoDL-Projects/tree/98fadf8) Before moving to `xautodl`
|
||||||
- [2021.05.21] [5b09f05](https://github.com/D-X-Y/AutoDL-Projects/tree/5b09f05) `xautodl` is close to ready
|
- [2021.05.21] [8109ed1](https://github.com/D-X-Y/AutoDL-Projects/tree/8109ed1) `xautodl` is close to ready
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
# Learning to Generate Model One Step Ahead #
|
||||||
#####################################################
|
#####################################################
|
||||||
# python exps/LFNA/lfna.py --env_version v1 --workers 0
|
# python exps/LFNA/lfna.py --env_version v1 --workers 0
|
||||||
# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001
|
# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001
|
||||||
@ -109,6 +109,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
|||||||
final_best_name = "final-pretrain-{:}.pth".format(args.rand_seed)
|
final_best_name = "final-pretrain-{:}.pth".format(args.rand_seed)
|
||||||
if meta_model.has_best(final_best_name):
|
if meta_model.has_best(final_best_name):
|
||||||
meta_model.load_best(final_best_name)
|
meta_model.load_best(final_best_name)
|
||||||
|
logger.log("Directly load the best model from {:}".format(final_best_name))
|
||||||
return
|
return
|
||||||
|
|
||||||
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
|
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
|
||||||
@ -118,58 +119,64 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
|||||||
left_time = "Time Left: {:}".format(
|
left_time = "Time Left: {:}".format(
|
||||||
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
|
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
|
||||||
)
|
)
|
||||||
total_meta_losses, total_match_losses = [], []
|
total_meta_v1_losses, total_meta_v2_losses, total_match_losses = [], [], []
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
for ibatch in range(args.meta_batch):
|
for ibatch in range(args.meta_batch):
|
||||||
rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1)
|
rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1)
|
||||||
timestamps = meta_model.meta_timestamps[
|
timestamps = meta_model.meta_timestamps[
|
||||||
rand_index : rand_index + xenv.seq_length
|
rand_index : rand_index + xenv.seq_length
|
||||||
]
|
]
|
||||||
|
meta_embeds = meta_model.super_meta_embed[
|
||||||
|
rand_index : rand_index + xenv.seq_length
|
||||||
|
]
|
||||||
|
|
||||||
seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps)
|
_, (seq_inputs, seq_targets) = xenv.seq_call(timestamps)
|
||||||
[seq_containers], time_embeds = meta_model(
|
|
||||||
torch.unsqueeze(timestamps, dim=0), None
|
|
||||||
)
|
|
||||||
# performance loss
|
|
||||||
losses = []
|
|
||||||
seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to(
|
seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to(
|
||||||
args.device
|
args.device
|
||||||
)
|
)
|
||||||
|
# generate models one step ahead
|
||||||
|
[seq_containers], time_embeds = meta_model(
|
||||||
|
torch.unsqueeze(timestamps, dim=0), None
|
||||||
|
)
|
||||||
for container, inputs, targets in zip(
|
for container, inputs, targets in zip(
|
||||||
seq_containers, seq_inputs, seq_targets
|
seq_containers, seq_inputs, seq_targets
|
||||||
):
|
):
|
||||||
predictions = base_model.forward_with_container(inputs, container)
|
predictions = base_model.forward_with_container(inputs, container)
|
||||||
loss = criterion(predictions, targets)
|
total_meta_v1_losses.append(criterion(predictions, targets))
|
||||||
losses.append(loss)
|
# the matching loss
|
||||||
meta_loss = torch.stack(losses).mean()
|
match_loss = criterion(torch.squeeze(time_embeds, dim=0), meta_embeds)
|
||||||
match_loss = criterion(
|
|
||||||
torch.squeeze(time_embeds, dim=0),
|
|
||||||
meta_model.super_meta_embed[rand_index : rand_index + xenv.seq_length],
|
|
||||||
)
|
|
||||||
total_meta_losses.append(meta_loss)
|
|
||||||
total_match_losses.append(match_loss)
|
total_match_losses.append(match_loss)
|
||||||
|
# generate models via memory
|
||||||
|
[seq_containers], _ = meta_model(None, torch.unsqueeze(meta_embeds, dim=0))
|
||||||
|
for container, inputs, targets in zip(
|
||||||
|
seq_containers, seq_inputs, seq_targets
|
||||||
|
):
|
||||||
|
predictions = base_model.forward_with_container(inputs, container)
|
||||||
|
total_meta_v2_losses.append(criterion(predictions, targets))
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
meta_std = torch.stack(total_meta_losses).std().item()
|
meta_std = torch.stack(total_meta_v1_losses).std().item()
|
||||||
final_meta_loss = torch.stack(total_meta_losses).mean()
|
meta_v1_loss = torch.stack(total_meta_v1_losses).mean()
|
||||||
final_match_loss = torch.stack(total_match_losses).mean()
|
meta_v2_loss = torch.stack(total_meta_v2_losses).mean()
|
||||||
total_loss = final_meta_loss + final_match_loss
|
match_loss = torch.stack(total_match_losses).mean()
|
||||||
|
total_loss = meta_v1_loss + meta_v2_loss + match_loss
|
||||||
total_loss.backward()
|
total_loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# success
|
# success
|
||||||
success, best_score = meta_model.save_best(-total_loss.item())
|
success, best_score = meta_model.save_best(-total_loss.item())
|
||||||
logger.log(
|
logger.log(
|
||||||
"{:} [Pre-V2 {:04d}/{:}] loss : {:.5f} +- {:.5f} = {:.5f} + {:.5f} (match)".format(
|
"{:} [Pre-V2 {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f} (match)".format(
|
||||||
time_string(),
|
time_string(),
|
||||||
iepoch,
|
iepoch,
|
||||||
args.epochs,
|
args.epochs,
|
||||||
total_loss.item(),
|
total_loss.item(),
|
||||||
meta_std,
|
meta_std,
|
||||||
final_meta_loss.item(),
|
meta_v1_loss.item(),
|
||||||
final_match_loss.item(),
|
meta_v2_loss.item(),
|
||||||
|
match_loss.item(),
|
||||||
)
|
)
|
||||||
+ ", batch={:}".format(len(total_meta_losses))
|
+ ", batch={:}".format(len(total_meta_v1_losses))
|
||||||
+ ", success={:}, best_score={:.4f}".format(success, -best_score)
|
+ ", success={:}, best={:.4f}".format(success, -best_score)
|
||||||
+ ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh)
|
+ ", LS={:}/{:}".format(last_success_epoch - iepoch, early_stop_thresh)
|
||||||
+ ", {:}".format(left_time)
|
+ ", {:}".format(left_time)
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
@ -184,6 +191,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
|||||||
meta_model.set_best_name(final_best_name)
|
meta_model.set_best_name(final_best_name)
|
||||||
success, _ = meta_model.save_best(best_score + 1e-6)
|
success, _ = meta_model.save_best(best_score + 1e-6)
|
||||||
assert success
|
assert success
|
||||||
|
logger.log("Save the best model into {:}".format(final_best_name))
|
||||||
|
|
||||||
|
|
||||||
def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
|
def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
|
||||||
@ -243,8 +251,8 @@ def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
|
|||||||
final_loss.item(),
|
final_loss.item(),
|
||||||
)
|
)
|
||||||
+ ", batch={:}".format(len(losses))
|
+ ", batch={:}".format(len(losses))
|
||||||
+ ", success={:}, best_score={:.4f}".format(success, -best_score)
|
+ ", success={:}, best={:.4f}".format(success, -best_score)
|
||||||
+ ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh)
|
+ ", LS={:}/{:}".format(last_success_epoch - iepoch, early_stop_thresh)
|
||||||
+ " {:}".format(left_time)
|
+ " {:}".format(left_time)
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
@ -277,6 +285,8 @@ def main(args):
|
|||||||
|
|
||||||
logger.log("The base-model has {:} weights.".format(base_model.numel()))
|
logger.log("The base-model has {:} weights.".format(base_model.numel()))
|
||||||
logger.log("The meta-model has {:} weights.".format(meta_model.numel()))
|
logger.log("The meta-model has {:} weights.".format(meta_model.numel()))
|
||||||
|
logger.log("The base-model is\n{:}".format(base_model))
|
||||||
|
logger.log("The meta-model is\n{:}".format(meta_model))
|
||||||
|
|
||||||
batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
|
batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
|
||||||
train_env.reset_max_seq_length(args.seq_length)
|
train_env.reset_max_seq_length(args.seq_length)
|
||||||
@ -294,9 +304,10 @@ def main(args):
|
|||||||
num_workers=args.workers,
|
num_workers=args.workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
|
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
|
||||||
|
|
||||||
optimizer = torch.optim.Adam(
|
optimizer = torch.optim.Adam(
|
||||||
meta_model.parameters(),
|
meta_model.get_parameters(True, True, False), # fix hypernet
|
||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
weight_decay=args.weight_decay,
|
weight_decay=args.weight_decay,
|
||||||
amsgrad=True,
|
amsgrad=True,
|
||||||
@ -306,14 +317,10 @@ def main(args):
|
|||||||
milestones=[1, 2, 3, 4, 5],
|
milestones=[1, 2, 3, 4, 5],
|
||||||
gamma=0.2,
|
gamma=0.2,
|
||||||
)
|
)
|
||||||
logger.log("The base-model is\n{:}".format(base_model))
|
|
||||||
logger.log("The meta-model is\n{:}".format(meta_model))
|
|
||||||
logger.log("The optimizer is\n{:}".format(optimizer))
|
logger.log("The optimizer is\n{:}".format(optimizer))
|
||||||
logger.log("The scheduler is\n{:}".format(lr_scheduler))
|
logger.log("The scheduler is\n{:}".format(lr_scheduler))
|
||||||
logger.log("Per epoch iterations = {:}".format(len(train_env_loader)))
|
logger.log("Per epoch iterations = {:}".format(len(train_env_loader)))
|
||||||
|
|
||||||
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
|
|
||||||
|
|
||||||
if logger.path("model").exists():
|
if logger.path("model").exists():
|
||||||
ckp_data = torch.load(logger.path("model"))
|
ckp_data = torch.load(logger.path("model"))
|
||||||
base_model.load_state_dict(ckp_data["base_model"])
|
base_model.load_state_dict(ckp_data["base_model"])
|
||||||
|
Loading…
Reference in New Issue
Block a user