Unfinished Codes
This commit is contained in:
		
							
								
								
									
										125
									
								
								lib/procedures/advanced_main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										125
									
								
								lib/procedures/advanced_main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,125 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 # | ||||
| ##################################################### | ||||
| import os, sys, time, torch | ||||
| from typing import import Optional, Text, Callable | ||||
|  | ||||
| # modules in AutoDL | ||||
| from log_utils import AverageMeter | ||||
| from log_utils import time_string | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def basic_train( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     optimizer, | ||||
|     optim_config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     loss, acc1, acc5 = procedure( | ||||
|         xloader, | ||||
|         network, | ||||
|         criterion, | ||||
|         scheduler, | ||||
|         optimizer, | ||||
|         "train", | ||||
|         optim_config, | ||||
|         extra_info, | ||||
|         print_freq, | ||||
|         logger, | ||||
|     ) | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def basic_valid( | ||||
|     xloader, network, criterion, optim_config, extra_info, print_freq, logger | ||||
| ): | ||||
|     with torch.no_grad(): | ||||
|         loss, acc1, acc5 = procedure( | ||||
|             xloader, | ||||
|             network, | ||||
|             criterion, | ||||
|             None, | ||||
|             None, | ||||
|             "valid", | ||||
|             None, | ||||
|             extra_info, | ||||
|             print_freq, | ||||
|             logger, | ||||
|         ) | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def procedure( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     optimizer, | ||||
|     mode: Text, | ||||
|     print_freq: int = 100, | ||||
|     logger_fn: Callable = None | ||||
| ): | ||||
|     data_time, batch_time, losses = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     if mode.lower() == "train": | ||||
|         network.train() | ||||
|     elif mode.lower() == "valid": | ||||
|         network.eval() | ||||
|     else: | ||||
|         raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|  | ||||
|     end = time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|         # calculate prediction and loss | ||||
|         targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|         if mode == "train": | ||||
|             optimizer.zero_grad() | ||||
|  | ||||
|         outputs = network(inputs) | ||||
|         loss = criterion(outputs, targets) | ||||
|  | ||||
|         if mode == "train": | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         metrics =  | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|         losses.update(loss.item(), inputs.size(0)) | ||||
|         top1.update(prec1.item(), inputs.size(0)) | ||||
|         top5.update(prec5.item(), inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|  | ||||
|         if i % print_freq == 0 or (i + 1) == len(xloader): | ||||
|             Sstr = ( | ||||
|                 " {:5s} ".format(mode.upper()) | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) | ||||
|             ) | ||||
|             Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | ||||
|                 loss=losses, top1=top1, top5=top5 | ||||
|             ) | ||||
|             Istr = "Size={:}".format(list(inputs.size())) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr) | ||||
|  | ||||
|     logger.log( | ||||
|         " **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( | ||||
|             mode=mode.upper(), | ||||
|             top1=top1, | ||||
|             top5=top5, | ||||
|             error1=100 - top1.avg, | ||||
|             error5=100 - top5.avg, | ||||
|             loss=losses.avg, | ||||
|         ) | ||||
|     ) | ||||
|     return losses.avg, top1.avg, top5.avg | ||||
| @@ -1,3 +1,8 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 # | ||||
| ##################################################### | ||||
| import abc | ||||
|  | ||||
| def obtain_accuracy(output, target, topk=(1,)): | ||||
|     """Computes the precision@k for the specified values of k""" | ||||
|     maxk = max(topk) | ||||
| @@ -12,3 +17,12 @@ def obtain_accuracy(output, target, topk=(1,)): | ||||
|         correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||||
|         res.append(correct_k.mul_(100.0 / batch_size)) | ||||
|     return res | ||||
|  | ||||
|  | ||||
| class EvaluationMetric(abc.ABC): | ||||
|      | ||||
|     def __init__(self): | ||||
|         self._total_metrics = 0 | ||||
|  | ||||
|     def __len__(self): | ||||
|         return self._total_metrics | ||||
|   | ||||
							
								
								
									
										36
									
								
								lib/xlayers/super_activations.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								lib/xlayers/super_activations.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,36 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| import math | ||||
| from typing import Optional, Callable | ||||
|  | ||||
| import spaces | ||||
| from .super_module import SuperModule | ||||
| from .super_module import IntSpaceType | ||||
| from .super_module import BoolSpaceType | ||||
|  | ||||
|  | ||||
| class SuperReLU(SuperModule): | ||||
|     """Applies a the rectified linear unit function element-wise.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, inplace=False) -> None: | ||||
|         super(SuperReLU, self).__init__() | ||||
|         self._inplace = inplace | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         return spaces.VirtualNode(id(self)) | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return self.forward_raw(input) | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return F.relu(input, inplace=self._inplace) | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return 'inplace=True' if self._inplace else '' | ||||
| @@ -14,5 +14,8 @@ from .super_norm import SuperLayerNorm1D | ||||
| from .super_attention import SuperAttention | ||||
| from .super_transformer import SuperTransformerEncoderLayer | ||||
|  | ||||
| from .super_activations import SuperReLU | ||||
|  | ||||
| from .super_trade_stem import SuperAlphaEBDv1 | ||||
| from .super_positional_embedding import SuperPositionalEncoder | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user