| 
									
										
										
										
											2020-02-23 10:30:37 +11:00
										 |  |  | ##################################################### | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | 
					
						
							|  |  |  | ##################################################### | 
					
						
							| 
									
										
										
										
											2020-03-21 01:33:07 -07:00
										 |  |  | import time, torch | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | from procedures import prepare_seed, get_optim_scheduler | 
					
						
							|  |  |  | from utils import get_model_infos, obtain_accuracy | 
					
						
							| 
									
										
										
										
											2019-11-08 20:06:12 +11:00
										 |  |  | from config_utils import dict2config | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | from log_utils import AverageMeter, time_string, convert_secs2time | 
					
						
							|  |  |  | from models import get_cell_based_tiny_net | 
					
						
							| 
									
										
										
										
											2019-11-08 20:06:12 +11:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | __all__ = ["evaluate_for_seed", "pure_evaluate"] | 
					
						
							| 
									
										
										
										
											2019-12-20 20:41:49 +11:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     data_time, batch_time, batch = AverageMeter(), AverageMeter(), None | 
					
						
							|  |  |  |     losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | 
					
						
							|  |  |  |     latencies = [] | 
					
						
							|  |  |  |     network.eval() | 
					
						
							|  |  |  |     with torch.no_grad(): | 
					
						
							|  |  |  |         end = time.time() | 
					
						
							|  |  |  |         for i, (inputs, targets) in enumerate(xloader): | 
					
						
							|  |  |  |             targets = targets.cuda(non_blocking=True) | 
					
						
							|  |  |  |             inputs = inputs.cuda(non_blocking=True) | 
					
						
							|  |  |  |             data_time.update(time.time() - end) | 
					
						
							|  |  |  |             # forward | 
					
						
							|  |  |  |             features, logits = network(inputs) | 
					
						
							|  |  |  |             loss = criterion(logits, targets) | 
					
						
							|  |  |  |             batch_time.update(time.time() - end) | 
					
						
							|  |  |  |             if batch is None or batch == inputs.size(0): | 
					
						
							|  |  |  |                 batch = inputs.size(0) | 
					
						
							|  |  |  |                 latencies.append(batch_time.val - data_time.val) | 
					
						
							|  |  |  |             # record loss and accuracy | 
					
						
							|  |  |  |             prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | 
					
						
							|  |  |  |             losses.update(loss.item(), inputs.size(0)) | 
					
						
							|  |  |  |             top1.update(prec1.item(), inputs.size(0)) | 
					
						
							|  |  |  |             top5.update(prec5.item(), inputs.size(0)) | 
					
						
							|  |  |  |             end = time.time() | 
					
						
							|  |  |  |     if len(latencies) > 2: | 
					
						
							|  |  |  |         latencies = latencies[1:] | 
					
						
							|  |  |  |     return losses.avg, top1.avg, top5.avg, latencies | 
					
						
							| 
									
										
										
										
											2019-11-08 20:06:12 +11:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def procedure(xloader, network, criterion, scheduler, optimizer, mode): | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | 
					
						
							|  |  |  |     if mode == "train": | 
					
						
							|  |  |  |         network.train() | 
					
						
							|  |  |  |     elif mode == "valid": | 
					
						
							|  |  |  |         network.eval() | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         raise ValueError("The mode is not right : {:}".format(mode)) | 
					
						
							| 
									
										
										
										
											2019-11-08 20:06:12 +11:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() | 
					
						
							|  |  |  |     for i, (inputs, targets) in enumerate(xloader): | 
					
						
							|  |  |  |         if mode == "train": | 
					
						
							|  |  |  |             scheduler.update(None, 1.0 * i / len(xloader)) | 
					
						
							| 
									
										
										
										
											2019-11-08 20:06:12 +11:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |         targets = targets.cuda(non_blocking=True) | 
					
						
							|  |  |  |         if mode == "train": | 
					
						
							|  |  |  |             optimizer.zero_grad() | 
					
						
							|  |  |  |         # forward | 
					
						
							|  |  |  |         features, logits = network(inputs) | 
					
						
							|  |  |  |         loss = criterion(logits, targets) | 
					
						
							|  |  |  |         # backward | 
					
						
							|  |  |  |         if mode == "train": | 
					
						
							|  |  |  |             loss.backward() | 
					
						
							|  |  |  |             optimizer.step() | 
					
						
							|  |  |  |         # record loss and accuracy | 
					
						
							|  |  |  |         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | 
					
						
							|  |  |  |         losses.update(loss.item(), inputs.size(0)) | 
					
						
							|  |  |  |         top1.update(prec1.item(), inputs.size(0)) | 
					
						
							|  |  |  |         top5.update(prec5.item(), inputs.size(0)) | 
					
						
							|  |  |  |         # count time | 
					
						
							|  |  |  |         batch_time.update(time.time() - end) | 
					
						
							|  |  |  |         end = time.time() | 
					
						
							|  |  |  |     return losses.avg, top1.avg, top5.avg, batch_time.sum | 
					
						
							| 
									
										
										
										
											2019-11-08 20:06:12 +11:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  | def evaluate_for_seed( | 
					
						
							|  |  |  |     arch_config, config, arch, train_loader, valid_loaders, seed, logger | 
					
						
							|  |  |  | ): | 
					
						
							| 
									
										
										
										
											2019-11-08 20:06:12 +11:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     prepare_seed(seed)  # random seed | 
					
						
							|  |  |  |     net = get_cell_based_tiny_net( | 
					
						
							|  |  |  |         dict2config( | 
					
						
							|  |  |  |             { | 
					
						
							|  |  |  |                 "name": "infer.tiny", | 
					
						
							|  |  |  |                 "C": arch_config["channel"], | 
					
						
							|  |  |  |                 "N": arch_config["num_cells"], | 
					
						
							|  |  |  |                 "genotype": arch, | 
					
						
							|  |  |  |                 "num_classes": config.class_num, | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |             None, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     # net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) | 
					
						
							|  |  |  |     flop, param = get_model_infos(net, config.xshape) | 
					
						
							|  |  |  |     logger.log("Network : {:}".format(net.get_message()), False) | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     logger.log( | 
					
						
							|  |  |  |         "{:} Seed-------------------------- {:} --------------------------".format( | 
					
						
							|  |  |  |             time_string(), seed | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param)) | 
					
						
							|  |  |  |     # train and valid | 
					
						
							|  |  |  |     optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), config) | 
					
						
							|  |  |  |     network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda() | 
					
						
							|  |  |  |     # start training | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     start_time, epoch_time, total_epoch = ( | 
					
						
							|  |  |  |         time.time(), | 
					
						
							|  |  |  |         AverageMeter(), | 
					
						
							|  |  |  |         config.epochs + config.warmup, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     ( | 
					
						
							|  |  |  |         train_losses, | 
					
						
							|  |  |  |         train_acc1es, | 
					
						
							|  |  |  |         train_acc5es, | 
					
						
							|  |  |  |         valid_losses, | 
					
						
							|  |  |  |         valid_acc1es, | 
					
						
							|  |  |  |         valid_acc5es, | 
					
						
							|  |  |  |     ) = ({}, {}, {}, {}, {}, {}) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     train_times, valid_times = {}, {} | 
					
						
							|  |  |  |     for epoch in range(total_epoch): | 
					
						
							|  |  |  |         scheduler.update(epoch, 0.0) | 
					
						
							| 
									
										
										
										
											2019-11-08 20:06:12 +11:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |         train_loss, train_acc1, train_acc5, train_tm = procedure( | 
					
						
							|  |  |  |             train_loader, network, criterion, scheduler, optimizer, "train" | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         train_losses[epoch] = train_loss | 
					
						
							|  |  |  |         train_acc1es[epoch] = train_acc1 | 
					
						
							|  |  |  |         train_acc5es[epoch] = train_acc5 | 
					
						
							|  |  |  |         train_times[epoch] = train_tm | 
					
						
							|  |  |  |         with torch.no_grad(): | 
					
						
							|  |  |  |             for key, xloder in valid_loaders.items(): | 
					
						
							|  |  |  |                 valid_loss, valid_acc1, valid_acc5, valid_tm = procedure( | 
					
						
							|  |  |  |                     xloder, network, criterion, None, None, "valid" | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 valid_losses["{:}@{:}".format(key, epoch)] = valid_loss | 
					
						
							|  |  |  |                 valid_acc1es["{:}@{:}".format(key, epoch)] = valid_acc1 | 
					
						
							|  |  |  |                 valid_acc5es["{:}@{:}".format(key, epoch)] = valid_acc5 | 
					
						
							|  |  |  |                 valid_times["{:}@{:}".format(key, epoch)] = valid_tm | 
					
						
							| 
									
										
										
										
											2019-11-08 20:06:12 +11:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |         # measure elapsed time | 
					
						
							|  |  |  |         epoch_time.update(time.time() - start_time) | 
					
						
							|  |  |  |         start_time = time.time() | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |         need_time = "Time Left: {:}".format( | 
					
						
							|  |  |  |             convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True) | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |         logger.log( | 
					
						
							|  |  |  |             "{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]".format( | 
					
						
							|  |  |  |                 time_string(), | 
					
						
							|  |  |  |                 need_time, | 
					
						
							|  |  |  |                 epoch, | 
					
						
							|  |  |  |                 total_epoch, | 
					
						
							|  |  |  |                 train_loss, | 
					
						
							|  |  |  |                 train_acc1, | 
					
						
							|  |  |  |                 train_acc5, | 
					
						
							|  |  |  |                 valid_loss, | 
					
						
							|  |  |  |                 valid_acc1, | 
					
						
							|  |  |  |                 valid_acc5, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     info_seed = { | 
					
						
							|  |  |  |         "flop": flop, | 
					
						
							|  |  |  |         "param": param, | 
					
						
							|  |  |  |         "channel": arch_config["channel"], | 
					
						
							|  |  |  |         "num_cells": arch_config["num_cells"], | 
					
						
							|  |  |  |         "config": config._asdict(), | 
					
						
							|  |  |  |         "total_epoch": total_epoch, | 
					
						
							|  |  |  |         "train_losses": train_losses, | 
					
						
							|  |  |  |         "train_acc1es": train_acc1es, | 
					
						
							|  |  |  |         "train_acc5es": train_acc5es, | 
					
						
							|  |  |  |         "train_times": train_times, | 
					
						
							|  |  |  |         "valid_losses": valid_losses, | 
					
						
							|  |  |  |         "valid_acc1es": valid_acc1es, | 
					
						
							|  |  |  |         "valid_acc5es": valid_acc5es, | 
					
						
							|  |  |  |         "valid_times": valid_times, | 
					
						
							|  |  |  |         "net_state_dict": net.state_dict(), | 
					
						
							|  |  |  |         "net_string": "{:}".format(net), | 
					
						
							|  |  |  |         "finish-train": True, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return info_seed |