autodl-projects/xautodl/xlayers/super_container.py

121 lines
4.2 KiB
Python
Raw Permalink Normal View History

2021-03-21 06:26:52 +01:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import torch
from itertools import islice
import operator
from collections import OrderedDict
from typing import Optional, Union, Callable, TypeVar, Iterator
2021-05-19 07:00:33 +02:00
from xautodl import spaces
2021-03-21 06:26:52 +01:00
from .super_module import SuperModule
T = TypeVar("T", bound=SuperModule)
class SuperSequential(SuperModule):
"""A sequential container wrapped with 'Super' ability.
Modules will be added to it in the order they are passed in the constructor.
Alternatively, an ordered dict of modules can also be passed in.
To make it easier to understand, here is a small example::
# Example of using Sequential
model = SuperSequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
# Example of using Sequential with OrderedDict
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
"""
def __init__(self, *args):
super(SuperSequential, self).__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
if not isinstance(args, (list, tuple)):
raise ValueError("Invalid input type: {:}".format(type(args)))
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def _get_item_by_idx(self, iterator, idx) -> T:
"""Get the idx-th item of the iterator"""
size = len(self)
idx = operator.index(idx)
if not -size <= idx < size:
raise IndexError("index {} is out of range".format(idx))
idx %= size
return next(islice(iterator, idx, None))
def __getitem__(self, idx) -> Union["SuperSequential", T]:
if isinstance(idx, slice):
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
else:
return self._get_item_by_idx(self._modules.values(), idx)
def __setitem__(self, idx: int, module: SuperModule) -> None:
key: str = self._get_item_by_idx(self._modules.keys(), idx)
return setattr(self, key, module)
def __delitem__(self, idx: Union[slice, int]) -> None:
if isinstance(idx, slice):
for key in list(self._modules.keys())[idx]:
delattr(self, key)
else:
key = self._get_item_by_idx(self._modules.keys(), idx)
delattr(self, key)
def __len__(self) -> int:
return len(self._modules)
def __dir__(self):
keys = super(SuperSequential, self).__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
def __iter__(self) -> Iterator[SuperModule]:
return iter(self._modules.values())
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
for index, module in enumerate(self):
if not isinstance(module, SuperModule):
continue
2021-03-21 06:26:52 +01:00
space = module.abstract_search_space
if not spaces.is_determined(space):
root_node.append(str(index), space)
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperSequential, self).apply_candidate(abstract_child)
for index, module in enumerate(self):
2021-03-21 06:26:52 +01:00
if str(index) in abstract_child:
module.apply_candidate(abstract_child[str(index)])
2021-03-21 06:26:52 +01:00
def forward_candidate(self, input):
return self.forward_raw(input)
def forward_raw(self, input):
for module in self:
input = module(input)
return input
2021-05-07 04:26:35 +02:00
def forward_with_container(self, input, container, prefix=[]):
for index, module in enumerate(self):
input = module.forward_with_container(
input, container, prefix + [str(index)]
)
return input