update scripts
This commit is contained in:
		| @@ -1,3 +1,4 @@ | ||||
| # DARTS First Order, Refer to https://github.com/quark0/darts | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
|   | ||||
| @@ -13,25 +13,11 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from utils import AverageMeter, time_string, convert_secs2time | ||||
| from utils import print_log, obtain_accuracy | ||||
| from utils import Cutout, count_parameters_in_MB | ||||
| from nas import DARTS_V1, DARTS_V2, NASNet, PNASNet, AmoebaNet, ENASNet | ||||
| from nas import DMS_V1, DMS_F1, GDAS_CC | ||||
| from meta_nas import META_V1, META_V2 | ||||
| from nas import model_types as models | ||||
| from train_utils import main_procedure | ||||
| from train_utils_imagenet import main_procedure_imagenet | ||||
| from scheduler import load_config | ||||
|  | ||||
| models = {'DARTS_V1': DARTS_V1, | ||||
|           'DARTS_V2': DARTS_V2, | ||||
|           'NASNet'  : NASNet, | ||||
|           'PNASNet' : PNASNet, | ||||
|           'ENASNet' : ENASNet, | ||||
|           'DMS_V1'  : DMS_V1, | ||||
|           'DMS_F1'  : DMS_F1, | ||||
|           'GDAS_CC' : GDAS_CC, | ||||
|           'META_V1' : META_V1, | ||||
|           'META_V2' : META_V2, | ||||
|           'AmoebaNet' : AmoebaNet} | ||||
|  | ||||
|  | ||||
| parser = argparse.ArgumentParser("cifar") | ||||
| parser.add_argument('--data_path',         type=str,   help='Path to dataset') | ||||
|   | ||||
| @@ -10,6 +10,7 @@ from utils import time_string, convert_secs2time | ||||
| from utils import count_parameters_in_MB | ||||
| from utils import Cutout | ||||
| from nas import NetworkCIFAR as Network | ||||
| from datasets import get_datasets | ||||
|  | ||||
| def obtain_best(accuracies): | ||||
|   if len(accuracies) == 0: return (0, 0) | ||||
| @@ -17,38 +18,10 @@ def obtain_best(accuracies): | ||||
|   s2b = sorted( tops ) | ||||
|   return s2b[-1] | ||||
|  | ||||
|  | ||||
| def main_procedure(config, dataset, data_path, args, genotype, init_channels, layers, log): | ||||
|    | ||||
|   # Mean + Std | ||||
|   if dataset == 'cifar10': | ||||
|     mean = [x / 255 for x in [125.3, 123.0, 113.9]] | ||||
|     std = [x / 255 for x in [63.0, 62.1, 66.7]] | ||||
|   elif dataset == 'cifar100': | ||||
|     mean = [x / 255 for x in [129.3, 124.1, 112.4]] | ||||
|     std = [x / 255 for x in [68.2, 65.4, 70.4]] | ||||
|   else: | ||||
|     raise TypeError("Unknow dataset : {:}".format(dataset)) | ||||
|   # Dataset Transformation | ||||
|   if dataset == 'cifar10' or dataset == 'cifar100': | ||||
|     lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), | ||||
|              transforms.Normalize(mean, std)] | ||||
|     if config.cutout > 0 : lists += [Cutout(config.cutout)] | ||||
|     train_transform = transforms.Compose(lists) | ||||
|     test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|   else: | ||||
|     raise TypeError("Unknow dataset : {:}".format(dataset)) | ||||
|   # Dataset Defination | ||||
|   if dataset == 'cifar10': | ||||
|     train_data = dset.CIFAR10(data_path, train= True, transform=train_transform, download=True) | ||||
|     test_data  = dset.CIFAR10(data_path, train=False, transform=test_transform , download=True) | ||||
|     class_num  = 10 | ||||
|   elif dataset == 'cifar100': | ||||
|     train_data = dset.CIFAR100(data_path, train= True, transform=train_transform, download=True) | ||||
|     test_data  = dset.CIFAR100(data_path, train=False, transform=test_transform , download=True) | ||||
|     class_num  = 100 | ||||
|   else: | ||||
|     raise TypeError("Unknow dataset : {:}".format(dataset)) | ||||
|  | ||||
|   train_data, test_data, class_num = get_datasets(dataset, data_path, args.cutout) | ||||
|  | ||||
|   print_log('-------------------------------------- main-procedure', log) | ||||
|   print_log('config        : {:}'.format(config), log) | ||||
|   | ||||
| @@ -12,6 +12,7 @@ from utils import count_parameters_in_MB | ||||
| from utils import print_FLOPs | ||||
| from utils import Cutout | ||||
| from nas import NetworkImageNet as Network | ||||
| from datasets import get_datasets | ||||
|  | ||||
|  | ||||
| def obtain_best(accuracies): | ||||
| @@ -40,30 +41,7 @@ class CrossEntropyLabelSmooth(nn.Module): | ||||
| def main_procedure_imagenet(config, data_path, args, genotype, init_channels, layers, log): | ||||
|    | ||||
|   # training data and testing data | ||||
|   traindir = os.path.join(data_path, 'train') | ||||
|   validdir = os.path.join(data_path, 'val') | ||||
|   normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | ||||
|   train_data = dset.ImageFolder( | ||||
|     traindir, | ||||
|     transforms.Compose([ | ||||
|       transforms.RandomResizedCrop(224), | ||||
|       transforms.RandomHorizontalFlip(), | ||||
|       transforms.ColorJitter( | ||||
|         brightness=0.4, | ||||
|         contrast=0.4, | ||||
|         saturation=0.4, | ||||
|         hue=0.2), | ||||
|       transforms.ToTensor(), | ||||
|       normalize, | ||||
|     ])) | ||||
|   valid_data = dset.ImageFolder( | ||||
|     validdir, | ||||
|     transforms.Compose([ | ||||
|       transforms.Resize(256), | ||||
|       transforms.CenterCrop(224), | ||||
|       transforms.ToTensor(), | ||||
|       normalize, | ||||
|     ])) | ||||
|   train_data, valid_data, class_num = get_datasets('imagenet-1k', data_path, -1) | ||||
|  | ||||
|   train_queue = torch.utils.data.DataLoader( | ||||
|     train_data, batch_size=config.batch_size, shuffle= True, pin_memory=True, num_workers=args.workers) | ||||
| @@ -73,7 +51,6 @@ def main_procedure_imagenet(config, data_path, args, genotype, init_channels, la | ||||
|  | ||||
|   class_num   = 1000 | ||||
|  | ||||
|  | ||||
|   print_log('-------------------------------------- main-procedure', log) | ||||
|   print_log('config        : {:}'.format(config), log) | ||||
|   print_log('genotype      : {:}'.format(genotype), log) | ||||
| @@ -98,8 +75,7 @@ def main_procedure_imagenet(config, data_path, args, genotype, init_channels, la | ||||
|   criterion_smooth = CrossEntropyLabelSmooth(class_num, config.label_smooth).cuda() | ||||
|  | ||||
|  | ||||
|   optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay) | ||||
|   #optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay, nestero=True) | ||||
|   optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay, nestero=True) | ||||
|   if config.type == 'cosine': | ||||
|     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(config.epochs)) | ||||
|   elif config.type == 'steplr': | ||||
|   | ||||
		Reference in New Issue
	
	Block a user