Update test weights and shapes
This commit is contained in:
		| @@ -17,13 +17,13 @@ __all__ = ['evaluate_for_seed', 'pure_evaluate', 'get_nas_bench_loaders'] | ||||
| def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): | ||||
|   data_time, batch_time, batch = AverageMeter(), AverageMeter(), None | ||||
|   losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   latencies = [] | ||||
|   latencies, device = [], torch.cuda.current_device() | ||||
|   network.eval() | ||||
|   with torch.no_grad(): | ||||
|     end = time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|       targets = targets.cuda(non_blocking=True) | ||||
|       inputs  = inputs.cuda(non_blocking=True) | ||||
|       targets = targets.cuda(device=device, non_blocking=True) | ||||
|       inputs  = inputs.cuda(device=device, non_blocking=True) | ||||
|       data_time.update(time.time() - end) | ||||
|       # forward | ||||
|       features, logits = network(inputs) | ||||
| @@ -48,12 +48,12 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode: str): | ||||
|   if mode == 'train'  : network.train() | ||||
|   elif mode == 'valid': network.eval() | ||||
|   else: raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|  | ||||
|   device = torch.cuda.current_device() | ||||
|   data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() | ||||
|   for i, (inputs, targets) in enumerate(xloader): | ||||
|     if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|  | ||||
|     targets = targets.cuda(non_blocking=True) | ||||
|     targets = targets.cuda(device=device, non_blocking=True) | ||||
|     if mode == 'train': optimizer.zero_grad() | ||||
|     # forward | ||||
|     features, logits = network(inputs) | ||||
| @@ -84,7 +84,9 @@ def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed | ||||
|   logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param)) | ||||
|   # train and valid | ||||
|   optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config) | ||||
|   network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda() | ||||
|   default_device = torch.cuda.current_device() | ||||
|   network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(device=default_device) | ||||
|   criterion = criterion.cuda(device=default_device) | ||||
|   # start training | ||||
|   start_time, epoch_time, total_epoch = time.time(), AverageMeter(), opt_config.epochs + opt_config.warmup | ||||
|   train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user