Updates
This commit is contained in:
		| @@ -1,7 +1,6 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
| ##################################################### | ##################################################### | ||||||
| import copy |  | ||||||
| import torch | import torch | ||||||
|  |  | ||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
| @@ -294,7 +293,9 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|                 best_loss = init_info["loss"] |                 best_loss = init_info["loss"] | ||||||
|                 new_param.data.copy_(init_info["param"].data) |                 new_param.data.copy_(init_info["param"].data) | ||||||
|             else: |             else: | ||||||
|                 best_new_param, best_loss = None, 1e9 |                 best_loss = 1e9 | ||||||
|  |             with torch.no_grad(): | ||||||
|  |                 best_new_param = new_param.detach().clone() | ||||||
|             for iepoch in range(epochs): |             for iepoch in range(epochs): | ||||||
|                 optimizer.zero_grad() |                 optimizer.zero_grad() | ||||||
|                 _, [_], time_embed = self(timestamp.view(1, 1), None, True) |                 _, [_], time_embed = self(timestamp.view(1, 1), None, True) | ||||||
| @@ -310,7 +311,7 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|                 if meta_loss.item() < best_loss: |                 if meta_loss.item() < best_loss: | ||||||
|                     with torch.no_grad(): |                     with torch.no_grad(): | ||||||
|                         best_loss = meta_loss.item() |                         best_loss = meta_loss.item() | ||||||
|                         best_new_param = new_param.detach() |                         best_new_param = new_param.detach().clone() | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             self.replace_append_learnt(None, None) |             self.replace_append_learnt(None, None) | ||||||
|             self.append_fixed(timestamp, best_new_param) |             self.append_fixed(timestamp, best_new_param) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user