Add more algorithms
This commit is contained in:
		
							
								
								
									
										5
									
								
								others/GDAS/lib/scheduler/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								others/GDAS/lib/scheduler/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from .utils import load_config | ||||
| from .scheduler import MultiStepLR, obtain_scheduler | ||||
							
								
								
									
										32
									
								
								others/GDAS/lib/scheduler/scheduler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								others/GDAS/lib/scheduler/scheduler.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import torch | ||||
| from bisect import bisect_right | ||||
|  | ||||
|  | ||||
| class MultiStepLR(torch.optim.lr_scheduler._LRScheduler): | ||||
|  | ||||
|   def __init__(self, optimizer, milestones, gammas, last_epoch=-1): | ||||
|     if not list(milestones) == sorted(milestones): | ||||
|       raise ValueError('Milestones should be a list of' | ||||
|                        ' increasing integers. Got {:}', milestones) | ||||
|     assert len(milestones) == len(gammas), '{:} vs {:}'.format(milestones, gammas) | ||||
|     self.milestones = milestones | ||||
|     self.gammas = gammas | ||||
|     super(MultiStepLR, self).__init__(optimizer, last_epoch) | ||||
|  | ||||
|   def get_lr(self): | ||||
|     LR = 1 | ||||
|     for x in self.gammas[:bisect_right(self.milestones, self.last_epoch)]: LR = LR * x | ||||
|     return [base_lr * LR for base_lr in self.base_lrs] | ||||
|  | ||||
|  | ||||
| def obtain_scheduler(config, optimizer): | ||||
|   if config.type == 'multistep': | ||||
|     scheduler = MultiStepLR(optimizer, milestones=config.milestones, gammas=config.gammas) | ||||
|   elif config.type == 'cosine': | ||||
|     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.epochs) | ||||
|   else: | ||||
|     raise ValueError('Unknown learning rate scheduler type : {:}'.format(config.type)) | ||||
|   return scheduler | ||||
							
								
								
									
										42
									
								
								others/GDAS/lib/scheduler/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								others/GDAS/lib/scheduler/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, json | ||||
| from pathlib import Path | ||||
| from collections import namedtuple | ||||
|  | ||||
| support_types = ('str', 'int', 'bool', 'float') | ||||
|  | ||||
| def convert_param(original_lists): | ||||
|   assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists) | ||||
|   ctype, value = original_lists[0], original_lists[1] | ||||
|   assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types) | ||||
|   is_list = isinstance(value, list) | ||||
|   if not is_list: value = [value] | ||||
|   outs = [] | ||||
|   for x in value: | ||||
|     if ctype == 'int': | ||||
|       x = int(x) | ||||
|     elif ctype == 'str': | ||||
|       x = str(x) | ||||
|     elif ctype == 'bool': | ||||
|       x = bool(int(x)) | ||||
|     elif ctype == 'float': | ||||
|       x = float(x) | ||||
|     else: | ||||
|       raise TypeError('Does not know this type : {:}'.format(ctype)) | ||||
|     outs.append(x) | ||||
|   if not is_list: outs = outs[0] | ||||
|   return outs | ||||
|  | ||||
| def load_config(path): | ||||
|   path = str(path) | ||||
|   assert os.path.exists(path), 'Can not find {:}'.format(path) | ||||
|   # Reading data back | ||||
|   with open(path, 'r') as f: | ||||
|     data = json.load(f) | ||||
|   f.close() | ||||
|   content = { k: convert_param(v) for k,v in data.items()} | ||||
|   Arguments = namedtuple('Configure', ' '.join(content.keys())) | ||||
|   content = Arguments(**content) | ||||
|   return content | ||||
		Reference in New Issue
	
	Block a user