from torch.utils.tensorboard import SummaryWriter import argparse import glob import logging import sys sys.path.insert(0, '../../') import time import random import numpy as np import os import torch import torch.backends.cudnn as cudnn import torch.nn as nn import torch.utils import torchvision.datasets as dset import torchvision.transforms as transforms from torch.autograd import Variable import nasbench201.utils as utils from sota.cnn.model_imagenet import NetworkImageNet as Network import sota.cnn.genotypes as genotypes from sota.cnn.hdf5 import H5Dataset parser = argparse.ArgumentParser("imagenet") parser.add_argument('--data', type=str, default='../../data', help='location of the data corpus') parser.add_argument('--batch_size', type=int, default=128, help='batch size') parser.add_argument('--learning_rate', type=float, default=0.1, help='init learning rate') parser.add_argument('--momentum', type=float, default=0.9, help='momentum') parser.add_argument('--weight_decay', type=float, default=3e-5, help='weight decay') parser.add_argument('--report_freq', type=float, default=100, help='report frequency') parser.add_argument('--gpu', type=int, default=0, help='gpu device id') parser.add_argument('--epochs', type=int, default=250, help='num of training epochs') parser.add_argument('--init_channels', type=int, default=48, help='num of init channels') parser.add_argument('--layers', type=int, default=14, help='total number of layers') parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower') parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss') parser.add_argument('--drop_path_prob', type=float, default=0, help='drop path probability') parser.add_argument('--save', type=str, default='EXP', help='experiment name') parser.add_argument('--seed', type=int, default=0, help='random_ws seed') parser.add_argument('--arch', type=str, default='c10_s3_pgd', help='which architecture to use') parser.add_argument('--grad_clip', type=float, default=5., help='gradient clipping') parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing') parser.add_argument('--gamma', type=float, default=0.97, help='learning rate decay') parser.add_argument('--decay_period', type=int, default=1, help='epochs between two learning rate decays') parser.add_argument('--parallel', action='store_true', default=False, help='darts parallelism') parser.add_argument('--load', action='store_true', default=False, help='whether load checkpoint for continue training') args = parser.parse_args() args.save = '../../experiments/sota/imagenet/eval/{}-{}-{}-{}'.format( args.save, time.strftime("%Y%m%d-%H%M%S"), args.arch, args.seed) if args.auxiliary: args.save += '-auxiliary-' + str(args.auxiliary_weight) args.save += '-' + str(np.random.randint(10000)) utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py')) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) writer = SummaryWriter(args.save + '/runs') CLASSES = 1000 class CrossEntropyLabelSmooth(nn.Module): def __init__(self, num_classes, epsilon): super(CrossEntropyLabelSmooth, self).__init__() self.num_classes = num_classes self.epsilon = epsilon self.logsoftmax = nn.LogSoftmax(dim=1) def forward(self, inputs, targets): log_probs = self.logsoftmax(inputs) targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes loss = (-targets * log_probs).mean(0).sum() return loss def seed_torch(seed=0): random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) cudnn.deterministic = True cudnn.benchmark = False def main(): if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) torch.cuda.set_device(args.gpu) cudnn.enabled = True seed_torch(args.seed) logging.info('gpu device = %d' % args.gpu) logging.info("args = %s", args) genotype = eval("genotypes.%s" % args.arch) model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype) if args.parallel: model = nn.DataParallel(model).cuda() else: model = model.cuda() logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) criterion = nn.CrossEntropyLoss() criterion = criterion.cuda() criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth) criterion_smooth = criterion_smooth.cuda() optimizer = torch.optim.SGD( model.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay ) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2), transforms.ToTensor(), normalize, ]) test_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ]) train_data = H5Dataset(os.path.join(args.data, 'imagenet-train-256.h5'), transform=train_transform) valid_data = H5Dataset(os.path.join(args.data, 'imagenet-val-256.h5'), transform=test_transform) train_queue = torch.utils.data.DataLoader( train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4) valid_queue = torch.utils.data.DataLoader( valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.decay_period, gamma=args.gamma) if args.load: model, optimizer, start_epoch, best_acc_top1 = utils.load_checkpoint( model, optimizer, '../../experiments/sota/imagenet/eval/EXP-20200210-143540-c10_s3_pgd-0-auxiliary-0.4-2753') else: best_acc_top1 = 0 start_epoch = 0 for epoch in range(start_epoch, args.epochs): logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0]) model.drop_path_prob = args.drop_path_prob * epoch / args.epochs train_acc, train_obj = train(train_queue, model, criterion_smooth, optimizer) logging.info('train_acc %f', train_acc) writer.add_scalar('Acc/train', train_acc, epoch) writer.add_scalar('Obj/train', train_obj, epoch) scheduler.step() valid_acc_top1, valid_acc_top5, valid_obj = infer(valid_queue, model, criterion) logging.info('valid_acc_top1 %f', valid_acc_top1) logging.info('valid_acc_top5 %f', valid_acc_top5) writer.add_scalar('Acc/valid_top1', valid_acc_top1, epoch) writer.add_scalar('Acc/valid_top5', valid_acc_top5, epoch) is_best = False if valid_acc_top1 > best_acc_top1: best_acc_top1 = valid_acc_top1 is_best = True utils.save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_acc_top1': best_acc_top1, 'optimizer': optimizer.state_dict(), }, is_best, args.save) def train(train_queue, model, criterion, optimizer): objs = utils.AvgrageMeter() top1 = utils.AvgrageMeter() top5 = utils.AvgrageMeter() model.train() for step, (input, target) in enumerate(train_queue): input = input.cuda() target = target.cuda(non_blocking=True) optimizer.zero_grad() logits, logits_aux = model(input) loss = criterion(logits, target) if args.auxiliary: loss_aux = criterion(logits_aux, target) loss += args.auxiliary_weight * loss_aux loss.backward() nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) n = input.size(0) objs.update(loss.data, n) top1.update(prec1.data, n) top5.update(prec5.data, n) if step % args.report_freq == 0: logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) return top1.avg, objs.avg def infer(valid_queue, model, criterion): objs = utils.AvgrageMeter() top1 = utils.AvgrageMeter() top5 = utils.AvgrageMeter() model.eval() with torch.no_grad(): for step, (input, target) in enumerate(valid_queue): input = input.cuda() target = target.cuda(non_blocking=True) logits, _ = model(input) loss = criterion(logits, target) prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) n = input.size(0) objs.update(loss.data, n) top1.update(prec1.data, n) top5.update(prec5.data, n) if step % args.report_freq == 0: logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) return top1.avg, top5.avg, objs.avg if __name__ == '__main__': main()