| 
									
										
										
										
											2019-11-15 17:15:07 +11:00
										 |  |  | ################################################## | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | 
					
						
							|  |  |  | ################################################## | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | import os, sys, time, torch | 
					
						
							| 
									
										
										
										
											2021-03-07 01:44:26 -08:00
										 |  |  | from log_utils import AverageMeter | 
					
						
							|  |  |  | from log_utils import time_string | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  | from utils import obtain_accuracy | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  | def basic_train( | 
					
						
							|  |  |  |     xloader, | 
					
						
							|  |  |  |     network, | 
					
						
							|  |  |  |     criterion, | 
					
						
							|  |  |  |     scheduler, | 
					
						
							|  |  |  |     optimizer, | 
					
						
							|  |  |  |     optim_config, | 
					
						
							|  |  |  |     extra_info, | 
					
						
							|  |  |  |     print_freq, | 
					
						
							|  |  |  |     logger, | 
					
						
							|  |  |  | ): | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     loss, acc1, acc5 = procedure( | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |         xloader, | 
					
						
							|  |  |  |         network, | 
					
						
							|  |  |  |         criterion, | 
					
						
							|  |  |  |         scheduler, | 
					
						
							|  |  |  |         optimizer, | 
					
						
							|  |  |  |         "train", | 
					
						
							|  |  |  |         optim_config, | 
					
						
							|  |  |  |         extra_info, | 
					
						
							|  |  |  |         print_freq, | 
					
						
							|  |  |  |         logger, | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     ) | 
					
						
							|  |  |  |     return loss, acc1, acc5 | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  | def basic_valid( | 
					
						
							|  |  |  |     xloader, network, criterion, optim_config, extra_info, print_freq, logger | 
					
						
							|  |  |  | ): | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     with torch.no_grad(): | 
					
						
							|  |  |  |         loss, acc1, acc5 = procedure( | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |             xloader, | 
					
						
							|  |  |  |             network, | 
					
						
							|  |  |  |             criterion, | 
					
						
							|  |  |  |             None, | 
					
						
							|  |  |  |             None, | 
					
						
							|  |  |  |             "valid", | 
					
						
							|  |  |  |             None, | 
					
						
							|  |  |  |             extra_info, | 
					
						
							|  |  |  |             print_freq, | 
					
						
							|  |  |  |             logger, | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |         ) | 
					
						
							|  |  |  |     return loss, acc1, acc5 | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  | def procedure( | 
					
						
							|  |  |  |     xloader, | 
					
						
							|  |  |  |     network, | 
					
						
							|  |  |  |     criterion, | 
					
						
							|  |  |  |     scheduler, | 
					
						
							|  |  |  |     optimizer, | 
					
						
							|  |  |  |     mode, | 
					
						
							|  |  |  |     config, | 
					
						
							|  |  |  |     extra_info, | 
					
						
							|  |  |  |     print_freq, | 
					
						
							|  |  |  |     logger, | 
					
						
							|  |  |  | ): | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     data_time, batch_time, losses, top1, top5 = ( | 
					
						
							|  |  |  |         AverageMeter(), | 
					
						
							|  |  |  |         AverageMeter(), | 
					
						
							|  |  |  |         AverageMeter(), | 
					
						
							|  |  |  |         AverageMeter(), | 
					
						
							|  |  |  |         AverageMeter(), | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     if mode == "train": | 
					
						
							|  |  |  |         network.train() | 
					
						
							|  |  |  |     elif mode == "valid": | 
					
						
							|  |  |  |         network.eval() | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         raise ValueError("The mode is not right : {:}".format(mode)) | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     # logger.log('[{:5s}] config ::  auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message())) | 
					
						
							|  |  |  |     logger.log( | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |         "[{:5s}] config ::  auxiliary={:}".format( | 
					
						
							|  |  |  |             mode, config.auxiliary if hasattr(config, "auxiliary") else -1 | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     ) | 
					
						
							|  |  |  |     end = time.time() | 
					
						
							|  |  |  |     for i, (inputs, targets) in enumerate(xloader): | 
					
						
							|  |  |  |         if mode == "train": | 
					
						
							|  |  |  |             scheduler.update(None, 1.0 * i / len(xloader)) | 
					
						
							|  |  |  |         # measure data loading time | 
					
						
							|  |  |  |         data_time.update(time.time() - end) | 
					
						
							|  |  |  |         # calculate prediction and loss | 
					
						
							|  |  |  |         targets = targets.cuda(non_blocking=True) | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |         if mode == "train": | 
					
						
							|  |  |  |             optimizer.zero_grad() | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |         features, logits = network(inputs) | 
					
						
							|  |  |  |         if isinstance(logits, list): | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |             assert len(logits) == 2, "logits must has {:} items instead of {:}".format( | 
					
						
							|  |  |  |                 2, len(logits) | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |             logits, logits_aux = logits | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             logits, logits_aux = logits, None | 
					
						
							|  |  |  |         loss = criterion(logits, targets) | 
					
						
							|  |  |  |         if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0: | 
					
						
							|  |  |  |             loss_aux = criterion(logits_aux, targets) | 
					
						
							|  |  |  |             loss += config.auxiliary * loss_aux | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |         if mode == "train": | 
					
						
							|  |  |  |             loss.backward() | 
					
						
							|  |  |  |             optimizer.step() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # record | 
					
						
							|  |  |  |         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | 
					
						
							|  |  |  |         losses.update(loss.item(), inputs.size(0)) | 
					
						
							|  |  |  |         top1.update(prec1.item(), inputs.size(0)) | 
					
						
							|  |  |  |         top5.update(prec5.item(), inputs.size(0)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # measure elapsed time | 
					
						
							|  |  |  |         batch_time.update(time.time() - end) | 
					
						
							|  |  |  |         end = time.time() | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |         if i % print_freq == 0 or (i + 1) == len(xloader): | 
					
						
							|  |  |  |             Sstr = ( | 
					
						
							|  |  |  |                 " {:5s} ".format(mode.upper()) | 
					
						
							|  |  |  |                 + time_string() | 
					
						
							|  |  |  |                 + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             if scheduler is not None: | 
					
						
							|  |  |  |                 Sstr += " {:}".format(scheduler.get_min_info()) | 
					
						
							|  |  |  |             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | 
					
						
							|  |  |  |                 batch_time=batch_time, data_time=data_time | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | 
					
						
							|  |  |  |                 loss=losses, top1=top1, top5=top5 | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             Istr = "Size={:}".format(list(inputs.size())) | 
					
						
							|  |  |  |             logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr) | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |     logger.log( | 
					
						
							|  |  |  |         " **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |             mode=mode.upper(), | 
					
						
							|  |  |  |             top1=top1, | 
					
						
							|  |  |  |             top5=top5, | 
					
						
							|  |  |  |             error1=100 - top1.avg, | 
					
						
							|  |  |  |             error5=100 - top5.avg, | 
					
						
							|  |  |  |             loss=losses.avg, | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |         ) | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     return losses.avg, top1.avg, top5.avg |