264 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			264 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #####################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
 | |
| #####################################################
 | |
| import math, torch
 | |
| import torch.nn as nn
 | |
| from bisect import bisect_right
 | |
| from torch.optim import Optimizer
 | |
| 
 | |
| 
 | |
| class _LRScheduler(object):
 | |
|     def __init__(self, optimizer, warmup_epochs, epochs):
 | |
|         if not isinstance(optimizer, Optimizer):
 | |
|             raise TypeError("{:} is not an Optimizer".format(type(optimizer).__name__))
 | |
|         self.optimizer = optimizer
 | |
|         for group in optimizer.param_groups:
 | |
|             group.setdefault("initial_lr", group["lr"])
 | |
|         self.base_lrs = list(
 | |
|             map(lambda group: group["initial_lr"], optimizer.param_groups)
 | |
|         )
 | |
|         self.max_epochs = epochs
 | |
|         self.warmup_epochs = warmup_epochs
 | |
|         self.current_epoch = 0
 | |
|         self.current_iter = 0
 | |
| 
 | |
|     def extra_repr(self):
 | |
|         return ""
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return "{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}".format(
 | |
|             name=self.__class__.__name__, **self.__dict__
 | |
|         ) + ", {:})".format(
 | |
|             self.extra_repr()
 | |
|         )
 | |
| 
 | |
|     def state_dict(self):
 | |
|         return {
 | |
|             key: value for key, value in self.__dict__.items() if key != "optimizer"
 | |
|         }
 | |
| 
 | |
|     def load_state_dict(self, state_dict):
 | |
|         self.__dict__.update(state_dict)
 | |
| 
 | |
|     def get_lr(self):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def get_min_info(self):
 | |
|         lrs = self.get_lr()
 | |
|         return "#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#".format(
 | |
|             min(lrs), max(lrs), self.current_epoch, self.current_iter
 | |
|         )
 | |
| 
 | |
|     def get_min_lr(self):
 | |
|         return min(self.get_lr())
 | |
| 
 | |
|     def update(self, cur_epoch, cur_iter):
 | |
|         if cur_epoch is not None:
 | |
|             assert (
 | |
|                 isinstance(cur_epoch, int) and cur_epoch >= 0
 | |
|             ), "invalid cur-epoch : {:}".format(cur_epoch)
 | |
|             self.current_epoch = cur_epoch
 | |
|         if cur_iter is not None:
 | |
|             assert (
 | |
|                 isinstance(cur_iter, float) and cur_iter >= 0
 | |
|             ), "invalid cur-iter : {:}".format(cur_iter)
 | |
|             self.current_iter = cur_iter
 | |
|         for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
 | |
|             param_group["lr"] = lr
 | |
| 
 | |
| 
 | |
| class CosineAnnealingLR(_LRScheduler):
 | |
|     def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min):
 | |
|         self.T_max = T_max
 | |
|         self.eta_min = eta_min
 | |
|         super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
 | |
| 
 | |
|     def extra_repr(self):
 | |
|         return "type={:}, T-max={:}, eta-min={:}".format(
 | |
|             "cosine", self.T_max, self.eta_min
 | |
|         )
 | |
| 
 | |
|     def get_lr(self):
 | |
|         lrs = []
 | |
|         for base_lr in self.base_lrs:
 | |
|             if (
 | |
|                 self.current_epoch >= self.warmup_epochs
 | |
|                 and self.current_epoch < self.max_epochs
 | |
|             ):
 | |
|                 last_epoch = self.current_epoch - self.warmup_epochs
 | |
|                 # if last_epoch < self.T_max:
 | |
|                 # if last_epoch < self.max_epochs:
 | |
|                 lr = (
 | |
|                     self.eta_min
 | |
|                     + (base_lr - self.eta_min)
 | |
|                     * (1 + math.cos(math.pi * last_epoch / self.T_max))
 | |
|                     / 2
 | |
|                 )
 | |
|                 # else:
 | |
|                 #  lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
 | |
|             elif self.current_epoch >= self.max_epochs:
 | |
|                 lr = self.eta_min
 | |
|             else:
 | |
|                 lr = (
 | |
|                     self.current_epoch / self.warmup_epochs
 | |
|                     + self.current_iter / self.warmup_epochs
 | |
|                 ) * base_lr
 | |
|             lrs.append(lr)
 | |
|         return lrs
 | |
| 
 | |
| 
 | |
| class MultiStepLR(_LRScheduler):
 | |
|     def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
 | |
|         assert len(milestones) == len(gammas), "invalid {:} vs {:}".format(
 | |
|             len(milestones), len(gammas)
 | |
|         )
 | |
|         self.milestones = milestones
 | |
|         self.gammas = gammas
 | |
|         super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
 | |
| 
 | |
|     def extra_repr(self):
 | |
|         return "type={:}, milestones={:}, gammas={:}, base-lrs={:}".format(
 | |
|             "multistep", self.milestones, self.gammas, self.base_lrs
 | |
|         )
 | |
| 
 | |
|     def get_lr(self):
 | |
|         lrs = []
 | |
|         for base_lr in self.base_lrs:
 | |
|             if self.current_epoch >= self.warmup_epochs:
 | |
|                 last_epoch = self.current_epoch - self.warmup_epochs
 | |
|                 idx = bisect_right(self.milestones, last_epoch)
 | |
|                 lr = base_lr
 | |
|                 for x in self.gammas[:idx]:
 | |
|                     lr *= x
 | |
|             else:
 | |
|                 lr = (
 | |
|                     self.current_epoch / self.warmup_epochs
 | |
|                     + self.current_iter / self.warmup_epochs
 | |
|                 ) * base_lr
 | |
|             lrs.append(lr)
 | |
|         return lrs
 | |
| 
 | |
| 
 | |
| class ExponentialLR(_LRScheduler):
 | |
|     def __init__(self, optimizer, warmup_epochs, epochs, gamma):
 | |
|         self.gamma = gamma
 | |
|         super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
 | |
| 
 | |
|     def extra_repr(self):
 | |
|         return "type={:}, gamma={:}, base-lrs={:}".format(
 | |
|             "exponential", self.gamma, self.base_lrs
 | |
|         )
 | |
| 
 | |
|     def get_lr(self):
 | |
|         lrs = []
 | |
|         for base_lr in self.base_lrs:
 | |
|             if self.current_epoch >= self.warmup_epochs:
 | |
|                 last_epoch = self.current_epoch - self.warmup_epochs
 | |
|                 assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
 | |
|                 lr = base_lr * (self.gamma ** last_epoch)
 | |
|             else:
 | |
|                 lr = (
 | |
|                     self.current_epoch / self.warmup_epochs
 | |
|                     + self.current_iter / self.warmup_epochs
 | |
|                 ) * base_lr
 | |
|             lrs.append(lr)
 | |
|         return lrs
 | |
| 
 | |
| 
 | |
| class LinearLR(_LRScheduler):
 | |
|     def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR):
 | |
|         self.max_LR = max_LR
 | |
|         self.min_LR = min_LR
 | |
|         super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs)
 | |
| 
 | |
|     def extra_repr(self):
 | |
|         return "type={:}, max_LR={:}, min_LR={:}, base-lrs={:}".format(
 | |
|             "LinearLR", self.max_LR, self.min_LR, self.base_lrs
 | |
|         )
 | |
| 
 | |
|     def get_lr(self):
 | |
|         lrs = []
 | |
|         for base_lr in self.base_lrs:
 | |
|             if self.current_epoch >= self.warmup_epochs:
 | |
|                 last_epoch = self.current_epoch - self.warmup_epochs
 | |
|                 assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
 | |
|                 ratio = (
 | |
|                     (self.max_LR - self.min_LR)
 | |
|                     * last_epoch
 | |
|                     / self.max_epochs
 | |
|                     / self.max_LR
 | |
|                 )
 | |
|                 lr = base_lr * (1 - ratio)
 | |
|             else:
 | |
|                 lr = (
 | |
|                     self.current_epoch / self.warmup_epochs
 | |
|                     + self.current_iter / self.warmup_epochs
 | |
|                 ) * base_lr
 | |
|             lrs.append(lr)
 | |
|         return lrs
 | |
| 
 | |
| 
 | |
| class CrossEntropyLabelSmooth(nn.Module):
 | |
|     def __init__(self, num_classes, epsilon):
 | |
|         super(CrossEntropyLabelSmooth, self).__init__()
 | |
|         self.num_classes = num_classes
 | |
|         self.epsilon = epsilon
 | |
|         self.logsoftmax = nn.LogSoftmax(dim=1)
 | |
| 
 | |
|     def forward(self, inputs, targets):
 | |
|         log_probs = self.logsoftmax(inputs)
 | |
|         targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
 | |
|         targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
 | |
|         loss = (-targets * log_probs).mean(0).sum()
 | |
|         return loss
 | |
| 
 | |
| 
 | |
| def get_optim_scheduler(parameters, config):
 | |
|     assert (
 | |
|         hasattr(config, "optim")
 | |
|         and hasattr(config, "scheduler")
 | |
|         and hasattr(config, "criterion")
 | |
|     ), "config must have optim / scheduler / criterion keys instead of {:}".format(
 | |
|         config
 | |
|     )
 | |
|     if config.optim == "SGD":
 | |
|         optim = torch.optim.SGD(
 | |
|             parameters,
 | |
|             config.LR,
 | |
|             momentum=config.momentum,
 | |
|             weight_decay=config.decay,
 | |
|             nesterov=config.nesterov,
 | |
|         )
 | |
|     elif config.optim == "RMSprop":
 | |
|         optim = torch.optim.RMSprop(
 | |
|             parameters, config.LR, momentum=config.momentum, weight_decay=config.decay
 | |
|         )
 | |
|     else:
 | |
|         raise ValueError("invalid optim : {:}".format(config.optim))
 | |
| 
 | |
|     if config.scheduler == "cos":
 | |
|         T_max = getattr(config, "T_max", config.epochs)
 | |
|         scheduler = CosineAnnealingLR(
 | |
|             optim, config.warmup, config.epochs, T_max, config.eta_min
 | |
|         )
 | |
|     elif config.scheduler == "multistep":
 | |
|         scheduler = MultiStepLR(
 | |
|             optim, config.warmup, config.epochs, config.milestones, config.gammas
 | |
|         )
 | |
|     elif config.scheduler == "exponential":
 | |
|         scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
 | |
|     elif config.scheduler == "linear":
 | |
|         scheduler = LinearLR(
 | |
|             optim, config.warmup, config.epochs, config.LR, config.LR_min
 | |
|         )
 | |
|     else:
 | |
|         raise ValueError("invalid scheduler : {:}".format(config.scheduler))
 | |
| 
 | |
|     if config.criterion == "Softmax":
 | |
|         criterion = torch.nn.CrossEntropyLoss()
 | |
|     elif config.criterion == "SmoothSoftmax":
 | |
|         criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth)
 | |
|     else:
 | |
|         raise ValueError("invalid criterion : {:}".format(config.criterion))
 | |
|     return optim, scheduler, criterion
 |