| 
									
										
										
										
											2020-02-23 10:30:37 +11:00
										 |  |  | ##################################################### | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | 
					
						
							|  |  |  | ##################################################### | 
					
						
							| 
									
										
										
										
											2020-01-15 00:52:06 +11:00
										 |  |  | # I write this package to make AutoDL-Projects to be compatible with the old GDAS projects. | 
					
						
							|  |  |  | # Ideally, this package will be merged into lib/models/cell_infers in future. | 
					
						
							|  |  |  | # Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019). | 
					
						
							|  |  |  | ################################################## | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-06 19:29:07 +11:00
										 |  |  | import os, torch | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-06 19:29:07 +11:00
										 |  |  | def obtain_nas_infer_model(config, extra_model_path=None): | 
					
						
							|  |  |  |    | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  |   if config.arch == 'dxys': | 
					
						
							|  |  |  |     from .DXYs import CifarNet, ImageNet, Networks | 
					
						
							| 
									
										
										
										
											2020-03-06 19:29:07 +11:00
										 |  |  |     from .DXYs import build_genotype_from_dict | 
					
						
							|  |  |  |     if config.genotype is None: | 
					
						
							|  |  |  |       if extra_model_path is not None and not os.path.isfile(extra_model_path): | 
					
						
							|  |  |  |         raise ValueError('When genotype in confiig is None, extra_model_path must be set as a path instead of {:}'.format(extra_model_path)) | 
					
						
							|  |  |  |       xdata = torch.load(extra_model_path) | 
					
						
							|  |  |  |       current_epoch = xdata['epoch'] | 
					
						
							|  |  |  |       genotype_dict = xdata['genotypes'][current_epoch-1] | 
					
						
							|  |  |  |       genotype = build_genotype_from_dict(genotype_dict) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |       genotype = Networks[config.genotype] | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  |     if config.dataset == 'cifar': | 
					
						
							|  |  |  |       return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num) | 
					
						
							|  |  |  |     elif config.dataset == 'imagenet': | 
					
						
							|  |  |  |       return ImageNet(config.ichannel, config.layers, config.auxiliary, genotype, config.class_num) | 
					
						
							|  |  |  |     else: raise ValueError('invalid dataset : {:}'.format(config.dataset)) | 
					
						
							|  |  |  |   else: | 
					
						
							|  |  |  |     raise ValueError('invalid nas arch type : {:}'.format(config.arch)) |