45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
|
#####################################################
|
||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||
|
#############################################################
|
||
|
# Borrow the idea of https://github.com/arogozhnikov/einops #
|
||
|
#############################################################
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
import math
|
||
|
from typing import Optional, Callable
|
||
|
|
||
|
from xautodl import spaces
|
||
|
from .super_module import SuperModule
|
||
|
from .super_module import IntSpaceType
|
||
|
from .super_module import BoolSpaceType
|
||
|
|
||
|
|
||
|
class SuperRearrange(SuperModule):
|
||
|
"""Applies the rearrange operation."""
|
||
|
|
||
|
def __init__(self, pattern, **axes_lengths):
|
||
|
super(SuperRearrange, self).__init__()
|
||
|
|
||
|
self._pattern = pattern
|
||
|
self._axes_lengths = axes_lengths
|
||
|
self.reset_parameters()
|
||
|
|
||
|
@property
|
||
|
def abstract_search_space(self):
|
||
|
root_node = spaces.VirtualNode(id(self))
|
||
|
return root_node
|
||
|
|
||
|
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def extra_repr(self) -> str:
|
||
|
params = repr(self._pattern)
|
||
|
for axis, length in self._axes_lengths.items():
|
||
|
params += ", {}={}".format(axis, length)
|
||
|
return "{}({})".format(self.__class__.__name__, params)
|