update scripts
This commit is contained in:
		| @@ -7,6 +7,7 @@ import torch.nn.functional as F | ||||
| import torchvision.datasets as dset | ||||
| import torch.backends.cudnn as cudnn | ||||
| import torchvision.transforms as transforms | ||||
| import multiprocessing | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||
| print ('lib-dir : {:}'.format(lib_dir)) | ||||
| @@ -29,7 +30,7 @@ parser.add_argument('--config_path',       type=str, help='the training configur | ||||
| parser.add_argument('--save_path',         type=str, help='Folder to save checkpoints and log.') | ||||
| parser.add_argument('--print_freq',        type=int, help='print frequency (default: 200)') | ||||
| parser.add_argument('--manualSeed',        type=int, help='manual seed') | ||||
| parser.add_argument('--threads',           type=int, default=10, help='the number of threads') | ||||
| parser.add_argument('--threads',           type=int, default=4, help='the number of threads') | ||||
| args = parser.parse_args() | ||||
|  | ||||
| assert torch.cuda.is_available(), 'torch.cuda is not available' | ||||
| @@ -50,7 +51,7 @@ def main(): | ||||
|   if not os.path.isdir(args.save_path): | ||||
|     os.makedirs(args.save_path) | ||||
|   log = open(os.path.join(args.save_path, 'log-seed-{:}-{:}.txt'.format(args.manualSeed, time_file_str())), 'w') | ||||
|   print_log('save path : {}'.format(args.save_path), log) | ||||
|   print_log('save path : {:}'.format(args.save_path), log) | ||||
|   state = {k: v for k, v in args._get_kwargs()} | ||||
|   print_log(state, log) | ||||
|   print_log("Random Seed: {}".format(args.manualSeed), log) | ||||
| @@ -59,6 +60,7 @@ def main(): | ||||
|   print_log("CUDA   version : {}".format(torch.version.cuda), log) | ||||
|   print_log("cuDNN  version : {}".format(cudnn.version()), log) | ||||
|   print_log("Num of GPUs    : {}".format(torch.cuda.device_count()), log) | ||||
|   print_log("Num of CPUs    : {}".format(multiprocessing.cpu_count()), log) | ||||
|  | ||||
|   config = load_config( args.config_path ) | ||||
|   genotype = Networks[ args.arch ] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user