| 
									
										
										
										
											2019-11-15 17:15:07 +11:00
										 |  |  | ################################################## | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | 
					
						
							|  |  |  | ################################################## | 
					
						
							| 
									
										
										
										
											2020-01-09 22:26:23 +11:00
										 |  |  | import os, sys, torch, random, PIL, copy, numpy as np | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | from os import path as osp | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  | from shutil import copyfile | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def prepare_seed(rand_seed): | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     random.seed(rand_seed) | 
					
						
							|  |  |  |     np.random.seed(rand_seed) | 
					
						
							|  |  |  |     torch.manual_seed(rand_seed) | 
					
						
							|  |  |  |     torch.cuda.manual_seed(rand_seed) | 
					
						
							|  |  |  |     torch.cuda.manual_seed_all(rand_seed) | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def prepare_logger(xargs): | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     args = copy.deepcopy(xargs) | 
					
						
							|  |  |  |     from log_utils import Logger | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     logger = Logger(args.save_dir, args.rand_seed) | 
					
						
							|  |  |  |     logger.log("Main Function with logger : {:}".format(logger)) | 
					
						
							|  |  |  |     logger.log("Arguments : -------------------------------") | 
					
						
							|  |  |  |     for name, value in args._get_kwargs(): | 
					
						
							|  |  |  |         logger.log("{:16} : {:}".format(name, value)) | 
					
						
							|  |  |  |     logger.log("Python  Version  : {:}".format(sys.version.replace("\n", " "))) | 
					
						
							|  |  |  |     logger.log("Pillow  Version  : {:}".format(PIL.__version__)) | 
					
						
							|  |  |  |     logger.log("PyTorch Version  : {:}".format(torch.__version__)) | 
					
						
							|  |  |  |     logger.log("cuDNN   Version  : {:}".format(torch.backends.cudnn.version())) | 
					
						
							|  |  |  |     logger.log("CUDA available   : {:}".format(torch.cuda.is_available())) | 
					
						
							|  |  |  |     logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count())) | 
					
						
							|  |  |  |     logger.log( | 
					
						
							|  |  |  |         "CUDA_VISIBLE_DEVICES : {:}".format( | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |             os.environ["CUDA_VISIBLE_DEVICES"] | 
					
						
							|  |  |  |             if "CUDA_VISIBLE_DEVICES" in os.environ | 
					
						
							|  |  |  |             else "None" | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |         ) | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     return logger | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_machine_info(): | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     info = "Python  Version  : {:}".format(sys.version.replace("\n", " ")) | 
					
						
							|  |  |  |     info += "\nPillow  Version  : {:}".format(PIL.__version__) | 
					
						
							|  |  |  |     info += "\nPyTorch Version  : {:}".format(torch.__version__) | 
					
						
							|  |  |  |     info += "\ncuDNN   Version  : {:}".format(torch.backends.cudnn.version()) | 
					
						
							|  |  |  |     info += "\nCUDA available   : {:}".format(torch.cuda.is_available()) | 
					
						
							|  |  |  |     info += "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count()) | 
					
						
							|  |  |  |     if "CUDA_VISIBLE_DEVICES" in os.environ: | 
					
						
							|  |  |  |         info += "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ["CUDA_VISIBLE_DEVICES"]) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         info += "\nDoes not set CUDA_VISIBLE_DEVICES" | 
					
						
							|  |  |  |     return info | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def save_checkpoint(state, filename, logger): | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     if osp.isfile(filename): | 
					
						
							|  |  |  |         if hasattr(logger, "log"): | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |             logger.log( | 
					
						
							|  |  |  |                 "Find {:} exist, delete is at first before saving".format(filename) | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |         os.remove(filename) | 
					
						
							|  |  |  |     torch.save(state, filename) | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |     assert osp.isfile( | 
					
						
							|  |  |  |         filename | 
					
						
							|  |  |  |     ), "save filename : {:} failed, which is not found.".format(filename) | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     if hasattr(logger, "log"): | 
					
						
							|  |  |  |         logger.log("save checkpoint into {:}".format(filename)) | 
					
						
							|  |  |  |     return filename | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def copy_checkpoint(src, dst, logger): | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     if osp.isfile(dst): | 
					
						
							|  |  |  |         if hasattr(logger, "log"): | 
					
						
							|  |  |  |             logger.log("Find {:} exist, delete is at first before saving".format(dst)) | 
					
						
							|  |  |  |         os.remove(dst) | 
					
						
							|  |  |  |     copyfile(src, dst) | 
					
						
							|  |  |  |     if hasattr(logger, "log"): | 
					
						
							|  |  |  |         logger.log("copy the file from {:} into {:}".format(src, dst)) |