Update ablation for GeMOSA
This commit is contained in:
		| @@ -11,3 +11,4 @@ from .affine_utils import normalize_points, denormalize_points | |||||||
| from .affine_utils import identity2affine, solve2theta, affine2image | from .affine_utils import identity2affine, solve2theta, affine2image | ||||||
| from .hash_utils import get_md5_file | from .hash_utils import get_md5_file | ||||||
| from .str_utils import split_str2indexes | from .str_utils import split_str2indexes | ||||||
|  | from .str_utils import show_mean_var | ||||||
|   | |||||||
| @@ -1,3 +1,6 @@ | |||||||
|  | import numpy as np | ||||||
|  |  | ||||||
|  |  | ||||||
| def split_str2indexes(string: str, max_check: int, length_limit=5): | def split_str2indexes(string: str, max_check: int, length_limit=5): | ||||||
|     if not isinstance(string, str): |     if not isinstance(string, str): | ||||||
|         raise ValueError("Invalid scheme for {:}".format(string)) |         raise ValueError("Invalid scheme for {:}".format(string)) | ||||||
| @@ -19,3 +22,13 @@ def split_str2indexes(string: str, max_check: int, length_limit=5): | |||||||
|         for i in range(srange[0], srange[1] + 1): |         for i in range(srange[0], srange[1] + 1): | ||||||
|             indexes.add(i) |             indexes.add(i) | ||||||
|     return indexes |     return indexes | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def show_mean_var(xlist): | ||||||
|  |     values = np.array(xlist) | ||||||
|  |     print( | ||||||
|  |         "{:.3f}".format(values.mean()) | ||||||
|  |         + "$_{{\pm}{" | ||||||
|  |         + "{:.3f}".format(values.std()) | ||||||
|  |         + "}}$" | ||||||
|  |     ) | ||||||
|   | |||||||
| @@ -20,9 +20,7 @@ def optimize_fn(xs, ys, device="cpu", max_iter=2000, max_lr=0.1): | |||||||
|         SuperLinear(100, 1), |         SuperLinear(100, 1), | ||||||
|     ).to(device) |     ).to(device) | ||||||
|     model.train() |     model.train() | ||||||
|     optimizer = torch.optim.Adam( |     optimizer = torch.optim.Adam(model.parameters(), lr=max_lr, amsgrad=True) | ||||||
|         model.parameters(), lr=max_lr, amsgrad=True |  | ||||||
|     ) |  | ||||||
|     loss_func = torch.nn.MSELoss() |     loss_func = torch.nn.MSELoss() | ||||||
|     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|         optimizer, |         optimizer, | ||||||
| @@ -47,7 +45,7 @@ def optimize_fn(xs, ys, device="cpu", max_iter=2000, max_lr=0.1): | |||||||
|         if best_loss is None or best_loss > loss.item(): |         if best_loss is None or best_loss > loss.item(): | ||||||
|             best_loss = loss.item() |             best_loss = loss.item() | ||||||
|             best_param = copy.deepcopy(model.state_dict()) |             best_param = copy.deepcopy(model.state_dict()) | ||||||
|          |  | ||||||
|         # print('loss={:}, best-loss={:}'.format(loss.item(), best_loss)) |         # print('loss={:}, best-loss={:}'.format(loss.item(), best_loss)) | ||||||
|     model.load_state_dict(best_param) |     model.load_state_dict(best_param) | ||||||
|     return model, loss_func, best_loss |     return model, loss_func, best_loss | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user