##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # ##################################################### import torch import torch.nn as nn import torch.nn.functional as F import math from typing import Optional, Callable from xautodl import spaces from .super_module import SuperModule from .super_module import IntSpaceType from .super_module import BoolSpaceType class SuperLayerNorm1D(SuperModule): """Super Layer Norm.""" def __init__( self, dim: IntSpaceType, eps: float = 1e-6, elementwise_affine: bool = True ) -> None: super(SuperLayerNorm1D, self).__init__() self._in_dim = dim self._eps = eps self._elementwise_affine = elementwise_affine if self._elementwise_affine: self.register_parameter("weight", nn.Parameter(torch.Tensor(self.in_dim))) self.register_parameter("bias", nn.Parameter(torch.Tensor(self.in_dim))) else: self.register_parameter("weight", None) self.register_parameter("bias", None) self.reset_parameters() @property def in_dim(self): return spaces.get_max(self._in_dim) @property def eps(self): return self._eps def reset_parameters(self) -> None: if self._elementwise_affine: nn.init.ones_(self.weight) nn.init.zeros_(self.bias) @property def abstract_search_space(self): root_node = spaces.VirtualNode(id(self)) if not spaces.is_determined(self._in_dim): root_node.append("_in_dim", self._in_dim.abstract(reuse_last=True)) return root_node def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: # check inputs -> if not spaces.is_determined(self._in_dim): expected_input_dim = self.abstract_child["_in_dim"].value else: expected_input_dim = spaces.get_determined_value(self._in_dim) if input.size(-1) != expected_input_dim: raise ValueError( "Expect the input dim of {:} instead of {:}".format( expected_input_dim, input.size(-1) ) ) if self._elementwise_affine: weight = self.weight[:expected_input_dim] bias = self.bias[:expected_input_dim] else: weight, bias = None, None return F.layer_norm(input, (expected_input_dim,), weight, bias, self.eps) def forward_raw(self, input: torch.Tensor) -> torch.Tensor: return F.layer_norm(input, (self.in_dim,), self.weight, self.bias, self.eps) def forward_with_container(self, input, container, prefix=[]): super_weight_name = ".".join(prefix + ["weight"]) if container.has(super_weight_name): weight = container.query(super_weight_name) else: weight = None super_bias_name = ".".join(prefix + ["bias"]) if container.has(super_bias_name): bias = container.query(super_bias_name) else: bias = None return F.layer_norm(input, (self.in_dim,), weight, bias, self.eps) def extra_repr(self) -> str: return ( "shape={in_dim}, eps={eps}, elementwise_affine={elementwise_affine}".format( in_dim=self._in_dim, eps=self._eps, elementwise_affine=self._elementwise_affine, ) ) class SuperSimpleNorm(SuperModule): """Super simple normalization.""" def __init__(self, mean, std, inplace=False) -> None: super(SuperSimpleNorm, self).__init__() self.register_buffer("_mean", torch.tensor(mean, dtype=torch.float)) self.register_buffer("_std", torch.tensor(std, dtype=torch.float)) self._inplace = inplace @property def abstract_search_space(self): return spaces.VirtualNode(id(self)) def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: # check inputs -> return self.forward_raw(input) def forward_raw(self, input: torch.Tensor) -> torch.Tensor: if not self._inplace: tensor = input.clone() else: tensor = input mean = torch.as_tensor(self._mean, dtype=tensor.dtype, device=tensor.device) std = torch.as_tensor(self._std, dtype=tensor.dtype, device=tensor.device) if (std == 0).any(): raise ValueError( "std evaluated to zero after conversion to {}, leading to division by zero.".format( tensor.dtype ) ) while mean.ndim < tensor.ndim: mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0) return tensor.sub_(mean).div_(std) def extra_repr(self) -> str: return "mean={mean}, std={std}, inplace={inplace}".format( mean=self._mean.item(), std=self._std.item(), inplace=self._inplace ) class SuperSimpleLearnableNorm(SuperModule): """Super simple normalization.""" def __init__(self, mean=0, std=1, eps=1e-6, inplace=False) -> None: super(SuperSimpleLearnableNorm, self).__init__() self.register_parameter( "_mean", nn.Parameter(torch.tensor(mean, dtype=torch.float)) ) self.register_parameter( "_std", nn.Parameter(torch.tensor(std, dtype=torch.float)) ) self._eps = eps self._inplace = inplace @property def abstract_search_space(self): return spaces.VirtualNode(id(self)) def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: # check inputs -> return self.forward_raw(input) def forward_raw(self, input: torch.Tensor) -> torch.Tensor: if not self._inplace: tensor = input.clone() else: tensor = input mean, std = ( self._mean.to(tensor.device), torch.abs(self._std.to(tensor.device)) + self._eps, ) if (std == 0).any(): raise ValueError("std leads to division by zero.") while mean.ndim < tensor.ndim: mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0) return tensor.sub_(mean).div_(std) def forward_with_container(self, input, container, prefix=[]): if not self._inplace: tensor = input.clone() else: tensor = input mean_name = ".".join(prefix + ["_mean"]) std_name = ".".join(prefix + ["_std"]) mean, std = ( container.query(mean_name).to(tensor.device), torch.abs(container.query(std_name).to(tensor.device)) + self._eps, ) while mean.ndim < tensor.ndim: mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0) return tensor.sub_(mean).div_(std) def extra_repr(self) -> str: return "mean={mean}, std={std}, inplace={inplace}".format( mean=self._mean.item(), std=self._std.item(), inplace=self._inplace ) class SuperIdentity(SuperModule): """Super identity mapping layer.""" def __init__(self, inplace=False, **kwargs) -> None: super(SuperIdentity, self).__init__() self._inplace = inplace @property def abstract_search_space(self): return spaces.VirtualNode(id(self)) def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: # check inputs -> return self.forward_raw(input) def forward_raw(self, input: torch.Tensor) -> torch.Tensor: if not self._inplace: tensor = input.clone() else: tensor = input return tensor def extra_repr(self) -> str: return "inplace={inplace}".format(inplace=self._inplace) def forward_with_container(self, input, container, prefix=[]): return self.forward_raw(input)