| 
									
										
										
										
											2020-02-23 10:30:37 +11:00
										 |  |  | ##################################################### | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | 
					
						
							|  |  |  | ##################################################### | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | import math, torch | 
					
						
							|  |  |  | import torch.nn as nn | 
					
						
							|  |  |  | from bisect import bisect_right | 
					
						
							|  |  |  | from torch.optim import Optimizer | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class _LRScheduler(object): | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     def __init__(self, optimizer, warmup_epochs, epochs): | 
					
						
							|  |  |  |         if not isinstance(optimizer, Optimizer): | 
					
						
							|  |  |  |             raise TypeError("{:} is not an Optimizer".format(type(optimizer).__name__)) | 
					
						
							|  |  |  |         self.optimizer = optimizer | 
					
						
							|  |  |  |         for group in optimizer.param_groups: | 
					
						
							|  |  |  |             group.setdefault("initial_lr", group["lr"]) | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |         self.base_lrs = list( | 
					
						
							|  |  |  |             map(lambda group: group["initial_lr"], optimizer.param_groups) | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |         self.max_epochs = epochs | 
					
						
							|  |  |  |         self.warmup_epochs = warmup_epochs | 
					
						
							|  |  |  |         self.current_epoch = 0 | 
					
						
							|  |  |  |         self.current_iter = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def extra_repr(self): | 
					
						
							|  |  |  |         return "" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __repr__(self): | 
					
						
							|  |  |  |         return "{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}".format( | 
					
						
							|  |  |  |             name=self.__class__.__name__, **self.__dict__ | 
					
						
							|  |  |  |         ) + ", {:})".format( | 
					
						
							|  |  |  |             self.extra_repr() | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def state_dict(self): | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |         return { | 
					
						
							|  |  |  |             key: value for key, value in self.__dict__.items() if key != "optimizer" | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def load_state_dict(self, state_dict): | 
					
						
							|  |  |  |         self.__dict__.update(state_dict) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_lr(self): | 
					
						
							|  |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_min_info(self): | 
					
						
							|  |  |  |         lrs = self.get_lr() | 
					
						
							|  |  |  |         return "#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#".format( | 
					
						
							|  |  |  |             min(lrs), max(lrs), self.current_epoch, self.current_iter | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_min_lr(self): | 
					
						
							|  |  |  |         return min(self.get_lr()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def update(self, cur_epoch, cur_iter): | 
					
						
							|  |  |  |         if cur_epoch is not None: | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |             assert ( | 
					
						
							|  |  |  |                 isinstance(cur_epoch, int) and cur_epoch >= 0 | 
					
						
							|  |  |  |             ), "invalid cur-epoch : {:}".format(cur_epoch) | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |             self.current_epoch = cur_epoch | 
					
						
							|  |  |  |         if cur_iter is not None: | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |             assert ( | 
					
						
							|  |  |  |                 isinstance(cur_iter, float) and cur_iter >= 0 | 
					
						
							|  |  |  |             ), "invalid cur-iter : {:}".format(cur_iter) | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |             self.current_iter = cur_iter | 
					
						
							|  |  |  |         for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): | 
					
						
							|  |  |  |             param_group["lr"] = lr | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class CosineAnnealingLR(_LRScheduler): | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min): | 
					
						
							|  |  |  |         self.T_max = T_max | 
					
						
							|  |  |  |         self.eta_min = eta_min | 
					
						
							|  |  |  |         super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def extra_repr(self): | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |         return "type={:}, T-max={:}, eta-min={:}".format( | 
					
						
							|  |  |  |             "cosine", self.T_max, self.eta_min | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def get_lr(self): | 
					
						
							|  |  |  |         lrs = [] | 
					
						
							|  |  |  |         for base_lr in self.base_lrs: | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |             if ( | 
					
						
							|  |  |  |                 self.current_epoch >= self.warmup_epochs | 
					
						
							|  |  |  |                 and self.current_epoch < self.max_epochs | 
					
						
							|  |  |  |             ): | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |                 last_epoch = self.current_epoch - self.warmup_epochs | 
					
						
							|  |  |  |                 # if last_epoch < self.T_max: | 
					
						
							|  |  |  |                 # if last_epoch < self.max_epochs: | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |                 lr = ( | 
					
						
							|  |  |  |                     self.eta_min | 
					
						
							|  |  |  |                     + (base_lr - self.eta_min) | 
					
						
							|  |  |  |                     * (1 + math.cos(math.pi * last_epoch / self.T_max)) | 
					
						
							|  |  |  |                     / 2 | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |                 # else: | 
					
						
							|  |  |  |                 #  lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2 | 
					
						
							|  |  |  |             elif self.current_epoch >= self.max_epochs: | 
					
						
							|  |  |  |                 lr = self.eta_min | 
					
						
							|  |  |  |             else: | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |                 lr = ( | 
					
						
							|  |  |  |                     self.current_epoch / self.warmup_epochs | 
					
						
							|  |  |  |                     + self.current_iter / self.warmup_epochs | 
					
						
							|  |  |  |                 ) * base_lr | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |             lrs.append(lr) | 
					
						
							|  |  |  |         return lrs | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MultiStepLR(_LRScheduler): | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas): | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |         assert len(milestones) == len(gammas), "invalid {:} vs {:}".format( | 
					
						
							|  |  |  |             len(milestones), len(gammas) | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |         self.milestones = milestones | 
					
						
							|  |  |  |         self.gammas = gammas | 
					
						
							|  |  |  |         super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def extra_repr(self): | 
					
						
							|  |  |  |         return "type={:}, milestones={:}, gammas={:}, base-lrs={:}".format( | 
					
						
							|  |  |  |             "multistep", self.milestones, self.gammas, self.base_lrs | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_lr(self): | 
					
						
							|  |  |  |         lrs = [] | 
					
						
							|  |  |  |         for base_lr in self.base_lrs: | 
					
						
							|  |  |  |             if self.current_epoch >= self.warmup_epochs: | 
					
						
							|  |  |  |                 last_epoch = self.current_epoch - self.warmup_epochs | 
					
						
							|  |  |  |                 idx = bisect_right(self.milestones, last_epoch) | 
					
						
							|  |  |  |                 lr = base_lr | 
					
						
							|  |  |  |                 for x in self.gammas[:idx]: | 
					
						
							|  |  |  |                     lr *= x | 
					
						
							|  |  |  |             else: | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |                 lr = ( | 
					
						
							|  |  |  |                     self.current_epoch / self.warmup_epochs | 
					
						
							|  |  |  |                     + self.current_iter / self.warmup_epochs | 
					
						
							|  |  |  |                 ) * base_lr | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |             lrs.append(lr) | 
					
						
							|  |  |  |         return lrs | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ExponentialLR(_LRScheduler): | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     def __init__(self, optimizer, warmup_epochs, epochs, gamma): | 
					
						
							|  |  |  |         self.gamma = gamma | 
					
						
							|  |  |  |         super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def extra_repr(self): | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |         return "type={:}, gamma={:}, base-lrs={:}".format( | 
					
						
							|  |  |  |             "exponential", self.gamma, self.base_lrs | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def get_lr(self): | 
					
						
							|  |  |  |         lrs = [] | 
					
						
							|  |  |  |         for base_lr in self.base_lrs: | 
					
						
							|  |  |  |             if self.current_epoch >= self.warmup_epochs: | 
					
						
							|  |  |  |                 last_epoch = self.current_epoch - self.warmup_epochs | 
					
						
							|  |  |  |                 assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch) | 
					
						
							|  |  |  |                 lr = base_lr * (self.gamma ** last_epoch) | 
					
						
							|  |  |  |             else: | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |                 lr = ( | 
					
						
							|  |  |  |                     self.current_epoch / self.warmup_epochs | 
					
						
							|  |  |  |                     + self.current_iter / self.warmup_epochs | 
					
						
							|  |  |  |                 ) * base_lr | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |             lrs.append(lr) | 
					
						
							|  |  |  |         return lrs | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class LinearLR(_LRScheduler): | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR): | 
					
						
							|  |  |  |         self.max_LR = max_LR | 
					
						
							|  |  |  |         self.min_LR = min_LR | 
					
						
							|  |  |  |         super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def extra_repr(self): | 
					
						
							|  |  |  |         return "type={:}, max_LR={:}, min_LR={:}, base-lrs={:}".format( | 
					
						
							|  |  |  |             "LinearLR", self.max_LR, self.min_LR, self.base_lrs | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_lr(self): | 
					
						
							|  |  |  |         lrs = [] | 
					
						
							|  |  |  |         for base_lr in self.base_lrs: | 
					
						
							|  |  |  |             if self.current_epoch >= self.warmup_epochs: | 
					
						
							|  |  |  |                 last_epoch = self.current_epoch - self.warmup_epochs | 
					
						
							|  |  |  |                 assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch) | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |                 ratio = ( | 
					
						
							|  |  |  |                     (self.max_LR - self.min_LR) | 
					
						
							|  |  |  |                     * last_epoch | 
					
						
							|  |  |  |                     / self.max_epochs | 
					
						
							|  |  |  |                     / self.max_LR | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |                 lr = base_lr * (1 - ratio) | 
					
						
							|  |  |  |             else: | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |                 lr = ( | 
					
						
							|  |  |  |                     self.current_epoch / self.warmup_epochs | 
					
						
							|  |  |  |                     + self.current_iter / self.warmup_epochs | 
					
						
							|  |  |  |                 ) * base_lr | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |             lrs.append(lr) | 
					
						
							|  |  |  |         return lrs | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class CrossEntropyLabelSmooth(nn.Module): | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     def __init__(self, num_classes, epsilon): | 
					
						
							|  |  |  |         super(CrossEntropyLabelSmooth, self).__init__() | 
					
						
							|  |  |  |         self.num_classes = num_classes | 
					
						
							|  |  |  |         self.epsilon = epsilon | 
					
						
							|  |  |  |         self.logsoftmax = nn.LogSoftmax(dim=1) | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     def forward(self, inputs, targets): | 
					
						
							|  |  |  |         log_probs = self.logsoftmax(inputs) | 
					
						
							|  |  |  |         targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) | 
					
						
							|  |  |  |         targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes | 
					
						
							|  |  |  |         loss = (-targets * log_probs).mean(0).sum() | 
					
						
							|  |  |  |         return loss | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_optim_scheduler(parameters, config): | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     assert ( | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |         hasattr(config, "optim") | 
					
						
							|  |  |  |         and hasattr(config, "scheduler") | 
					
						
							|  |  |  |         and hasattr(config, "criterion") | 
					
						
							|  |  |  |     ), "config must have optim / scheduler / criterion keys instead of {:}".format( | 
					
						
							|  |  |  |         config | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     if config.optim == "SGD": | 
					
						
							|  |  |  |         optim = torch.optim.SGD( | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |             parameters, | 
					
						
							|  |  |  |             config.LR, | 
					
						
							|  |  |  |             momentum=config.momentum, | 
					
						
							|  |  |  |             weight_decay=config.decay, | 
					
						
							|  |  |  |             nesterov=config.nesterov, | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |         ) | 
					
						
							|  |  |  |     elif config.optim == "RMSprop": | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |         optim = torch.optim.RMSprop( | 
					
						
							|  |  |  |             parameters, config.LR, momentum=config.momentum, weight_decay=config.decay | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     else: | 
					
						
							|  |  |  |         raise ValueError("invalid optim : {:}".format(config.optim)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if config.scheduler == "cos": | 
					
						
							|  |  |  |         T_max = getattr(config, "T_max", config.epochs) | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |         scheduler = CosineAnnealingLR( | 
					
						
							|  |  |  |             optim, config.warmup, config.epochs, T_max, config.eta_min | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     elif config.scheduler == "multistep": | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |         scheduler = MultiStepLR( | 
					
						
							|  |  |  |             optim, config.warmup, config.epochs, config.milestones, config.gammas | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     elif config.scheduler == "exponential": | 
					
						
							|  |  |  |         scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma) | 
					
						
							|  |  |  |     elif config.scheduler == "linear": | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |         scheduler = LinearLR( | 
					
						
							|  |  |  |             optim, config.warmup, config.epochs, config.LR, config.LR_min | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     else: | 
					
						
							|  |  |  |         raise ValueError("invalid scheduler : {:}".format(config.scheduler)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if config.criterion == "Softmax": | 
					
						
							|  |  |  |         criterion = torch.nn.CrossEntropyLoss() | 
					
						
							|  |  |  |     elif config.criterion == "SmoothSoftmax": | 
					
						
							|  |  |  |         criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         raise ValueError("invalid criterion : {:}".format(config.criterion)) | 
					
						
							|  |  |  |     return optim, scheduler, criterion |