Add more algorithms

This commit is contained in:
D-X-Y
2019-09-28 18:24:47 +10:00
parent bfd6b648fd
commit cfb462e463
286 changed files with 10557 additions and 122955 deletions

View File

@@ -0,0 +1,13 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .configure_utils import load_config, dict2config, configure2str
from .basic_args import obtain_basic_args
from .attention_args import obtain_attention_args
from .random_baseline import obtain_RandomSearch_args
from .cls_kd_args import obtain_cls_kd_args
from .cls_init_args import obtain_cls_init_args
from .search_single_args import obtain_search_single_args
from .search_args import obtain_search_args
# for network pruning
from .pruning_args import obtain_pruning_args

View File

@@ -0,0 +1,25 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, random, argparse
from .share_args import add_shared_args
def obtain_attention_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--att_channel' , type=int, help='.')
parser.add_argument('--att_spatial' , type=str, help='.')
parser.add_argument('--att_active' , type=str, help='.')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, 'save-path argument can not be None'
return args

View File

@@ -0,0 +1,23 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, random, argparse
from .share_args import add_shared_args
def obtain_basic_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--model_source', type=str, default='normal',help='The source of model defination.')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, 'save-path argument can not be None'
return args

View File

@@ -0,0 +1,23 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, random, argparse
from .share_args import add_shared_args
def obtain_cls_init_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--init_checkpoint', type=str, help='The checkpoint path to the initial model.')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, 'save-path argument can not be None'
return args

View File

@@ -0,0 +1,26 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, random, argparse
from .share_args import add_shared_args
def obtain_cls_kd_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--KD_checkpoint', type=str, help='The teacher checkpoint in knowledge distillation.')
parser.add_argument('--KD_alpha' , type=float, help='The alpha parameter in knowledge distillation.')
parser.add_argument('--KD_temperature', type=float, help='The temperature parameter in knowledge distillation.')
#parser.add_argument('--KD_feature', type=float, help='Knowledge distillation at the feature level.')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, 'save-path argument can not be None'
return args

View File

@@ -0,0 +1,105 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os, sys, json
from os import path as osp
from pathlib import Path
from collections import namedtuple
support_types = ('str', 'int', 'bool', 'float', 'none')
def convert_param(original_lists):
assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists)
ctype, value = original_lists[0], original_lists[1]
assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types)
is_list = isinstance(value, list)
if not is_list: value = [value]
outs = []
for x in value:
if ctype == 'int':
x = int(x)
elif ctype == 'str':
x = str(x)
elif ctype == 'bool':
x = bool(int(x))
elif ctype == 'float':
x = float(x)
elif ctype == 'none':
assert x == 'None', 'for none type, the value must be None instead of {:}'.format(x)
x = None
else:
raise TypeError('Does not know this type : {:}'.format(ctype))
outs.append(x)
if not is_list: outs = outs[0]
return outs
def load_config(path, extra, logger):
path = str(path)
if hasattr(logger, 'log'): logger.log(path)
assert os.path.exists(path), 'Can not find {:}'.format(path)
# Reading data back
with open(path, 'r') as f:
data = json.load(f)
content = { k: convert_param(v) for k,v in data.items()}
assert extra is None or isinstance(extra, dict), 'invalid type of extra : {:}'.format(extra)
if isinstance(extra, dict): content = {**content, **extra}
Arguments = namedtuple('Configure', ' '.join(content.keys()))
content = Arguments(**content)
if hasattr(logger, 'log'): logger.log('{:}'.format(content))
return content
def configure2str(config, xpath=None):
if not isinstance(config, dict):
config = config._asdict()
def cstring(x):
return "\"{:}\"".format(x)
def gtype(x):
if isinstance(x, list): x = x[0]
if isinstance(x, str) : return 'str'
elif isinstance(x, bool) : return 'bool'
elif isinstance(x, int): return 'int'
elif isinstance(x, float): return 'float'
elif x is None : return 'none'
else: raise ValueError('invalid : {:}'.format(x))
def cvalue(x, xtype):
if isinstance(x, list): is_list = True
else:
is_list, x = False, [x]
temps = []
for temp in x:
if xtype == 'bool' : temp = cstring(int(temp))
elif xtype == 'none': temp = cstring('None')
else : temp = cstring(temp)
temps.append( temp )
if is_list:
return "[{:}]".format( ', '.join( temps ) )
else:
return temps[0]
xstrings = []
for key, value in config.items():
xtype = gtype(value)
string = ' {:20s} : [{:8s}, {:}]'.format(cstring(key), cstring(xtype), cvalue(value, xtype))
xstrings.append(string)
Fstring = '{\n' + ',\n'.join(xstrings) + '\n}'
if xpath is not None:
parent = Path(xpath).resolve().parent
parent.mkdir(parents=True, exist_ok=True)
if osp.isfile(xpath): os.remove(xpath)
with open(xpath, "w") as text_file:
text_file.write('{:}'.format(Fstring))
return Fstring
def dict2config(xdict, logger):
assert isinstance(xdict, dict), 'invalid type : {:}'.format( type(xdict) )
Arguments = namedtuple('Configure', ' '.join(xdict.keys()))
content = Arguments(**xdict)
if hasattr(logger, 'log'): logger.log('{:}'.format(content))
return content

View File

@@ -0,0 +1,29 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, random, argparse
from .share_args import add_shared_args
def obtain_pruning_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--keep_ratio' , type=float, help='The left channel ratio compared to the original network.')
parser.add_argument('--model_version', type=str, help='The network version.')
parser.add_argument('--KD_alpha' , type=float, help='The alpha parameter in knowledge distillation.')
parser.add_argument('--KD_temperature', type=float, help='The temperature parameter in knowledge distillation.')
parser.add_argument('--Regular_W_feat', type=float, help='The .')
parser.add_argument('--Regular_W_conv', type=float, help='The .')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, 'save-path argument can not be None'
assert args.keep_ratio > 0 and args.keep_ratio <= 1, 'invalid keep ratio : {:}'.format(args.keep_ratio)
return args

View File

@@ -0,0 +1,27 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, random, argparse
from .share_args import add_shared_args
def obtain_RandomSearch_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
parser.add_argument('--expect_flop', type=float, help='The expected flop keep ratio.')
parser.add_argument('--arch_nums' , type=int, help='The maximum number of running random arch generating..')
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
parser.add_argument('--random_mode', type=str, choices=['random', 'fix'], help='The path to the optimizer configuration')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, 'save-path argument can not be None'
#assert args.flop_ratio_min < args.flop_ratio_max, 'flop-ratio {:} vs {:}'.format(args.flop_ratio_min, args.flop_ratio_max)
return args

View File

@@ -0,0 +1,35 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, random, argparse
from .share_args import add_shared_args
def obtain_search_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--model_config' , type=str, help='The path to the model configuration')
parser.add_argument('--optim_config' , type=str, help='The path to the optimizer configuration')
parser.add_argument('--split_path' , type=str, help='The split file path.')
#parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.')
parser.add_argument('--gumbel_tau_max', type=float, help='The maximum tau for Gumbel.')
parser.add_argument('--gumbel_tau_min', type=float, help='The minimum tau for Gumbel.')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--FLOP_ratio' , type=float, help='The expected FLOP ratio.')
parser.add_argument('--FLOP_weight' , type=float, help='The loss weight for FLOP.')
parser.add_argument('--FLOP_tolerant' , type=float, help='The tolerant range for FLOP.')
# ablation studies
parser.add_argument('--ablation_num_select', type=int, help='The number of randomly selected channels.')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size' , type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, 'save-path argument can not be None'
assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None
assert args.FLOP_tolerant is not None and args.FLOP_tolerant > 0, 'invalid FLOP_tolerant : {:}'.format(FLOP_tolerant)
#assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure)
#args.arch_para_pure = bool(args.arch_para_pure)
return args

View File

@@ -0,0 +1,34 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, random, argparse
from .share_args import add_shared_args
def obtain_search_single_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--model_config' , type=str, help='The path to the model configuration')
parser.add_argument('--optim_config' , type=str, help='The path to the optimizer configuration')
parser.add_argument('--split_path' , type=str, help='The split file path.')
parser.add_argument('--search_shape' , type=str, help='The shape to be searched.')
#parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.')
parser.add_argument('--gumbel_tau_max', type=float, help='The maximum tau for Gumbel.')
parser.add_argument('--gumbel_tau_min', type=float, help='The minimum tau for Gumbel.')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--FLOP_ratio' , type=float, help='The expected FLOP ratio.')
parser.add_argument('--FLOP_weight' , type=float, help='The loss weight for FLOP.')
parser.add_argument('--FLOP_tolerant' , type=float, help='The tolerant range for FLOP.')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size' , type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, 'save-path argument can not be None'
assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None
assert args.FLOP_tolerant is not None and args.FLOP_tolerant > 0, 'invalid FLOP_tolerant : {:}'.format(FLOP_tolerant)
#assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure)
#args.arch_para_pure = bool(args.arch_para_pure)
return args

View File

@@ -0,0 +1,20 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, random, argparse
def add_shared_args( parser ):
# Data Generation
parser.add_argument('--dataset', type=str, help='The dataset name.')
parser.add_argument('--data_path', type=str, help='The dataset name.')
parser.add_argument('--cutout_length', type=int, help='The cutout length, negative means not use.')
# Printing
parser.add_argument('--print_freq', type=int, default=100, help='print frequency (default: 200)')
parser.add_argument('--print_freq_eval', type=int, default=100, help='print frequency (default: 200)')
# Checkpoints
parser.add_argument('--eval_frequency', type=int, default=1, help='evaluation frequency (default: 200)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
# Acceleration
parser.add_argument('--workers', type=int, default=8, help='number of data loading workers (default: 8)')
# Random Seed
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')

View File

@@ -0,0 +1,126 @@
import os, sys, hashlib, torch
import numpy as np
from PIL import Image
import torch.utils.data as data
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
def calculate_md5(fpath, chunk_size=1024 * 1024):
md5 = hashlib.md5()
with open(fpath, 'rb') as f:
for chunk in iter(lambda: f.read(chunk_size), b''):
md5.update(chunk)
return md5.hexdigest()
def check_md5(fpath, md5, **kwargs):
return md5 == calculate_md5(fpath, **kwargs)
def check_integrity(fpath, md5=None):
if not os.path.isfile(fpath): return False
if md5 is None: return True
else : return check_md5(fpath, md5)
class ImageNet16(data.Dataset):
# http://image-net.org/download-images
# A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets
# https://arxiv.org/pdf/1707.08819.pdf
train_list = [
['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'],
['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'],
['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'],
['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'],
['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'],
['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'],
['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'],
['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'],
['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'],
['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'],
]
valid_list = [
['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'],
]
def __init__(self, root, train, transform, use_num_of_class_only=None):
self.root = root
self.transform = transform
self.train = train # training set or valid set
if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.')
if self.train: downloaded_list = self.train_list
else : downloaded_list = self.valid_list
self.data = []
self.targets = []
# now load the picked numpy arrays
for i, (file_name, checksum) in enumerate(downloaded_list):
file_path = os.path.join(self.root, file_name)
#print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path))
with open(file_path, 'rb') as f:
if sys.version_info[0] == 2:
entry = pickle.load(f)
else:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
self.targets.extend(entry['labels'])
self.data = np.vstack(self.data).reshape(-1, 3, 16, 16)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
if use_num_of_class_only is not None:
assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only)
new_data, new_targets = [], []
for I, L in zip(self.data, self.targets):
if 1 <= L <= use_num_of_class_only:
new_data.append( I )
new_targets.append( L )
self.data = new_data
self.targets = new_targets
# self.mean.append(entry['mean'])
#self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16)
#self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1)
#print ('Mean : {:}'.format(self.mean))
#temp = self.data - np.reshape(self.mean, (1, 1, 1, 3))
#std_data = np.std(temp, axis=0)
#std_data = np.mean(np.mean(std_data, axis=0), axis=0)
#print ('Std : {:}'.format(std_data))
def __getitem__(self, index):
img, target = self.data[index], self.targets[index] - 1
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.data)
def _check_integrity(self):
root = self.root
for fentry in (self.train_list + self.valid_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, filename)
if not check_integrity(fpath, md5):
return False
return True
#
if __name__ == '__main__':
train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None)
valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None)
print ( len(train) )
print ( len(valid) )
image, label = train[111]
trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None, 200)
validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False , None, 200)
print ( len(trainX) )
print ( len(validX) )
#import pdb; pdb.set_trace()

View File

@@ -0,0 +1,191 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from os import path as osp
from copy import deepcopy as copy
from tqdm import tqdm
import warnings, time, random, numpy as np
from pts_utils import generate_label_map
from xvision import denormalize_points
from xvision import identity2affine, solve2theta, affine2image
from .dataset_utils import pil_loader
from .landmark_utils import PointMeta2V
from .augmentation_utils import CutOut
import torch
import torch.utils.data as data
class LandmarkDataset(data.Dataset):
def __init__(self, transform, sigma, downsample, heatmap_type, shape, use_gray, mean_file, data_indicator, cache_images=None):
self.transform = transform
self.sigma = sigma
self.downsample = downsample
self.heatmap_type = heatmap_type
self.dataset_name = data_indicator
self.shape = shape # [H,W]
self.use_gray = use_gray
assert transform is not None, 'transform : {:}'.format(transform)
self.mean_file = mean_file
if mean_file is None:
self.mean_data = None
warnings.warn('LandmarkDataset initialized with mean_data = None')
else:
assert osp.isfile(mean_file), '{:} is not a file.'.format(mean_file)
self.mean_data = torch.load(mean_file)
self.reset()
self.cutout = None
self.cache_images = cache_images
print ('The general dataset initialization done : {:}'.format(self))
warnings.simplefilter( 'once' )
def __repr__(self):
return ('{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})'.format(name=self.__class__.__name__, **self.__dict__))
def set_cutout(self, length):
if length is not None and length >= 1:
self.cutout = CutOut( int(length) )
else: self.cutout = None
def reset(self, num_pts=-1, boxid='default', only_pts=False):
self.NUM_PTS = num_pts
if only_pts: return
self.length = 0
self.datas = []
self.labels = []
self.NormDistances = []
self.BOXID = boxid
if self.mean_data is None:
self.mean_face = None
else:
self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T)
assert (self.mean_face >= -1).all() and (self.mean_face <= 1).all(), 'mean-{:}-face : {:}'.format(boxid, self.mean_face)
#assert self.dataset_name is not None, 'The dataset name is None'
def __len__(self):
assert len(self.datas) == self.length, 'The length is not correct : {}'.format(self.length)
return self.length
def append(self, data, label, distance):
assert osp.isfile(data), 'The image path is not a file : {:}'.format(data)
self.datas.append( data ) ; self.labels.append( label )
self.NormDistances.append( distance )
self.length = self.length + 1
def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset):
if reset: self.reset(num_pts, boxindicator)
else : assert self.NUM_PTS == num_pts and self.BOXID == boxindicator, 'The number of point is inconsistance : {:} vs {:}'.format(self.NUM_PTS, num_pts)
if isinstance(file_lists, str): file_lists = [file_lists]
samples = []
for idx, file_path in enumerate(file_lists):
print (':::: load list {:}/{:} : {:}'.format(idx, len(file_lists), file_path))
xdata = torch.load(file_path)
if isinstance(xdata, list) : data = xdata # image or video dataset list
elif isinstance(xdata, dict): data = xdata['datas'] # multi-view dataset list
else: raise ValueError('Invalid Type Error : {:}'.format( type(xdata) ))
samples = samples + data
# samples is a dict, where the key is the image-path and the value is the annotation
# each annotation is a dict, contains 'points' (3,num_pts), and various box
print ('GeneralDataset-V2 : {:} samples'.format(len(samples)))
#for index, annotation in enumerate(samples):
for index in tqdm( range( len(samples) ) ):
annotation = samples[index]
image_path = annotation['current_frame']
points, box = annotation['points'], annotation['box-{:}'.format(boxindicator)]
label = PointMeta2V(self.NUM_PTS, points, box, image_path, self.dataset_name)
if normalizeL is None: normDistance = None
else : normDistance = annotation['normalizeL-{:}'.format(normalizeL)]
self.append(image_path, label, normDistance)
assert len(self.datas) == self.length, 'The length and the data is not right {} vs {}'.format(self.length, len(self.datas))
assert len(self.labels) == self.length, 'The length and the labels is not right {} vs {}'.format(self.length, len(self.labels))
assert len(self.NormDistances) == self.length, 'The length and the NormDistances is not right {} vs {}'.format(self.length, len(self.NormDistance))
print ('Load data done for LandmarkDataset, which has {:} images.'.format(self.length))
def __getitem__(self, index):
assert index >= 0 and index < self.length, 'Invalid index : {:}'.format(index)
if self.cache_images is not None and self.datas[index] in self.cache_images:
image = self.cache_images[ self.datas[index] ].clone()
else:
image = pil_loader(self.datas[index], self.use_gray)
target = self.labels[index].copy()
return self._process_(image, target, index)
def _process_(self, image, target, index):
# transform the image and points
image, target, theta = self.transform(image, target)
(C, H, W), (height, width) = image.size(), self.shape
# obtain the visiable indicator vector
if target.is_none(): nopoints = True
else : nopoints = False
if index == -1: __path = None
else : __path = self.datas[index]
if isinstance(theta, list) or isinstance(theta, tuple):
affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = [], [], [], [], [], []
for _theta in theta:
_affineImage, _heatmaps, _mask, _norm_trans_points, _theta, _transpose_theta \
= self.__process_affine(image, target, _theta, nopoints, 'P[{:}]@{:}'.format(index, __path))
affineImage.append(_affineImage)
heatmaps.append(_heatmaps)
mask.append(_mask)
norm_trans_points.append(_norm_trans_points)
THETA.append(_theta)
transpose_theta.append(_transpose_theta)
affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = \
torch.stack(affineImage), torch.stack(heatmaps), torch.stack(mask), torch.stack(norm_trans_points), torch.stack(THETA), torch.stack(transpose_theta)
else:
affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = self.__process_affine(image, target, theta, nopoints, 'S[{:}]@{:}'.format(index, __path))
torch_index = torch.IntTensor([index])
torch_nopoints = torch.ByteTensor( [ nopoints ] )
torch_shape = torch.IntTensor([H,W])
return affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta, torch_index, torch_nopoints, torch_shape
def __process_affine(self, image, target, theta, nopoints, aux_info=None):
image, target, theta = image.clone(), target.copy(), theta.clone()
(C, H, W), (height, width) = image.size(), self.shape
if nopoints: # do not have label
norm_trans_points = torch.zeros((3, self.NUM_PTS))
heatmaps = torch.zeros((self.NUM_PTS+1, height//self.downsample, width//self.downsample))
mask = torch.ones((self.NUM_PTS+1, 1, 1), dtype=torch.uint8)
transpose_theta = identity2affine(False)
else:
norm_trans_points = apply_affine2point(target.get_points(), theta, (H,W))
norm_trans_points = apply_boundary(norm_trans_points)
real_trans_points = norm_trans_points.clone()
real_trans_points[:2, :] = denormalize_points(self.shape, real_trans_points[:2,:])
heatmaps, mask = generate_label_map(real_trans_points.numpy(), height//self.downsample, width//self.downsample, self.sigma, self.downsample, nopoints, self.heatmap_type) # H*W*C
heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(torch.FloatTensor)
mask = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor)
if self.mean_face is None:
#warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.')
transpose_theta = identity2affine(False)
else:
if torch.sum(norm_trans_points[2,:] == 1) < 3:
warnings.warn('In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}'.format(aux_info))
transpose_theta = identity2affine(False)
else:
transpose_theta = solve2theta(norm_trans_points, self.mean_face.clone())
affineImage = affine2image(image, theta, self.shape)
if self.cutout is not None: affineImage = self.cutout( affineImage )
return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta

View File

@@ -1,122 +0,0 @@
import os
import torch
from collections import Counter
class Dictionary(object):
def __init__(self):
self.word2idx = {}
self.idx2word = []
self.counter = Counter()
self.total = 0
def add_word(self, word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
token_id = self.word2idx[word]
self.counter[token_id] += 1
self.total += 1
return self.word2idx[word]
def __len__(self):
return len(self.idx2word)
class Corpus(object):
def __init__(self, path):
self.dictionary = Dictionary()
self.train = self.tokenize(os.path.join(path, 'train.txt'))
self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
self.test = self.tokenize(os.path.join(path, 'test.txt'))
def tokenize(self, path):
"""Tokenizes a text file."""
assert os.path.exists(path)
# Add words to the dictionary
with open(path, 'r', encoding='utf-8') as f:
tokens = 0
for line in f:
words = line.split() + ['<eos>']
tokens += len(words)
for word in words:
self.dictionary.add_word(word)
# Tokenize file content
with open(path, 'r', encoding='utf-8') as f:
ids = torch.LongTensor(tokens)
token = 0
for line in f:
words = line.split() + ['<eos>']
for word in words:
ids[token] = self.dictionary.word2idx[word]
token += 1
return ids
class SentCorpus(object):
def __init__(self, path):
self.dictionary = Dictionary()
self.train = self.tokenize(os.path.join(path, 'train.txt'))
self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
self.test = self.tokenize(os.path.join(path, 'test.txt'))
def tokenize(self, path):
"""Tokenizes a text file."""
assert os.path.exists(path)
# Add words to the dictionary
with open(path, 'r', encoding='utf-8') as f:
tokens = 0
for line in f:
words = line.split() + ['<eos>']
tokens += len(words)
for word in words:
self.dictionary.add_word(word)
# Tokenize file content
sents = []
with open(path, 'r', encoding='utf-8') as f:
for line in f:
if not line:
continue
words = line.split() + ['<eos>']
sent = torch.LongTensor(len(words))
for i, word in enumerate(words):
sent[i] = self.dictionary.word2idx[word]
sents.append(sent)
return sents
class BatchSentLoader(object):
def __init__(self, sents, batch_size, pad_id=0, cuda=False, volatile=False):
self.sents = sents
self.batch_size = batch_size
self.sort_sents = sorted(sents, key=lambda x: x.size(0))
self.cuda = cuda
self.volatile = volatile
self.pad_id = pad_id
def __next__(self):
if self.idx >= len(self.sort_sents):
raise StopIteration
batch_size = min(self.batch_size, len(self.sort_sents)-self.idx)
batch = self.sort_sents[self.idx:self.idx+batch_size]
max_len = max([s.size(0) for s in batch])
tensor = torch.LongTensor(max_len, batch_size).fill_(self.pad_id)
for i in range(len(batch)):
s = batch[i]
tensor[:s.size(0),i].copy_(s)
if self.cuda:
tensor = tensor.cuda()
self.idx += batch_size
return tensor
next = __next__
def __iter__(self):
self.idx = 0
return self

View File

@@ -1,65 +0,0 @@
# coding=utf-8
import numpy as np
import torch
class MetaBatchSampler(object):
def __init__(self, labels, classes_per_it, num_samples, iterations):
'''
Initialize MetaBatchSampler
Args:
- labels: an iterable containing all the labels for the current dataset
samples indexes will be infered from this iterable.
- classes_per_it: number of random classes for each iteration
- num_samples: number of samples for each iteration for each class (support + query)
- iterations: number of iterations (episodes) per epoch
'''
super(MetaBatchSampler, self).__init__()
self.labels = labels.copy()
self.classes_per_it = classes_per_it
self.sample_per_class = num_samples
self.iterations = iterations
self.classes, self.counts = np.unique(self.labels, return_counts=True)
assert len(self.classes) == np.max(self.classes) + 1 and np.min(self.classes) == 0
assert classes_per_it < len(self.classes), '{:} vs. {:}'.format(classes_per_it, len(self.classes))
self.classes = torch.LongTensor(self.classes)
# create a matrix, indexes, of dim: classes X max(elements per class)
# fill it with nans
# for every class c, fill the relative row with the indices samples belonging to c
# in numel_per_class we store the number of samples for each class/row
self.indexes = { x.item() : [] for x in self.classes }
indexes = { x.item() : [] for x in self.classes }
for idx, label in enumerate(self.labels):
indexes[ label.item() ].append( idx )
for key, value in indexes.items():
self.indexes[ key ] = torch.LongTensor( value )
def __iter__(self):
# yield a batch of indexes
spc = self.sample_per_class
cpi = self.classes_per_it
for it in range(self.iterations):
batch_size = spc * cpi
batch = torch.LongTensor(batch_size)
assert cpi < len(self.classes), '{:} vs. {:}'.format(cpi, len(self.classes))
c_idxs = torch.randperm(len(self.classes))[:cpi]
for i, cls in enumerate(self.classes[c_idxs]):
s = slice(i * spc, (i + 1) * spc)
num = self.indexes[ cls.item() ].nelement()
assert spc < num, '{:} vs. {:}'.format(spc, num)
sample_idxs = torch.randperm( num )[:spc]
batch[s] = self.indexes[ cls.item() ][sample_idxs]
batch = batch[torch.randperm(len(batch))]
yield batch
def __len__(self):
# returns the number of iterations (episodes) per epoch
return self.iterations

View File

@@ -0,0 +1,26 @@
import torch, copy, random
import torch.utils.data as data
class SearchDataset(data.Dataset):
def __init__(self, name, data, train_split, valid_split):
self.datasetname = name
self.data = data
self.train_split = train_split.copy()
self.valid_split = valid_split.copy()
self.length = len(self.train_split)
def __repr__(self):
return ('{name}(name={datasetname}, length={length})'.format(name=self.__class__.__name__, **self.__dict__))
def __len__(self):
return self.length
def __getitem__(self, index):
assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index)
train_index = self.train_split[index]
valid_index = random.choice( self.valid_split )
train_image, train_label = self.data[train_index]
valid_image, valid_label = self.data[valid_index]
return train_image, train_label, valid_image, valid_label

View File

@@ -1,84 +0,0 @@
from __future__ import print_function
import numpy as np
from PIL import Image
import pickle as pkl
import os, cv2, csv, glob
import torch
import torch.utils.data as data
class TieredImageNet(data.Dataset):
def __init__(self, root_dir, split, transform=None):
self.split = split
self.root_dir = root_dir
self.transform = transform
splits = split.split('-')
images, labels, last = [], [], 0
for split in splits:
labels_name = '{:}/{:}_labels.pkl'.format(self.root_dir, split)
images_name = '{:}/{:}_images.npz'.format(self.root_dir, split)
# decompress images if npz not exits
if not os.path.exists(images_name):
png_pkl = images_name[:-4] + '_png.pkl'
if os.path.exists(png_pkl):
decompress(images_name, png_pkl)
else:
raise ValueError('png_pkl {:} not exits'.format( png_pkl ))
assert os.path.exists(images_name) and os.path.exists(labels_name), '{:} & {:}'.format(images_name, labels_name)
print ("Prepare {:} done".format(images_name))
try:
with open(labels_name) as f:
data = pkl.load(f)
label_specific = data["label_specific"]
except:
with open(labels_name, 'rb') as f:
data = pkl.load(f, encoding='bytes')
label_specific = data[b'label_specific']
with np.load(images_name, mmap_mode="r", encoding='latin1') as data:
image_data = data["images"]
images.append( image_data )
label_specific = label_specific + last
labels.append( label_specific )
last = np.max(label_specific) + 1
print ("Load {:} done, with image shape = {:}, label shape = {:}, [{:} ~ {:}]".format(images_name, image_data.shape, label_specific.shape, np.min(label_specific), np.max(label_specific)))
images, labels = np.concatenate(images), np.concatenate(labels)
self.images = images
self.labels = labels
self.n_classes = int( np.max(labels) + 1 )
self.dict_index_label = {}
for cls in range(self.n_classes):
idxs = np.where(labels==cls)[0]
self.dict_index_label[cls] = idxs
self.length = len(labels)
print ("There are {:} images, {:} labels [{:} ~ {:}]".format(images.shape, labels.shape, np.min(labels), np.max(labels)))
def __repr__(self):
return ('{name}(length={length}, classes={n_classes})'.format(name=self.__class__.__name__, **self.__dict__))
def __len__(self):
return self.length
def __getitem__(self, index):
assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index)
image = self.images[index].copy()
label = int(self.labels[index])
image = Image.fromarray(image[:,:,::-1].astype('uint8'), 'RGB')
if self.transform is not None:
image = self.transform( image )
return image, label
def decompress(path, output):
with open(output, 'rb') as f:
array = pkl.load(f, encoding='bytes')
images = np.zeros([len(array), 84, 84, 3], dtype=np.uint8)
for ii, item in enumerate(array):
im = cv2.imdecode(item, 1)
images[ii] = im
np.savez(path, images=images)

View File

@@ -1,7 +1,5 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .MetaBatchSampler import MetaBatchSampler
from .TieredImageNet import TieredImageNet
from .LanguageDataset import Corpus
from .get_dataset_with_transform import get_datasets
from .SearchDatasetWrap import SearchDataset

View File

@@ -3,75 +3,181 @@
##################################################
import os, sys, torch
import os.path as osp
import numpy as np
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from utils import Cutout
from .TieredImageNet import TieredImageNet
from PIL import Image
from .DownsampledImageNet import ImageNet16
Dataset2Class = {'cifar10' : 10,
'cifar100': 100,
'tiered' : -1,
'imagenet-1k-s':1000,
'imagenet-1k' : 1000,
'imagenet-100': 100}
'ImageNet16' : 1000,
'ImageNet16-150': 150,
'ImageNet16-120': 120,
'ImageNet16-200': 200}
class CUTOUT(object):
def __init__(self, length):
self.length = length
def __repr__(self):
return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
imagenet_pca = {
'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
'eigvec': np.asarray([
[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203],
])
}
class Lighting(object):
def __init__(self, alphastd,
eigval=imagenet_pca['eigval'],
eigvec=imagenet_pca['eigvec']):
self.alphastd = alphastd
assert eigval.shape == (3,)
assert eigvec.shape == (3, 3)
self.eigval = eigval
self.eigvec = eigvec
def __call__(self, img):
if self.alphastd == 0.:
return img
rnd = np.random.randn(3) * self.alphastd
rnd = rnd.astype('float32')
v = rnd
old_dtype = np.asarray(img).dtype
v = v * self.eigval
v = v.reshape((3, 1))
inc = np.dot(self.eigvec, v).reshape((3,))
img = np.add(img, inc)
if old_dtype == np.uint8:
img = np.clip(img, 0, 255)
img = Image.fromarray(img.astype(old_dtype), 'RGB')
return img
def __repr__(self):
return self.__class__.__name__ + '()'
def get_datasets(name, root, cutout):
# Mean + Std
if name == 'cifar10':
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
elif name == 'cifar100':
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif name == 'tiered':
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif name.startswith('imagenet-1k'):
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
elif name == 'imagenet-1k' or name == 'imagenet-100':
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
else: raise TypeError("Unknow dataset : {:}".format(name))
elif name.startswith('ImageNet16'):
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
std = [x / 255 for x in [63.22, 61.26 , 65.09]]
else:
raise TypeError("Unknow dataset : {:}".format(name))
# Data Argumentation
if name == 'cifar10' or name == 'cifar100':
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
transforms.Normalize(mean, std)]
if cutout > 0 : lists += [Cutout(cutout)]
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
if cutout > 0 : lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
xshape = (1, 3, 32, 32)
elif name.startswith('ImageNet16'):
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
if cutout > 0 : lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
xshape = (1, 3, 16, 16)
elif name == 'tiered':
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(80, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
if cutout > 0 : lists += [Cutout(cutout)]
if cutout > 0 : lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)])
elif name == 'imagenet-1k' or name == 'imagenet-100':
xshape = (1, 3, 32, 32)
elif name.startswith('imagenet-1k'):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
if name == 'imagenet-1k':
xlists = [transforms.RandomResizedCrop(224)]
xlists.append(
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.2),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
else: raise TypeError("Unknow dataset : {:}".format(name))
hue=0.2))
xlists.append( Lighting(0.1))
elif name == 'imagenet-1k-s':
xlists = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))]
else: raise ValueError('invalid name : {:}'.format(name))
xlists.append( transforms.RandomHorizontalFlip(p=0.5) )
xlists.append( transforms.ToTensor() )
xlists.append( normalize )
train_transform = transforms.Compose(xlists)
test_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
xshape = (1, 3, 224, 224)
else:
raise TypeError("Unknow dataset : {:}".format(name))
if name == 'cifar10':
train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR10 (root, train=False, transform=test_transform , download=True)
assert len(train_data) == 50000 and len(test_data) == 10000
elif name == 'cifar100':
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
elif name == 'imagenet-1k' or name == 'imagenet-100':
assert len(train_data) == 50000 and len(test_data) == 10000
elif name.startswith('imagenet-1k'):
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
assert len(train_data) == 1281167 and len(test_data) == 50000, 'invalid number of images : {:} & {:} vs {:} & {:}'.format(len(train_data), len(test_data), 1281167, 50000)
elif name == 'ImageNet16':
train_data = ImageNet16(root, True , train_transform)
test_data = ImageNet16(root, False, test_transform)
assert len(train_data) == 1281167 and len(test_data) == 50000
elif name == 'ImageNet16-120':
train_data = ImageNet16(root, True , train_transform, 120)
test_data = ImageNet16(root, False, test_transform , 120)
assert len(train_data) == 151700 and len(test_data) == 6000
elif name == 'ImageNet16-150':
train_data = ImageNet16(root, True , train_transform, 150)
test_data = ImageNet16(root, False, test_transform , 150)
assert len(train_data) == 190272 and len(test_data) == 7500
elif name == 'ImageNet16-200':
train_data = ImageNet16(root, True , train_transform, 200)
test_data = ImageNet16(root, False, test_transform , 200)
assert len(train_data) == 254775 and len(test_data) == 10000
else: raise TypeError("Unknow dataset : {:}".format(name))
class_num = Dataset2Class[name]
return train_data, test_data, class_num
return train_data, test_data, xshape, class_num
#if __name__ == '__main__':
# train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1)
# import pdb; pdb.set_trace()

View File

@@ -0,0 +1 @@
from .point_meta import PointMeta2V, apply_affine2point, apply_boundary

View File

@@ -0,0 +1,116 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import copy, math, torch, numpy as np
from xvision import normalize_points
from xvision import denormalize_points
class PointMeta():
# points : 3 x num_pts (x, y, oculusion)
# image_size: original [width, height]
def __init__(self, num_point, points, box, image_path, dataset_name):
self.num_point = num_point
if box is not None:
assert (isinstance(box, tuple) or isinstance(box, list)) and len(box) == 4
self.box = torch.Tensor(box)
else: self.box = None
if points is None:
self.points = points
else:
assert len(points.shape) == 2 and points.shape[0] == 3 and points.shape[1] == self.num_point, 'The shape of point is not right : {}'.format( points )
self.points = torch.Tensor(points.copy())
self.image_path = image_path
self.datasets = dataset_name
def __repr__(self):
if self.box is None: boxstr = 'None'
else : boxstr = 'box=[{:.1f}, {:.1f}, {:.1f}, {:.1f}]'.format(*self.box.tolist())
return ('{name}(points={num_point}, '.format(name=self.__class__.__name__, **self.__dict__) + boxstr + ')')
def get_box(self, return_diagonal=False):
if self.box is None: return None
if return_diagonal == False:
return self.box.clone()
else:
W = (self.box[2]-self.box[0]).item()
H = (self.box[3]-self.box[1]).item()
return math.sqrt(H*H+W*W)
def get_points(self, ignore_indicator=False):
if ignore_indicator: last = 2
else : last = 3
if self.points is not None: return self.points.clone()[:last, :]
else : return torch.zeros((last, self.num_point))
def is_none(self):
#assert self.box is not None, 'The box should not be None'
return self.points is None
#if self.box is None: return True
#else : return self.points is None
def copy(self):
return copy.deepcopy(self)
def visiable_pts_num(self):
with torch.no_grad():
ans = self.points[2,:] > 0
ans = torch.sum(ans)
ans = ans.item()
return ans
def special_fun(self, indicator):
if indicator == '68to49': # For 300W or 300VW, convert the default 68 points to 49 points.
assert self.num_point == 68, 'num-point must be 68 vs. {:}'.format(self.num_point)
self.num_point = 49
out = torch.ones((68), dtype=torch.uint8)
out[[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,60,64]] = 0
if self.points is not None: self.points = self.points.clone()[:, out]
else:
raise ValueError('Invalid indicator : {:}'.format( indicator ))
def apply_horizontal_flip(self):
#self.points[0, :] = width - self.points[0, :] - 1
# Mugsy spefic or Synthetic
if self.datasets.startswith('HandsyROT'):
ori = np.array(list(range(0, 42)))
pos = np.array(list(range(21,42)) + list(range(0,21)))
self.points[:, pos] = self.points[:, ori]
elif self.datasets.startswith('face68'):
ori = np.array(list(range(0, 68)))
pos = np.array([17,16,15,14,13,12,11,10, 9, 8,7,6,5,4,3,2,1, 27,26,25,24,23,22,21,20,19,18, 28,29,30,31, 36,35,34,33,32, 46,45,44,43,48,47, 40,39,38,37,42,41, 55,54,53,52,51,50,49,60,59,58,57,56,65,64,63,62,61,68,67,66])-1
self.points[:, ori] = self.points[:, pos]
else:
raise ValueError('Does not support {:}'.format(self.datasets))
# shape = (H,W)
def apply_affine2point(points, theta, shape):
assert points.size(0) == 3, 'invalid points shape : {:}'.format(points.size())
with torch.no_grad():
ok_points = points[2,:] == 1
assert torch.sum(ok_points).item() > 0, 'there is no visiable point'
points[:2,:] = normalize_points(shape, points[:2,:])
norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float()
trans_points, ___ = torch.gesv(points[:, ok_points], theta)
norm_trans_points[:, ok_points] = trans_points
return norm_trans_points
def apply_boundary(norm_trans_points):
with torch.no_grad():
norm_trans_points = norm_trans_points.clone()
oks = torch.stack((norm_trans_points[0]>-1, norm_trans_points[0]<1, norm_trans_points[1]>-1, norm_trans_points[1]<1, norm_trans_points[2]>0))
oks = torch.sum(oks, dim=0) == 5
norm_trans_points[2, :] = oks
return norm_trans_points

View File

@@ -1,10 +0,0 @@
import os, sys, torch
from LanguageDataset import SentCorpus, BatchSentLoader
if __name__ == '__main__':
path = '../../data/data/penn'
corpus = SentCorpus( path )
loader = BatchSentLoader(corpus.test, 10)
for i, d in enumerate(loader):
print('{:} :: {:}'.format(i, d.size()))

View File

@@ -1,33 +0,0 @@
import os, sys, torch
import torchvision.transforms as transforms
from TieredImageNet import TieredImageNet
from MetaBatchSampler import MetaBatchSampler
root_dir = os.environ['TORCH_HOME'] + '/tiered-imagenet'
print ('root : {:}'.format(root_dir))
means, stds = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(84, padding=8), transforms.ToTensor(), transforms.Normalize(means, stds)]
transform = transforms.Compose(lists)
dataset = TieredImageNet(root_dir, 'val-test', transform)
image, label = dataset[111]
print ('image shape = {:}, label = {:}'.format(image.size(), label))
print ('image : min = {:}, max = {:} ||| label : {:}'.format(image.min(), image.max(), label))
sampler = MetaBatchSampler(dataset.labels, 250, 100, 10)
dataloader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler)
print ('the length of dataset : {:}'.format( len(dataset) ))
print ('the length of loader : {:}'.format( len(dataloader) ))
for images, labels in dataloader:
print ('images : {:}'.format( images.size() ))
print ('labels : {:}'.format( labels.size() ))
for i in range(3):
print ('image-value-[{:}] : {:} ~ {:}, mean={:}, std={:}'.format(i, images[:,i].min(), images[:,i].max(), images[:,i].mean(), images[:,i].std()))
print('-----')

View File

@@ -0,0 +1,14 @@
def test_imagenet_data(imagenet):
total_length = len(imagenet)
assert total_length == 1281166 or total_length == 50000, 'The length of ImageNet is wrong : {}'.format(total_length)
map_id = {}
for index in range(total_length):
path, target = imagenet.imgs[index]
folder, image_name = os.path.split(path)
_, folder = os.path.split(folder)
if folder not in map_id:
map_id[folder] = target
else:
assert map_id[folder] == target, 'Class : {} is not {}'.format(folder, target)
assert image_name.find(folder) == 0, '{} is wrong.'.format(path)
print ('Check ImageNet Dataset OK')

View File

@@ -0,0 +1,7 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .logger import Logger
from .print_logger import PrintLogger
from .meter import AverageMeter
from .time_utils import time_for_file, time_string, time_string_short, time_print, convert_size2str, convert_secs2time

140
lib/log_utils/logger.py Normal file
View File

@@ -0,0 +1,140 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from pathlib import Path
import importlib, warnings
import os, sys, time, numpy as np
if sys.version_info.major == 2: # Python 2.x
from StringIO import StringIO as BIO
else: # Python 3.x
from io import BytesIO as BIO
if importlib.util.find_spec('tensorflow'):
import tensorflow as tf
class Logger(object):
def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False):
"""Create a summary writer logging to log_dir."""
self.seed = int(seed)
self.log_dir = Path(log_dir)
self.model_dir = Path(log_dir) / 'checkpoint'
self.log_dir.mkdir (parents=True, exist_ok=True)
if create_model_dir:
self.model_dir.mkdir(parents=True, exist_ok=True)
#self.meta_dir.mkdir(mode=0o775, parents=True, exist_ok=True)
self.use_tf = bool(use_tf)
self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h', time.gmtime(time.time()) )))
#self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h-at-%H:%M:%S', time.gmtime(time.time()) )))
self.logger_path = self.log_dir / 'seed-{:}-T-{:}.log'.format(self.seed, time.strftime('%d-%h-at-%H-%M-%S', time.gmtime(time.time())))
self.logger_file = open(self.logger_path, 'w')
if self.use_tf:
self.tensorboard_dir.mkdir(mode=0o775, parents=True, exist_ok=True)
self.writer = tf.summary.FileWriter(str(self.tensorboard_dir))
else:
self.writer = None
def __repr__(self):
return ('{name}(dir={log_dir}, use-tf={use_tf}, writer={writer})'.format(name=self.__class__.__name__, **self.__dict__))
def path(self, mode):
valids = ('model', 'best', 'info', 'log')
if mode == 'model': return self.model_dir / 'seed-{:}-basic.pth'.format(self.seed)
elif mode == 'best' : return self.model_dir / 'seed-{:}-best.pth'.format(self.seed)
elif mode == 'info' : return self.log_dir / 'seed-{:}-last-info.pth'.format(self.seed)
elif mode == 'log' : return self.log_dir
else: raise TypeError('Unknow mode = {:}, valid modes = {:}'.format(mode, valids))
def extract_log(self):
return self.logger_file
def close(self):
self.logger_file.close()
if self.writer is not None:
self.writer.close()
def log(self, string, save=True, stdout=False):
if stdout:
sys.stdout.write(string); sys.stdout.flush()
else:
print (string)
if save:
self.logger_file.write('{:}\n'.format(string))
self.logger_file.flush()
def scalar_summary(self, tags, values, step):
"""Log a scalar variable."""
if not self.use_tf:
warnings.warn('Do set use-tensorflow installed but call scalar_summary')
else:
assert isinstance(tags, list) == isinstance(values, list), 'Type : {:} vs {:}'.format(type(tags), type(values))
if not isinstance(tags, list):
tags, values = [tags], [values]
for tag, value in zip(tags, values):
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
self.writer.flush()
def image_summary(self, tag, images, step):
"""Log a list of images."""
import scipy
if not self.use_tf:
warnings.warn('Do set use-tensorflow installed but call scalar_summary')
return
img_summaries = []
for i, img in enumerate(images):
# Write the image to a string
try:
s = StringIO()
except:
s = BytesIO()
scipy.misc.toimage(img).save(s, format="png")
# Create an Image object
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
height=img.shape[0],
width=img.shape[1])
# Create a Summary value
img_summaries.append(tf.Summary.Value(tag='{}/{}'.format(tag, i), image=img_sum))
# Create and write Summary
summary = tf.Summary(value=img_summaries)
self.writer.add_summary(summary, step)
self.writer.flush()
def histo_summary(self, tag, values, step, bins=1000):
"""Log a histogram of the tensor of values."""
if not self.use_tf: raise ValueError('Do not have tensorflow')
import tensorflow as tf
# Create a histogram using numpy
counts, bin_edges = np.histogram(values, bins=bins)
# Fill the fields of the histogram proto
hist = tf.HistogramProto()
hist.min = float(np.min(values))
hist.max = float(np.max(values))
hist.num = int(np.prod(values.shape))
hist.sum = float(np.sum(values))
hist.sum_squares = float(np.sum(values**2))
# Drop the start of the first bin
bin_edges = bin_edges[1:]
# Add bin edges and counts
for edge in bin_edges:
hist.bucket_limit.append(edge)
for c in counts:
hist.bucket.append(c)
# Create and write Summary
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
self.writer.add_summary(summary, step)
self.writer.flush()

View File

@@ -1,26 +1,29 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time
import time, sys
import numpy as np
import random
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.val = 0.0
self.avg = 0.0
self.sum = 0.0
self.count = 0.0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
self.avg = self.sum / self.count
def __repr__(self):
return ('{name}(val={val}, avg={avg}, count={count})'.format(name=self.__class__.__name__, **self.__dict__))
class RecorderMeter(object):
@@ -29,12 +32,11 @@ class RecorderMeter(object):
self.reset(total_epoch)
def reset(self, total_epoch):
assert total_epoch > 0
assert total_epoch > 0, 'total_epoch should be greater than 0 vs {:}'.format(total_epoch)
self.total_epoch = total_epoch
self.current_epoch = 0
self.epoch_losses = np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
self.epoch_losses = self.epoch_losses - 1
self.epoch_accuracy= np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
self.epoch_accuracy= self.epoch_accuracy
@@ -98,43 +100,3 @@ class RecorderMeter(object):
fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
print ('---- save figure {} into {}'.format(title, save_path))
plt.close(fig)
def print_log(print_string, log):
print ("{:}".format(print_string))
if log is not None:
log.write('{}\n'.format(print_string))
log.flush()
def time_file_str():
ISOTIMEFORMAT='%Y-%m-%d'
string = '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
return string + '-{}'.format(random.randint(1, 10000))
def time_string():
ISOTIMEFORMAT='%Y-%m-%d-%X'
string = '[{}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
return string
def convert_secs2time(epoch_time, return_str=False):
need_hour = int(epoch_time / 3600)
need_mins = int((epoch_time - 3600*need_hour) / 60)
need_secs = int(epoch_time - 3600*need_hour - 60*need_mins)
if return_str == False:
return need_hour, need_mins, need_secs
else:
return '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
def test_imagenet_data(imagenet):
total_length = len(imagenet)
assert total_length == 1281166 or total_length == 50000, 'The length of ImageNet is wrong : {}'.format(total_length)
map_id = {}
for index in range(total_length):
path, target = imagenet.imgs[index]
folder, image_name = os.path.split(path)
_, folder = os.path.split(folder)
if folder not in map_id:
map_id[folder] = target
else:
assert map_id[folder] == target, 'Class : {} is not {}'.format(folder, target)
assert image_name.find(folder) == 0, '{} is wrong.'.format(path)
print ('Check ImageNet Dataset OK')

View File

@@ -0,0 +1,18 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import importlib, warnings
import os, sys, time, numpy as np
class PrintLogger(object):
def __init__(self):
"""Create a summary writer logging to log_dir."""
self.name = 'PrintLogger'
def log(self, string):
print (string)
def close(self):
print ('-'*30 + ' close printer ' + '-'*30)

View File

@@ -0,0 +1,52 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import time, sys
import numpy as np
def time_for_file():
ISOTIMEFORMAT='%d-%h-at-%H-%M-%S'
return '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
def time_string():
ISOTIMEFORMAT='%Y-%m-%d %X'
string = '[{}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
return string
def time_string_short():
ISOTIMEFORMAT='%Y%m%d'
string = '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
return string
def time_print(string, is_print=True):
if (is_print):
print('{} : {}'.format(time_string(), string))
def convert_size2str(torch_size):
dims = len(torch_size)
string = '['
for idim in range(dims):
string = string + ' {}'.format(torch_size[idim])
return string + ']'
def convert_secs2time(epoch_time, return_str=False):
need_hour = int(epoch_time / 3600)
need_mins = int((epoch_time - 3600*need_hour) / 60)
need_secs = int(epoch_time - 3600*need_hour - 60*need_mins)
if return_str:
str = '[{:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
return str
else:
return need_hour, need_mins, need_secs
def print_log(print_string, log):
#if isinstance(log, Logger): log.log('{:}'.format(print_string))
if hasattr(log, 'log'): log.log('{:}'.format(print_string))
else:
print("{:}".format(print_string))
if log is not None:
log.write('{:}\n'.format(print_string))
log.flush()

105
lib/models/CifarDenseNet.py Normal file
View File

@@ -0,0 +1,105 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, torch
import torch.nn as nn
import torch.nn.functional as F
from .initialization import initialize_resnet
class Bottleneck(nn.Module):
def __init__(self, nChannels, growthRate):
super(Bottleneck, self).__init__()
interChannels = 4*growthRate
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(interChannels)
self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
out = torch.cat((x, out), 1)
return out
class SingleLayer(nn.Module):
def __init__(self, nChannels, growthRate):
super(SingleLayer, self).__init__()
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = torch.cat((x, out), 1)
return out
class Transition(nn.Module):
def __init__(self, nChannels, nOutChannels):
super(Transition, self).__init__()
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = F.avg_pool2d(out, 2)
return out
class DenseNet(nn.Module):
def __init__(self, growthRate, depth, reduction, nClasses, bottleneck):
super(DenseNet, self).__init__()
if bottleneck: nDenseBlocks = int( (depth-4) / 6 )
else : nDenseBlocks = int( (depth-4) / 3 )
self.message = 'CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}'.format('bottleneck' if bottleneck else 'basic', depth, reduction, growthRate, nClasses)
nChannels = 2*growthRate
self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False)
self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
nChannels += nDenseBlocks*growthRate
nOutChannels = int(math.floor(nChannels*reduction))
self.trans1 = Transition(nChannels, nOutChannels)
nChannels = nOutChannels
self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
nChannels += nDenseBlocks*growthRate
nOutChannels = int(math.floor(nChannels*reduction))
self.trans2 = Transition(nChannels, nOutChannels)
nChannels = nOutChannels
self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
nChannels += nDenseBlocks*growthRate
self.act = nn.Sequential(
nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True),
nn.AvgPool2d(8))
self.fc = nn.Linear(nChannels, nClasses)
self.apply(initialize_resnet)
def get_message(self):
return self.message
def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck):
layers = []
for i in range(int(nDenseBlocks)):
if bottleneck:
layers.append(Bottleneck(nChannels, growthRate))
else:
layers.append(SingleLayer(nChannels, growthRate))
nChannels += growthRate
return nn.Sequential(*layers)
def forward(self, inputs):
out = self.conv1( inputs )
out = self.trans1(self.dense1(out))
out = self.trans2(self.dense2(out))
out = self.dense3(out)
features = self.act(out)
features = features.view(features.size(0), -1)
out = self.fc(features)
return features, out

160
lib/models/CifarResNet.py Normal file
View File

@@ -0,0 +1,160 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
from .initialization import initialize_resnet
from .SharedUtils import additive_func
class Downsample(nn.Module):
def __init__(self, nIn, nOut, stride):
super(Downsample, self).__init__()
assert stride == 2 and nOut == 2*nIn, 'stride:{} IO:{},{}'.format(stride, nIn, nOut)
self.in_dim = nIn
self.out_dim = nOut
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=1, padding=0, bias=False)
def forward(self, x):
x = self.avg(x)
out = self.conv(x)
return out
class ConvBNReLU(nn.Module):
def __init__(self, nIn, nOut, kernel, stride, padding, bias, relu):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, bias=bias)
self.bn = nn.BatchNorm2d(nOut)
if relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
self.out_dim = nOut
self.num_conv = 1
def forward(self, x):
conv = self.conv( x )
bn = self.bn( conv )
if self.relu: return self.relu( bn )
else : return bn
class ResNetBasicblock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, True)
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, False)
if stride == 2:
self.downsample = Downsample(inplanes, planes, stride)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, False)
else:
self.downsample = None
self.out_dim = planes
self.num_conv = 2
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = additive_func(residual, basicblock)
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, True)
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, True)
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, False)
if stride == 2:
self.downsample = Downsample(inplanes, planes*self.expansion, stride)
elif inplanes != planes*self.expansion:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, False)
else:
self.downsample = None
self.out_dim = planes * self.expansion
self.num_conv = 3
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = additive_func(residual, bottleneck)
return F.relu(out, inplace=True)
class CifarResNet(nn.Module):
def __init__(self, block_name, depth, num_classes, zero_init_residual):
super(CifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
self.message = 'CifarResNet : Block : {:}, Depth : {:}, Layers for each block : {:}'.format(block_name, depth, layer_blocks)
self.num_classes = num_classes
self.channels = [16]
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, True) ] )
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(module.out_dim, num_classes)
assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,94 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .initialization import initialize_resnet
class WideBasicblock(nn.Module):
def __init__(self, inplanes, planes, stride, dropout=False):
super(WideBasicblock, self).__init__()
self.bn_a = nn.BatchNorm2d(inplanes)
self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn_b = nn.BatchNorm2d(planes)
if dropout:
self.dropout = nn.Dropout2d(p=0.5, inplace=True)
else:
self.dropout = None
self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
if inplanes != planes:
self.downsample = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False)
else:
self.downsample = None
def forward(self, x):
basicblock = self.bn_a(x)
basicblock = F.relu(basicblock)
basicblock = self.conv_a(basicblock)
basicblock = self.bn_b(basicblock)
basicblock = F.relu(basicblock)
if self.dropout is not None:
basicblock = self.dropout(basicblock)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
x = self.downsample(x)
return x + basicblock
class CifarWideResNet(nn.Module):
"""
ResNet optimized for the Cifar dataset, as specified in
https://arxiv.org/abs/1512.03385.pdf
"""
def __init__(self, depth, widen_factor, num_classes, dropout):
super(CifarWideResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
assert (depth - 4) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 4) // 6
print ('CifarPreResNet : Depth : {} , Layers for each block : {}'.format(depth, layer_blocks))
self.num_classes = num_classes
self.dropout = dropout
self.conv_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.message = 'Wide ResNet : depth={:}, widen_factor={:}, class={:}'.format(depth, widen_factor, num_classes)
self.inplanes = 16
self.stage_1 = self._make_layer(WideBasicblock, 16*widen_factor, layer_blocks, 1)
self.stage_2 = self._make_layer(WideBasicblock, 32*widen_factor, layer_blocks, 2)
self.stage_3 = self._make_layer(WideBasicblock, 64*widen_factor, layer_blocks, 2)
self.lastact = nn.Sequential(nn.BatchNorm2d(64*widen_factor), nn.ReLU(inplace=True))
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(64*widen_factor, num_classes)
self.apply(initialize_resnet)
def get_message(self):
return self.message
def _make_layer(self, block, planes, blocks, stride):
layers = []
layers.append(block(self.inplanes, planes, stride, self.dropout))
self.inplanes = planes
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, 1, self.dropout))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv_3x3(x)
x = self.stage_1(x)
x = self.stage_2(x)
x = self.stage_3(x)
x = self.lastact(x)
x = self.avgpool(x)
features = x.view(x.size(0), -1)
outs = self.classifier(features)
return features, outs

View File

@@ -0,0 +1,172 @@
# Deep Residual Learning for Image Recognition, CVPR 2016
import torch.nn as nn
from .initialization import initialize_resnet
def conv3x3(in_planes, out_planes, stride=1, groups=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64):
super(BasicBlock, self).__init__()
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64):
super(Bottleneck, self).__init__()
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = conv3x3(width, width, stride, groups)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block_name, layers, deep_stem, num_classes, zero_init_residual, groups, width_per_group):
super(ResNet, self).__init__()
#planes = [int(width_per_group * groups * 2 ** i) for i in range(4)]
if block_name == 'BasicBlock' : block= BasicBlock
elif block_name == 'Bottleneck': block= Bottleneck
else : raise ValueError('invalid block-name : {:}'.format(block_name))
if not deep_stem:
self.conv = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64), nn.ReLU(inplace=True))
else:
self.conv = nn.Sequential(
nn.Conv2d( 3, 32, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.inplanes = 64
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64 , layers[0], stride=1, groups=groups, base_width=width_per_group)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, groups=groups, base_width=width_per_group)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, groups=groups, base_width=width_per_group)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, groups=groups, base_width=width_per_group)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
self.message = 'block = {:}, layers = {:}, deep_stem = {:}, num_classes = {:}'.format(block, layers, deep_stem, num_classes)
self.apply( initialize_resnet )
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride, groups, base_width):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
if stride == 2:
downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
conv1x1(self.inplanes, planes * block.expansion, 1),
nn.BatchNorm2d(planes * block.expansion),
)
elif stride == 1:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
nn.BatchNorm2d(planes * block.expansion),
)
else: raise ValueError('invalid stride [{:}] for downsample'.format(stride))
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, groups, base_width))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, 1, None, groups, base_width))
return nn.Sequential(*layers)
def get_message(self):
return self.message
def forward(self, x):
x = self.conv(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.fc(features)
return features, logits

101
lib/models/MobileNet.py Normal file
View File

@@ -0,0 +1,101 @@
# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
from torch import nn
from .initialization import initialize_resnet
class ConvBNReLU(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False)
self.bn = nn.BatchNorm2d(out_planes)
self.relu = nn.ReLU6(inplace=True)
def forward(self, x):
out = self.conv( x )
out = self.bn ( out )
out = self.relu( out )
return out
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
])
self.conv = nn.Sequential(*layers)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(nn.Module):
def __init__(self, num_classes, width_mult, input_channel, last_channel, block_name, dropout):
super(MobileNetV2, self).__init__()
if block_name == 'InvertedResidual':
block = InvertedResidual
else:
raise ValueError('invalid block name : {:}'.format(block_name))
inverted_residual_setting = [
# t, c, n, s
[1, 16 , 1, 1],
[6, 24 , 2, 2],
[6, 32 , 3, 2],
[6, 64 , 4, 2],
[6, 96 , 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# building first layer
input_channel = int(input_channel * width_mult)
self.last_channel = int(last_channel * max(1.0, width_mult))
features = [ConvBNReLU(3, input_channel, stride=2)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = int(c * width_mult)
for i in range(n):
stride = s if i == 0 else 1
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
# make it nn.Sequential
self.features = nn.Sequential(*features)
# building classifier
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(self.last_channel, num_classes),
)
self.message = 'MobileNetV2 : width_mult={:}, in-C={:}, last-C={:}, block={:}, dropout={:}'.format(width_mult, input_channel, last_channel, block_name, dropout)
# weight initialization
self.apply( initialize_resnet )
def get_message(self):
return self.message
def forward(self, inputs):
features = self.features(inputs)
vectors = features.mean([2, 3])
predicts = self.classifier(vectors)
return features, predicts

31
lib/models/SharedUtils.py Normal file
View File

@@ -0,0 +1,31 @@
import torch
import torch.nn as nn
def additive_func(A, B):
assert A.dim() == B.dim() and A.size(0) == B.size(0), '{:} vs {:}'.format(A.size(), B.size())
C = min(A.size(1), B.size(1))
if A.size(1) == B.size(1):
return A + B
elif A.size(1) < B.size(1):
out = B.clone()
out[:,:C] += A
return out
else:
out = A.clone()
out[:,:C] += B
return out
def change_key(key, value):
def func(m):
if hasattr(m, key):
setattr(m, key, value)
return func
def parse_channel_info(xstring):
blocks = xstring.split(' ')
blocks = [x.split('-') for x in blocks]
blocks = [[int(_) for _ in x] for x in blocks]
return blocks

133
lib/models/ShuffleNetV2.py Normal file
View File

@@ -0,0 +1,133 @@
import functools
import torch
import torch.nn as nn
__all__ = ['ShuffleNetV2']
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride):
super(InvertedResidual, self).__init__()
if not (1 <= stride <= 3):
raise ValueError('illegal stride value')
self.stride = stride
branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1)
pw_conv11 = functools.partial(nn.Conv2d, kernel_size=1, stride=1, padding=0, bias=False)
dw_conv33 = functools.partial(self.depthwise_conv, kernel_size=3, stride=self.stride, padding=1)
if self.stride > 1:
self.branch1 = nn.Sequential(
dw_conv33(inp, inp),
nn.BatchNorm2d(inp),
pw_conv11(inp, branch_features),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
self.branch2 = nn.Sequential(
pw_conv11(inp if (self.stride > 1) else branch_features, branch_features),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
dw_conv33(branch_features, branch_features),
nn.BatchNorm2d(branch_features),
pw_conv11(branch_features, branch_features),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
@staticmethod
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
def forward(self, x):
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
out = channel_shuffle(out, 2)
return out
class ShuffleNetV2(nn.Module):
def __init__(self, num_classes, stages):
super(ShuffleNetV2, self).__init__()
self.stage_out_channels = stages
assert len(stages) == 5, 'invalid stages : {:}'.format(stages)
self.message = 'stages: ' + ' '.join([str(x) for x in stages])
input_channels = 3
output_channels = self.stage_out_channels[0]
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
input_channels = output_channels
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
stage_names = ['stage{:}'.format(i) for i in [2, 3, 4]]
stage_repeats = [4, 8, 4]
for name, repeats, output_channels in zip(
stage_names, stage_repeats, self.stage_out_channels[1:]):
seq = [InvertedResidual(input_channels, output_channels, 2)]
for i in range(repeats - 1):
seq.append(InvertedResidual(output_channels, output_channels, 1))
setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels
output_channels = self.stage_out_channels[-1]
self.conv5 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
self.fc = nn.Linear(output_channels, num_classes)
def get_message(self):
return self.message
def forward(self, inputs):
x = self.conv1( inputs )
x = self.maxpool(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.conv5(x)
features = x.mean([2, 3]) # globalpool
predicts = self.fc(features)
return features, predicts
#@staticmethod
#def _getStages(mult):
# stages = {
# '0.5': [24, 48, 96 , 192, 1024],
# '1.0': [24, 116, 232, 464, 1024],
# '1.5': [24, 176, 352, 704, 1024],
# '2.0': [24, 244, 488, 976, 2048],
# }
# return stages[str(mult)]

123
lib/models/__init__.py Normal file
View File

@@ -0,0 +1,123 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
from os import path as osp
# our modules
from config_utils import dict2config
from .SharedUtils import change_key
from .clone_weights import init_from_model
def get_cifar_models(config):
from .CifarResNet import CifarResNet
from .CifarDenseNet import DenseNet
from .CifarWideResNet import CifarWideResNet
super_type = getattr(config, 'super_type', 'basic')
if super_type == 'basic':
if config.arch == 'resnet':
return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual)
elif config.arch == 'densenet':
return DenseNet(config.growthRate, config.depth, config.reduction, config.class_num, config.bottleneck)
elif config.arch == 'wideresnet':
return CifarWideResNet(config.depth, config.wide_factor, config.class_num, config.dropout)
else:
raise ValueError('invalid module type : {:}'.format(config.arch))
elif super_type.startswith('infer'):
from .infers import InferWidthCifarResNet
from .infers import InferDepthCifarResNet
from .infers import InferCifarResNet
assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type)
infer_mode = super_type.split('-')[1]
if infer_mode == 'width':
return InferWidthCifarResNet(config.module, config.depth, config.xchannels, config.class_num, config.zero_init_residual)
elif infer_mode == 'depth':
return InferDepthCifarResNet(config.module, config.depth, config.xblocks, config.class_num, config.zero_init_residual)
elif infer_mode == 'shape':
return InferCifarResNet(config.module, config.depth, config.xblocks, config.xchannels, config.class_num, config.zero_init_residual)
else:
raise ValueError('invalid infer-mode : {:}'.format(infer_mode))
else:
raise ValueError('invalid super-type : {:}'.format(super_type))
def get_imagenet_models(config):
super_type = getattr(config, 'super_type', 'basic')
if super_type == 'basic':
return get_imagenet_models_basic(config)
# NAS searched architecture
elif super_type.startswith('infer'):
assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type)
infer_mode = super_type.split('-')[1]
if infer_mode == 'shape':
from .infers import InferImagenetResNet
from .infers import InferMobileNetV2
if config.arch == 'resnet':
return InferImagenetResNet(config.block_name, config.layers, config.xblocks, config.xchannels, config.deep_stem, config.class_num, config.zero_init_residual)
elif config.arch == "MobileNetV2":
return InferMobileNetV2(config.class_num, config.xchannels, config.xblocks, config.dropout)
else:
raise ValueError('invalid arch-mode : {:}'.format(config.arch))
else:
raise ValueError('invalid infer-mode : {:}'.format(infer_mode))
else:
raise ValueError('invalid super-type : {:}'.format(super_type))
def get_imagenet_models_basic(config):
from .ImagenetResNet import ResNet
from .MobileNet import MobileNetV2
from .ShuffleNetV2 import ShuffleNetV2
if config.arch == 'resnet':
return ResNet(config.block_name, config.layers, config.deep_stem, config.class_num, config.zero_init_residual, config.groups, config.width_per_group)
elif config.arch == 'MobileNetV2':
return MobileNetV2(config.class_num, config.width_mult, config.input_channel, config.last_channel, config.block_name, config.dropout)
elif config.arch == 'ShuffleNetV2':
return ShuffleNetV2(config.class_num, config.stages)
else:
raise ValueError('invalid arch : {:}'.format( config.arch ))
def obtain_model(config):
if config.dataset == 'cifar':
return get_cifar_models(config)
elif config.dataset == 'imagenet':
return get_imagenet_models(config)
else:
raise ValueError('invalid dataset in the model config : {:}'.format(config))
def obtain_search_model(config):
if config.dataset == 'cifar':
if config.arch == 'resnet':
from .searchs import SearchWidthCifarResNet
from .searchs import SearchDepthCifarResNet
from .searchs import SearchShapeCifarResNet
if config.search_mode == 'width':
return SearchWidthCifarResNet(config.module, config.depth, config.class_num)
elif config.search_mode == 'depth':
return SearchDepthCifarResNet(config.module, config.depth, config.class_num)
elif config.search_mode == 'shape':
return SearchShapeCifarResNet(config.module, config.depth, config.class_num)
else: raise ValueError('invalid search mode : {:}'.format(config.search_mode))
else:
raise ValueError('invalid arch : {:} for dataset [{:}]'.format(config.arch, config.dataset))
elif config.dataset == 'imagenet':
from .searchs import SearchShapeImagenetResNet
assert config.search_mode == 'shape', 'invalid search-mode : {:}'.format( config.search_mode )
if config.arch == 'resnet':
return SearchShapeImagenetResNet(config.block_name, config.layers, config.deep_stem, config.class_num)
else:
raise ValueError('invalid model config : {:}'.format(config))
else:
raise ValueError('invalid dataset in the model config : {:}'.format(config))
def load_net_from_checkpoint(checkpoint):
assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint)
checkpoint = torch.load(checkpoint)
model_config = dict2config(checkpoint['model-config'], None)
model = obtain_model(model_config)
model.load_state_dict(checkpoint['base-model'])
return model

View File

@@ -0,0 +1,62 @@
import torch
import torch.nn as nn
def copy_conv(module, init):
assert isinstance(module, nn.Conv2d), 'invalid module : {:}'.format(module)
assert isinstance(init , nn.Conv2d), 'invalid module : {:}'.format(init)
new_i, new_o = module.in_channels, module.out_channels
module.weight.copy_( init.weight.detach()[:new_o, :new_i] )
if module.bias is not None:
module.bias.copy_( init.bias.detach()[:new_o] )
def copy_bn (module, init):
assert isinstance(module, nn.BatchNorm2d), 'invalid module : {:}'.format(module)
assert isinstance(init , nn.BatchNorm2d), 'invalid module : {:}'.format(init)
num_features = module.num_features
if module.weight is not None:
module.weight.copy_( init.weight.detach()[:num_features] )
if module.bias is not None:
module.bias.copy_( init.bias.detach()[:num_features] )
if module.running_mean is not None:
module.running_mean.copy_( init.running_mean.detach()[:num_features] )
if module.running_var is not None:
module.running_var.copy_( init.running_var.detach()[:num_features] )
def copy_fc (module, init):
assert isinstance(module, nn.Linear), 'invalid module : {:}'.format(module)
assert isinstance(init , nn.Linear), 'invalid module : {:}'.format(init)
new_i, new_o = module.in_features, module.out_features
module.weight.copy_( init.weight.detach()[:new_o, :new_i] )
if module.bias is not None:
module.bias.copy_( init.bias.detach()[:new_o] )
def copy_base(module, init):
assert type(module).__name__ in ['ConvBNReLU', 'Downsample'], 'invalid module : {:}'.format(module)
assert type( init).__name__ in ['ConvBNReLU', 'Downsample'], 'invalid module : {:}'.format( init)
if module.conv is not None:
copy_conv(module.conv, init.conv)
if module.bn is not None:
copy_bn (module.bn, init.bn)
def copy_basic(module, init):
copy_base(module.conv_a, init.conv_a)
copy_base(module.conv_b, init.conv_b)
if module.downsample is not None:
if init.downsample is not None:
copy_base(module.downsample, init.downsample)
#else:
# import pdb; pdb.set_trace()
def init_from_model(network, init_model):
with torch.no_grad():
copy_fc(network.classifier, init_model.classifier)
for base, target in zip(init_model.layers, network.layers):
assert type(base).__name__ == type(target).__name__, 'invalid type : {:} vs {:}'.format(base, target)
if type(base).__name__ == 'ConvBNReLU':
copy_base(target, base)
elif type(base).__name__ == 'ResNetBasicblock':
copy_basic(target, base)
else:
raise ValueError('unknown type name : {:}'.format( type(base).__name__ ))

View File

@@ -0,0 +1,166 @@
import math, torch
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
class ConvBNReLU(nn.Module):
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
residual_in = iCs[3]
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferCifarResNet(nn.Module):
def __init__(self, block_name, depth, xblocks, xchannels, num_classes, zero_init_residual):
super(InferCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks)
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.xchannels = xchannels
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 1
for stage in range(3):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
if iL + 1 == xblocks[stage]: # reach the maximum depth
out_channel = module.out_dim
for iiL in range(iL+1, layer_blocks):
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
break
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,149 @@
import math, torch
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
class ConvBNReLU(nn.Module):
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
elif inplanes != planes*self.expansion:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
else:
self.downsample = None
self.out_dim = planes*self.expansion
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferDepthCifarResNet(nn.Module):
def __init__(self, block_name, depth, xblocks, num_classes, zero_init_residual):
super(InferDepthCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks)
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
self.channels = [16]
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, planes, module.out_dim, stride)
if iL + 1 == xblocks[stage]: # reach the maximum depth
break
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.channels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,159 @@
import math, torch
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
class ConvBNReLU(nn.Module):
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
residual_in = iCs[3]
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferWidthCifarResNet(nn.Module):
def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual):
super(InferWidthCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.xchannels = xchannels
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 1
for stage in range(3):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,169 @@
import math, torch
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=True, has_relu=False)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=True, has_relu=False)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[3]
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferImagenetResNet(nn.Module):
def __init__(self, block_name, layers, xblocks, xchannels, deep_stem, num_classes, zero_init_residual):
super(InferImagenetResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'BasicBlock':
block = ResNetBasicblock
elif block_name == 'Bottleneck':
block = ResNetBottleneck
else:
raise ValueError('invalid block : {:}'.format(block_name))
assert len(xblocks) == len(layers), 'invalid layers : {:} vs xblocks : {:}'.format(layers, xblocks)
self.message = 'InferImagenetResNet : Depth : {:} -> {:}, Layers for each block : {:}'.format(sum(layers)*block.num_conv, sum(xblocks)*block.num_conv, xblocks)
self.num_classes = num_classes
self.xchannels = xchannels
if not deep_stem:
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 7, 2, 3, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 1
else:
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 2, 1, False, has_avg=False, has_bn=True, has_relu=True)
,ConvBNReLU(xchannels[1], xchannels[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 2
self.layers.append( nn.MaxPool2d(kernel_size=3, stride=2, padding=1) )
for stage, layer_blocks in enumerate(layers):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
if iL + 1 == xblocks[stage]: # reach the maximum depth
out_channel = module.out_dim
for iiL in range(iL+1, layer_blocks):
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
break
assert last_channel_idx + 1 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels))
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,119 @@
# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
from torch import nn
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func, parse_channel_info
class ConvBNReLU(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, groups, has_bn=True, has_relu=True):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False)
if has_bn: self.bn = nn.BatchNorm2d(out_planes)
else : self.bn = None
if has_relu: self.relu = nn.ReLU6(inplace=True)
else : self.relu = None
def forward(self, x):
out = self.conv( x )
if self.bn: out = self.bn ( out )
if self.relu: out = self.relu( out )
return out
class InvertedResidual(nn.Module):
def __init__(self, channels, stride, expand_ratio, additive):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2], 'invalid stride : {:}'.format(stride)
assert len(channels) in [2, 3], 'invalid channels : {:}'.format(channels)
if len(channels) == 2:
layers = []
else:
layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)]
layers.extend([
# dw
ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]),
# pw-linear
ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False),
])
self.conv = nn.Sequential(*layers)
self.additive = additive
if self.additive and channels[0] != channels[-1]:
self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False)
else:
self.shortcut = None
self.out_dim = channels[-1]
def forward(self, x):
out = self.conv(x)
# if self.additive: return additive_func(out, x)
if self.shortcut: return out + self.shortcut(x)
else : return out
class InferMobileNetV2(nn.Module):
def __init__(self, num_classes, xchannels, xblocks, dropout):
super(InferMobileNetV2, self).__init__()
block = InvertedResidual
inverted_residual_setting = [
# t, c, n, s
[1, 16 , 1, 1],
[6, 24 , 2, 2],
[6, 32 , 3, 2],
[6, 64 , 4, 2],
[6, 96 , 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
assert len(inverted_residual_setting) == len(xblocks), 'invalid number of layers : {:} vs {:}'.format(len(inverted_residual_setting), len(xblocks))
for block_num, ir_setting in zip(xblocks, inverted_residual_setting):
assert block_num <= ir_setting[2], '{:} vs {:}'.format(block_num, ir_setting)
xchannels = parse_channel_info(xchannels)
#for i, chs in enumerate(xchannels):
# if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs)
self.xchannels = xchannels
self.message = 'InferMobileNetV2 : xblocks={:}'.format(xblocks)
# building first layer
features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)]
last_channel_idx = 1
# building inverted residual blocks
for stage, (t, c, n, s) in enumerate(inverted_residual_setting):
for i in range(n):
stride = s if i == 0 else 1
additv = True if i > 0 else False
module = block(self.xchannels[last_channel_idx], stride, t, additv)
features.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format(stage, i, n, len(features), self.xchannels[last_channel_idx], stride, t, c)
last_channel_idx += 1
if i + 1 == xblocks[stage]:
out_channel = module.out_dim
for iiL in range(i+1, n):
last_channel_idx += 1
self.xchannels[last_channel_idx][0] = module.out_dim
break
# building last several layers
features.append(ConvBNReLU(self.xchannels[last_channel_idx][0], self.xchannels[last_channel_idx][1], 1, 1, 1))
assert last_channel_idx + 2 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels))
# make it nn.Sequential
self.features = nn.Sequential(*features)
# building classifier
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(self.xchannels[last_channel_idx][1], num_classes),
)
# weight initialization
self.apply( initialize_resnet )
def get_message(self):
return self.message
def forward(self, inputs):
features = self.features(inputs)
vectors = features.mean([2, 3])
predicts = self.classifier(vectors)
return features, predicts

View File

@@ -0,0 +1,5 @@
from .InferCifarResNet_width import InferWidthCifarResNet
from .InferImagenetResNet import InferImagenetResNet
from .InferCifarResNet_depth import InferDepthCifarResNet
from .InferCifarResNet import InferCifarResNet
from .InferMobileNetV2 import InferMobileNetV2

View File

@@ -0,0 +1,7 @@
# Xuanyi Dong
def parse_channel_info(xstring):
blocks = xstring.split(' ')
blocks = [x.split('-') for x in blocks]
blocks = [[int(_) for _ in x] for x in blocks]
return blocks

View File

@@ -0,0 +1,18 @@
import torch
import torch.nn as nn
def initialize_resnet(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)

View File

@@ -0,0 +1,502 @@
import math, torch
from collections import OrderedDict
from bisect import bisect_right
import torch.nn as nn
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
from .SoftSelect import select2withP, ChannelWiseInter
from .SoftSelect import linear_forward
from .SoftSelect import get_width_choices
def get_depth_choices(nDepth, return_num):
if nDepth == 2:
choices = (1, 2)
elif nDepth == 3:
choices = (1, 2, 3)
elif nDepth > 3:
choices = list(range(1, nDepth+1, 2))
if choices[-1] < nDepth: choices.append(nDepth)
else:
raise ValueError('invalid nDepth : {:}'.format(nDepth))
if return_num: return len(choices)
else : return choices
def conv_forward(inputs, conv, choices):
iC = conv.in_channels
fill_size = list(inputs.size())
fill_size[1] = iC - fill_size[1]
filled = torch.zeros(fill_size, device=inputs.device)
xinputs = torch.cat((inputs, filled), dim=1)
outputs = conv(xinputs)
selecteds = [outputs[:,:oC] for oC in choices]
return selecteds
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
self.InShape = None
self.OutShape = None
self.choices = get_width_choices(nOut)
self.register_buffer('choices_tensor', torch.Tensor( self.choices ))
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
#if has_bn : self.bn = nn.BatchNorm2d(nOut)
#else : self.bn = None
self.has_bn = has_bn
self.BNs = nn.ModuleList()
for i, _out in enumerate(self.choices):
self.BNs.append(nn.BatchNorm2d(_out))
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
self.in_dim = nIn
self.out_dim = nOut
self.search_mode = 'basic'
def get_flops(self, channels, check_range=True, divide=1):
iC, oC = channels
if check_range: assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels)
assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape)
assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape)
#conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups)
all_positions = self.OutShape[0] * self.OutShape[1]
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
if self.conv.bias is not None: flops += all_positions / divide
return flops
def get_range(self):
return [self.choices]
def forward(self, inputs):
if self.search_mode == 'basic':
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, index, prob = tuple_inputs
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
probability = torch.squeeze(probability)
assert len(index) == 2, 'invalid length : {:}'.format(index)
# compute expected flop
#coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
expected_outC = (self.choices_tensor * probability).sum()
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
if self.avg : out = self.avg( inputs )
else : out = inputs
# convolutional layer
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
# merge
out_channel = max([x.size(1) for x in out_bns])
outA = ChannelWiseInter(out_bns[0], out_channel)
outB = ChannelWiseInter(out_bns[1], out_channel)
out = outA * prob[0] + outB * prob[1]
#out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
if self.relu: out = self.relu( out )
else : out = out
return out, expected_outC, expected_flop
def basic_forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.has_bn:out= self.BNs[-1]( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
self.OutShape = (out.size(-2) , out.size(-1))
return out
class ResNetBasicblock(nn.Module):
expansion = 1
num_conv = 2
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = 'basic'
def get_range(self):
return self.conv_a.get_range() + self.conv_b.get_range()
def get_flops(self, channels):
assert len(channels) == 3, 'invalid channels : {:}'.format(channels)
flop_A = self.conv_a.get_flops([channels[0], channels[1]])
flop_B = self.conv_b.get_flops([channels[1], channels[2]])
if hasattr(self.downsample, 'get_flops'):
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_C = 0
if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train
flop_C = channels[0] * channels[-1] * self.conv_b.OutShape[0] * self.conv_b.OutShape[1]
return flop_A + flop_B + flop_C
def forward(self, inputs):
if self.search_mode == 'basic' : return self.basic_forward(inputs)
elif self.search_mode == 'search': return self.search_forward(inputs)
else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2
out_a, expected_inC_a, expected_flop_a = self.conv_a( (inputs, expected_inC , probability[0], indexes[0], probs[0]) )
out_b, expected_inC_b, expected_flop_b = self.conv_b( (out_a , expected_inC_a, probability[1], indexes[1], probs[1]) )
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[1], indexes[1], probs[1]) )
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_b)
return out, expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c])
def basic_forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None: residual = self.downsample(inputs)
else : residual = inputs
out = additive_func(residual, basicblock)
return nn.functional.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes*self.expansion:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes * self.expansion
self.search_mode = 'basic'
def get_range(self):
return self.conv_1x1.get_range() + self.conv_3x3.get_range() + self.conv_1x4.get_range()
def get_flops(self, channels):
assert len(channels) == 4, 'invalid channels : {:}'.format(channels)
flop_A = self.conv_1x1.get_flops([channels[0], channels[1]])
flop_B = self.conv_3x3.get_flops([channels[1], channels[2]])
flop_C = self.conv_1x4.get_flops([channels[2], channels[3]])
if hasattr(self.downsample, 'get_flops'):
flop_D = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_D = 0
if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train
flop_D = channels[0] * channels[-1] * self.conv_1x4.OutShape[0] * self.conv_1x4.OutShape[1]
return flop_A + flop_B + flop_C + flop_D
def forward(self, inputs):
if self.search_mode == 'basic' : return self.basic_forward(inputs)
elif self.search_mode == 'search': return self.search_forward(inputs)
else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def basic_forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None: residual = self.downsample(inputs)
else : residual = inputs
out = additive_func(residual, bottleneck)
return nn.functional.relu(out, inplace=True)
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3
out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1( (inputs, expected_inC , probability[0], indexes[0], probs[0]) )
out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3( (out_1x1,expected_inC_1x1, probability[1], indexes[1], probs[1]) )
out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4( (out_3x3,expected_inC_3x3, probability[2], indexes[2], probs[2]) )
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[2], indexes[2], probs[2]) )
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_1x4)
return out, expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c])
class SearchShapeCifarResNet(nn.Module):
def __init__(self, block_name, depth, num_classes):
super(SearchShapeCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
self.message = 'SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.channels = [16]
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
self.InShape = None
self.depth_info = OrderedDict()
self.depth_at_i = OrderedDict()
for stage in range(3):
cur_block_choices = get_depth_choices(layer_blocks, False)
assert cur_block_choices[-1] == layer_blocks, 'stage={:}, {:} vs {:}'.format(stage, cur_block_choices, layer_blocks)
self.message += "\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format(stage, cur_block_choices, layer_blocks)
block_choices, xstart = [], len(self.layers)
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride)
# added for depth
layer_index = len(self.layers) - 1
if iL + 1 in cur_block_choices: block_choices.append( layer_index )
if iL + 1 == layer_blocks:
self.depth_info[layer_index] = {'choices': block_choices,
'stage' : stage,
'xstart' : xstart}
self.depth_info_list = []
for xend, info in self.depth_info.items():
self.depth_info_list.append( (xend, info) )
xstart, xstage = info['xstart'], info['stage']
for ilayer in range(xstart, xend+1):
idx = bisect_right(info['choices'], ilayer-1)
self.depth_at_i[ilayer] = (xstage, idx)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(module.out_dim, num_classes)
self.InShape = None
self.tau = -1
self.search_mode = 'basic'
#assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
# parameters for width
self.Ranges = []
self.layer2indexRange = []
for i, layer in enumerate(self.layers):
start_index = len(self.Ranges)
self.Ranges += layer.get_range()
self.layer2indexRange.append( (start_index, len(self.Ranges)) )
assert len(self.Ranges) + 1 == depth, 'invalid depth check {:} vs {:}'.format(len(self.Ranges) + 1, depth)
self.register_parameter('width_attentions', nn.Parameter(torch.Tensor(len(self.Ranges), get_width_choices(None))))
self.register_parameter('depth_attentions', nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True))))
nn.init.normal_(self.width_attentions, 0, 0.01)
nn.init.normal_(self.depth_attentions, 0, 0.01)
self.apply(initialize_resnet)
def arch_parameters(self, LR=None):
if LR is None:
return [self.width_attentions, self.depth_attentions]
else:
return [
{"params": self.width_attentions, "lr": LR},
{"params": self.depth_attentions, "lr": LR},
]
def base_parameters(self):
return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters())
def get_flop(self, mode, config_dict, extra_info):
if config_dict is not None: config_dict = config_dict.copy()
# select channels
channels = [3]
for i, weight in enumerate(self.width_attentions):
if mode == 'genotype':
with torch.no_grad():
probe = nn.functional.softmax(weight, dim=0)
C = self.Ranges[i][ torch.argmax(probe).item() ]
elif mode == 'max':
C = self.Ranges[i][-1]
elif mode == 'fix':
C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] )
elif mode == 'random':
assert isinstance(extra_info, float), 'invalid extra_info : {:}'.format(extra_info)
with torch.no_grad():
prob = nn.functional.softmax(weight, dim=0)
approximate_C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] )
for j in range(prob.size(0)):
prob[j] = 1 / (abs(j - (approximate_C-self.Ranges[i][j])) + 0.2)
C = self.Ranges[i][ torch.multinomial(prob, 1, False).item() ]
else:
raise ValueError('invalid mode : {:}'.format(mode))
channels.append( C )
# select depth
if mode == 'genotype':
with torch.no_grad():
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
choices = torch.argmax(depth_probs, dim=1).cpu().tolist()
elif mode == 'max' or mode == 'fix':
choices = [depth_probs.size(1)-1 for _ in range(depth_probs.size(0))]
elif mode == 'random':
with torch.no_grad():
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
choices = torch.multinomial(depth_probs, 1, False).cpu().tolist()
else:
raise ValueError('invalid mode : {:}'.format(mode))
selected_layers = []
for choice, xvalue in zip(choices, self.depth_info_list):
xtemp = xvalue[1]['choices'][choice] - xvalue[1]['xstart'] + 1
selected_layers.append(xtemp)
flop = 0
for i, layer in enumerate(self.layers):
s, e = self.layer2indexRange[i]
xchl = tuple( channels[s:e+1] )
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
if xatti <= choices[xstagei]: # leave this depth
flop+= layer.get_flops(xchl)
else:
flop+= 0 # do not use this layer
else:
flop+= layer.get_flops(xchl)
# the last fc layer
flop += channels[-1] * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict['xchannels'] = channels
config_dict['xblocks'] = selected_layers
config_dict['super_type'] = 'infer-shape'
config_dict['estimated_FLOP'] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self):
string = "for depth and width, there are {:} + {:} attention probabilities.".format(len(self.depth_attentions), len(self.width_attentions))
string+= '\n{:}'.format(self.depth_info)
discrepancy = []
with torch.no_grad():
for i, att in enumerate(self.depth_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist()
prob = ['{:.3f}'.format(x) for x in prob]
xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.depth_attentions), ' '.join(prob))
logt = ['{:.4f}'.format(x) for x in att.cpu().tolist()]
xstring += ' || {:17s}'.format(' '.join(logt))
prob = sorted( [float(x) for x in prob] )
disc = prob[-1] - prob[-2]
xstring += ' || discrepancy={:.2f} || select={:}/{:}'.format(disc, selc, len(prob))
discrepancy.append( disc )
string += '\n{:}'.format(xstring)
string += '\n-----------------------------------------------'
for i, att in enumerate(self.width_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist()
prob = ['{:.3f}'.format(x) for x in prob]
xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.width_attentions), ' '.join(prob))
logt = ['{:.3f}'.format(x) for x in att.cpu().tolist()]
xstring += ' || {:52s}'.format(' '.join(logt))
prob = sorted( [float(x) for x in prob] )
disc = prob[-1] - prob[-2]
xstring += ' || dis={:.2f} || select={:}/{:}'.format(disc, selc, len(prob))
discrepancy.append( disc )
string += '\n{:}'.format(xstring)
return string, discrepancy
def set_tau(self, tau_max, tau_min, epoch_ratio):
assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio)
tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
self.tau = tau
def get_message(self):
return self.message
def forward(self, inputs):
if self.search_mode == 'basic':
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, inputs):
flop_width_probs = nn.functional.softmax(self.width_attentions, dim=1)
flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
flop_depth_probs = torch.flip( torch.cumsum( torch.flip(flop_depth_probs, [1]), 1 ), [1] )
selected_widths, selected_width_probs = select2withP(self.width_attentions, self.tau)
selected_depth_probs = select2withP(self.depth_attentions, self.tau, True)
with torch.no_grad():
selected_widths = selected_widths.cpu()
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
feature_maps = []
for i, layer in enumerate(self.layers):
selected_w_index = selected_widths [last_channel_idx: last_channel_idx+layer.num_conv]
selected_w_probs = selected_width_probs[last_channel_idx: last_channel_idx+layer.num_conv]
layer_prob = flop_width_probs [last_channel_idx: last_channel_idx+layer.num_conv]
x, expected_inC, expected_flop = layer( (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) )
feature_maps.append( x )
last_channel_idx += layer.num_conv
if i in self.depth_info: # aggregate the information
choices = self.depth_info[i]['choices']
xstagei = self.depth_info[i]['stage']
#print ('iL={:}, choices={:}, stage={:}, probs={:}'.format(i, choices, xstagei, selected_depth_probs[xstagei].cpu().tolist()))
#for A, W in zip(choices, selected_depth_probs[xstagei]):
# print('Size = {:}, W = {:}'.format(feature_maps[A].size(), W))
possible_tensors = []
max_C = max( feature_maps[A].size(1) for A in choices )
for tempi, A in enumerate(choices):
xtensor = ChannelWiseInter(feature_maps[A], max_C)
#drop_ratio = 1-(tempi+1.0)/len(choices)
#xtensor = drop_path(xtensor, drop_ratio)
possible_tensors.append( xtensor )
weighted_sum = sum( xtensor * W for xtensor, W in zip(possible_tensors, selected_depth_probs[xstagei]) )
x = weighted_sum
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
x_expected_flop = flop_depth_probs[xstagei, xatti] * expected_flop
else:
x_expected_flop = expected_flop
flops.append( x_expected_flop )
flops.append( expected_inC * (self.classifier.out_features*1.0/1e6) )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack( [sum(flops)] )
def basic_forward(self, inputs):
if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1))
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,337 @@
import math, torch
from collections import OrderedDict
from bisect import bisect_right
import torch.nn as nn
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
from .SoftSelect import select2withP, ChannelWiseInter
from .SoftSelect import linear_forward
from .SoftSelect import get_width_choices
def get_depth_choices(nDepth, return_num):
if nDepth == 2:
choices = (1, 2)
elif nDepth == 3:
choices = (1, 2, 3)
elif nDepth > 3:
choices = list(range(1, nDepth+1, 2))
if choices[-1] < nDepth: choices.append(nDepth)
else:
raise ValueError('invalid nDepth : {:}'.format(nDepth))
if return_num: return len(choices)
else : return choices
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
self.InShape = None
self.OutShape = None
self.choices = get_width_choices(nOut)
self.register_buffer('choices_tensor', torch.Tensor( self.choices ))
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=False)
else : self.relu = None
self.in_dim = nIn
self.out_dim = nOut
def get_flops(self, divide=1):
iC, oC = self.in_dim, self.out_dim
assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels)
assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape)
assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape)
#conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups)
all_positions = self.OutShape[0] * self.OutShape[1]
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
if self.conv.bias is not None: flops += all_positions / divide
return flops
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
self.OutShape = (out.size(-2) , out.size(-1))
return out
class ResNetBasicblock(nn.Module):
expansion = 1
num_conv = 2
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = 'basic'
def get_flops(self, divide=1):
flop_A = self.conv_a.get_flops(divide)
flop_B = self.conv_b.get_flops(divide)
if hasattr(self.downsample, 'get_flops'):
flop_C = self.downsample.get_flops(divide)
else:
flop_C = 0
return flop_A + flop_B + flop_C
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None: residual = self.downsample(inputs)
else : residual = inputs
out = additive_func(residual, basicblock)
return nn.functional.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes*self.expansion:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes * self.expansion
self.search_mode = 'basic'
def get_range(self):
return self.conv_1x1.get_range() + self.conv_3x3.get_range() + self.conv_1x4.get_range()
def get_flops(self, divide):
flop_A = self.conv_1x1.get_flops(divide)
flop_B = self.conv_3x3.get_flops(divide)
flop_C = self.conv_1x4.get_flops(divide)
if hasattr(self.downsample, 'get_flops'):
flop_D = self.downsample.get_flops(divide)
else:
flop_D = 0
return flop_A + flop_B + flop_C + flop_D
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None: residual = self.downsample(inputs)
else : residual = inputs
out = additive_func(residual, bottleneck)
return nn.functional.relu(out, inplace=True)
class SearchDepthCifarResNet(nn.Module):
def __init__(self, block_name, depth, num_classes):
super(SearchDepthCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
self.message = 'SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.channels = [16]
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
self.InShape = None
self.depth_info = OrderedDict()
self.depth_at_i = OrderedDict()
for stage in range(3):
cur_block_choices = get_depth_choices(layer_blocks, False)
assert cur_block_choices[-1] == layer_blocks, 'stage={:}, {:} vs {:}'.format(stage, cur_block_choices, layer_blocks)
self.message += "\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format(stage, cur_block_choices, layer_blocks)
block_choices, xstart = [], len(self.layers)
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride)
# added for depth
layer_index = len(self.layers) - 1
if iL + 1 in cur_block_choices: block_choices.append( layer_index )
if iL + 1 == layer_blocks:
self.depth_info[layer_index] = {'choices': block_choices,
'stage' : stage,
'xstart' : xstart}
self.depth_info_list = []
for xend, info in self.depth_info.items():
self.depth_info_list.append( (xend, info) )
xstart, xstage = info['xstart'], info['stage']
for ilayer in range(xstart, xend+1):
idx = bisect_right(info['choices'], ilayer-1)
self.depth_at_i[ilayer] = (xstage, idx)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(module.out_dim, num_classes)
self.InShape = None
self.tau = -1
self.search_mode = 'basic'
#assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
self.register_parameter('depth_attentions', nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True))))
nn.init.normal_(self.depth_attentions, 0, 0.01)
self.apply(initialize_resnet)
def arch_parameters(self):
return [self.depth_attentions]
def base_parameters(self):
return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters())
def get_flop(self, mode, config_dict, extra_info):
if config_dict is not None: config_dict = config_dict.copy()
# select depth
if mode == 'genotype':
with torch.no_grad():
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
choices = torch.argmax(depth_probs, dim=1).cpu().tolist()
elif mode == 'max':
choices = [depth_probs.size(1)-1 for _ in range(depth_probs.size(0))]
elif mode == 'random':
with torch.no_grad():
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
choices = torch.multinomial(depth_probs, 1, False).cpu().tolist()
else:
raise ValueError('invalid mode : {:}'.format(mode))
selected_layers = []
for choice, xvalue in zip(choices, self.depth_info_list):
xtemp = xvalue[1]['choices'][choice] - xvalue[1]['xstart'] + 1
selected_layers.append(xtemp)
flop = 0
for i, layer in enumerate(self.layers):
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
if xatti <= choices[xstagei]: # leave this depth
flop+= layer.get_flops()
else:
flop+= 0 # do not use this layer
else:
flop+= layer.get_flops()
# the last fc layer
flop += self.classifier.in_features * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict['xblocks'] = selected_layers
config_dict['super_type'] = 'infer-depth'
config_dict['estimated_FLOP'] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self):
string = "for depth, there are {:} attention probabilities.".format(len(self.depth_attentions))
string+= '\n{:}'.format(self.depth_info)
discrepancy = []
with torch.no_grad():
for i, att in enumerate(self.depth_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist()
prob = ['{:.3f}'.format(x) for x in prob]
xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.depth_attentions), ' '.join(prob))
logt = ['{:.4f}'.format(x) for x in att.cpu().tolist()]
xstring += ' || {:17s}'.format(' '.join(logt))
prob = sorted( [float(x) for x in prob] )
disc = prob[-1] - prob[-2]
xstring += ' || discrepancy={:.2f} || select={:}/{:}'.format(disc, selc, len(prob))
discrepancy.append( disc )
string += '\n{:}'.format(xstring)
return string, discrepancy
def set_tau(self, tau_max, tau_min, epoch_ratio):
assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio)
tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
self.tau = tau
def get_message(self):
return self.message
def forward(self, inputs):
if self.search_mode == 'basic':
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, inputs):
flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
flop_depth_probs = torch.flip( torch.cumsum( torch.flip(flop_depth_probs, [1]), 1 ), [1] )
selected_depth_probs = select2withP(self.depth_attentions, self.tau, True)
x, flops = inputs, []
feature_maps = []
for i, layer in enumerate(self.layers):
layer_i = layer( x )
feature_maps.append( layer_i )
if i in self.depth_info: # aggregate the information
choices = self.depth_info[i]['choices']
xstagei = self.depth_info[i]['stage']
possible_tensors = []
for tempi, A in enumerate(choices):
xtensor = feature_maps[A]
possible_tensors.append( xtensor )
weighted_sum = sum( xtensor * W for xtensor, W in zip(possible_tensors, selected_depth_probs[xstagei]) )
x = weighted_sum
else:
x = layer_i
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
#print ('layer-{:03d}, stage={:}, att={:}, prob={:}, flop={:}'.format(i, xstagei, xatti, flop_depth_probs[xstagei, xatti].item(), layer.get_flops(1e6)))
x_expected_flop = flop_depth_probs[xstagei, xatti] * layer.get_flops(1e6)
else:
x_expected_flop = layer.get_flops(1e6)
flops.append( x_expected_flop )
flops.append( (self.classifier.in_features * self.classifier.out_features*1.0/1e6) )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack( [sum(flops)] )
def basic_forward(self, inputs):
if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1))
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,391 @@
import math, torch
import torch.nn as nn
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
from .SoftSelect import select2withP, ChannelWiseInter
from .SoftSelect import linear_forward
from .SoftSelect import get_width_choices as get_choices
def conv_forward(inputs, conv, choices):
iC = conv.in_channels
fill_size = list(inputs.size())
fill_size[1] = iC - fill_size[1]
filled = torch.zeros(fill_size, device=inputs.device)
xinputs = torch.cat((inputs, filled), dim=1)
outputs = conv(xinputs)
selecteds = [outputs[:,:oC] for oC in choices]
return selecteds
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
self.InShape = None
self.OutShape = None
self.choices = get_choices(nOut)
self.register_buffer('choices_tensor', torch.Tensor( self.choices ))
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
#if has_bn : self.bn = nn.BatchNorm2d(nOut)
#else : self.bn = None
self.has_bn = has_bn
self.BNs = nn.ModuleList()
for i, _out in enumerate(self.choices):
self.BNs.append(nn.BatchNorm2d(_out))
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
self.in_dim = nIn
self.out_dim = nOut
self.search_mode = 'basic'
def get_flops(self, channels, check_range=True, divide=1):
iC, oC = channels
if check_range: assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels)
assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape)
assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape)
#conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups)
all_positions = self.OutShape[0] * self.OutShape[1]
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
if self.conv.bias is not None: flops += all_positions / divide
return flops
def get_range(self):
return [self.choices]
def forward(self, inputs):
if self.search_mode == 'basic':
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, index, prob = tuple_inputs
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
probability = torch.squeeze(probability)
assert len(index) == 2, 'invalid length : {:}'.format(index)
# compute expected flop
#coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
expected_outC = (self.choices_tensor * probability).sum()
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
if self.avg : out = self.avg( inputs )
else : out = inputs
# convolutional layer
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
# merge
out_channel = max([x.size(1) for x in out_bns])
outA = ChannelWiseInter(out_bns[0], out_channel)
outB = ChannelWiseInter(out_bns[1], out_channel)
out = outA * prob[0] + outB * prob[1]
#out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
if self.relu: out = self.relu( out )
else : out = out
return out, expected_outC, expected_flop
def basic_forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.has_bn:out= self.BNs[-1]( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
self.OutShape = (out.size(-2) , out.size(-1))
return out
class ResNetBasicblock(nn.Module):
expansion = 1
num_conv = 2
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = 'basic'
def get_range(self):
return self.conv_a.get_range() + self.conv_b.get_range()
def get_flops(self, channels):
assert len(channels) == 3, 'invalid channels : {:}'.format(channels)
flop_A = self.conv_a.get_flops([channels[0], channels[1]])
flop_B = self.conv_b.get_flops([channels[1], channels[2]])
if hasattr(self.downsample, 'get_flops'):
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_C = 0
if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train
flop_C = channels[0] * channels[-1] * self.conv_b.OutShape[0] * self.conv_b.OutShape[1]
return flop_A + flop_B + flop_C
def forward(self, inputs):
if self.search_mode == 'basic' : return self.basic_forward(inputs)
elif self.search_mode == 'search': return self.search_forward(inputs)
else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2
out_a, expected_inC_a, expected_flop_a = self.conv_a( (inputs, expected_inC , probability[0], indexes[0], probs[0]) )
out_b, expected_inC_b, expected_flop_b = self.conv_b( (out_a , expected_inC_a, probability[1], indexes[1], probs[1]) )
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[1], indexes[1], probs[1]) )
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_b)
return out, expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c])
def basic_forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None: residual = self.downsample(inputs)
else : residual = inputs
out = additive_func(residual, basicblock)
return nn.functional.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes*self.expansion:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes * self.expansion
self.search_mode = 'basic'
def get_range(self):
return self.conv_1x1.get_range() + self.conv_3x3.get_range() + self.conv_1x4.get_range()
def get_flops(self, channels):
assert len(channels) == 4, 'invalid channels : {:}'.format(channels)
flop_A = self.conv_1x1.get_flops([channels[0], channels[1]])
flop_B = self.conv_3x3.get_flops([channels[1], channels[2]])
flop_C = self.conv_1x4.get_flops([channels[2], channels[3]])
if hasattr(self.downsample, 'get_flops'):
flop_D = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_D = 0
if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train
flop_D = channels[0] * channels[-1] * self.conv_1x4.OutShape[0] * self.conv_1x4.OutShape[1]
return flop_A + flop_B + flop_C + flop_D
def forward(self, inputs):
if self.search_mode == 'basic' : return self.basic_forward(inputs)
elif self.search_mode == 'search': return self.search_forward(inputs)
else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def basic_forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None: residual = self.downsample(inputs)
else : residual = inputs
out = additive_func(residual, bottleneck)
return nn.functional.relu(out, inplace=True)
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3
out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1( (inputs, expected_inC , probability[0], indexes[0], probs[0]) )
out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3( (out_1x1,expected_inC_1x1, probability[1], indexes[1], probs[1]) )
out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4( (out_3x3,expected_inC_3x3, probability[2], indexes[2], probs[2]) )
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[2], indexes[2], probs[2]) )
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_1x4)
return out, expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c])
class SearchWidthCifarResNet(nn.Module):
def __init__(self, block_name, depth, num_classes):
super(SearchWidthCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
self.message = 'SearchWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.channels = [16]
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
self.InShape = None
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(module.out_dim, num_classes)
self.InShape = None
self.tau = -1
self.search_mode = 'basic'
#assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
# parameters for width
self.Ranges = []
self.layer2indexRange = []
for i, layer in enumerate(self.layers):
start_index = len(self.Ranges)
self.Ranges += layer.get_range()
self.layer2indexRange.append( (start_index, len(self.Ranges)) )
assert len(self.Ranges) + 1 == depth, 'invalid depth check {:} vs {:}'.format(len(self.Ranges) + 1, depth)
self.register_parameter('width_attentions', nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))))
nn.init.normal_(self.width_attentions, 0, 0.01)
self.apply(initialize_resnet)
def arch_parameters(self):
return [self.width_attentions]
def base_parameters(self):
return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters())
def get_flop(self, mode, config_dict, extra_info):
if config_dict is not None: config_dict = config_dict.copy()
#weights = [F.softmax(x, dim=0) for x in self.width_attentions]
channels = [3]
for i, weight in enumerate(self.width_attentions):
if mode == 'genotype':
with torch.no_grad():
probe = nn.functional.softmax(weight, dim=0)
C = self.Ranges[i][ torch.argmax(probe).item() ]
elif mode == 'max':
C = self.Ranges[i][-1]
elif mode == 'fix':
C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] )
elif mode == 'random':
assert isinstance(extra_info, float), 'invalid extra_info : {:}'.format(extra_info)
with torch.no_grad():
prob = nn.functional.softmax(weight, dim=0)
approximate_C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] )
for j in range(prob.size(0)):
prob[j] = 1 / (abs(j - (approximate_C-self.Ranges[i][j])) + 0.2)
C = self.Ranges[i][ torch.multinomial(prob, 1, False).item() ]
else:
raise ValueError('invalid mode : {:}'.format(mode))
channels.append( C )
flop = 0
for i, layer in enumerate(self.layers):
s, e = self.layer2indexRange[i]
xchl = tuple( channels[s:e+1] )
flop+= layer.get_flops(xchl)
# the last fc layer
flop += channels[-1] * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict['xchannels'] = channels
config_dict['super_type'] = 'infer-width'
config_dict['estimated_FLOP'] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self):
string = "for width, there are {:} attention probabilities.".format(len(self.width_attentions))
discrepancy = []
with torch.no_grad():
for i, att in enumerate(self.width_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist()
prob = ['{:.3f}'.format(x) for x in prob]
xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.width_attentions), ' '.join(prob))
logt = ['{:.3f}'.format(x) for x in att.cpu().tolist()]
xstring += ' || {:52s}'.format(' '.join(logt))
prob = sorted( [float(x) for x in prob] )
disc = prob[-1] - prob[-2]
xstring += ' || dis={:.2f} || select={:}/{:}'.format(disc, selc, len(prob))
discrepancy.append( disc )
string += '\n{:}'.format(xstring)
return string, discrepancy
def set_tau(self, tau_max, tau_min, epoch_ratio):
assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio)
tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
self.tau = tau
def get_message(self):
return self.message
def forward(self, inputs):
if self.search_mode == 'basic':
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, inputs):
flop_probs = nn.functional.softmax(self.width_attentions, dim=1)
selected_widths, selected_probs = select2withP(self.width_attentions, self.tau)
with torch.no_grad():
selected_widths = selected_widths.cpu()
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
for i, layer in enumerate(self.layers):
selected_w_index = selected_widths[last_channel_idx: last_channel_idx+layer.num_conv]
selected_w_probs = selected_probs[last_channel_idx: last_channel_idx+layer.num_conv]
layer_prob = flop_probs[last_channel_idx: last_channel_idx+layer.num_conv]
x, expected_inC, expected_flop = layer( (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) )
last_channel_idx += layer.num_conv
flops.append( expected_flop )
flops.append( expected_inC * (self.classifier.out_features*1.0/1e6) )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack( [sum(flops)] )
def basic_forward(self, inputs):
if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1))
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,483 @@
import math, torch
from collections import OrderedDict
from bisect import bisect_right
import torch.nn as nn
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
from .SoftSelect import select2withP, ChannelWiseInter
from .SoftSelect import linear_forward
from .SoftSelect import get_width_choices
def get_depth_choices(layers):
min_depth = min(layers)
info = {'num': min_depth}
for i, depth in enumerate(layers):
choices = []
for j in range(1, min_depth+1):
choices.append( int( float(depth)*j/min_depth ) )
info[i] = choices
return info
def conv_forward(inputs, conv, choices):
iC = conv.in_channels
fill_size = list(inputs.size())
fill_size[1] = iC - fill_size[1]
filled = torch.zeros(fill_size, device=inputs.device)
xinputs = torch.cat((inputs, filled), dim=1)
outputs = conv(xinputs)
selecteds = [outputs[:,:oC] for oC in choices]
return selecteds
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu, last_max_pool=False):
super(ConvBNReLU, self).__init__()
self.InShape = None
self.OutShape = None
self.choices = get_width_choices(nOut)
self.register_buffer('choices_tensor', torch.Tensor( self.choices ))
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
#if has_bn : self.bn = nn.BatchNorm2d(nOut)
#else : self.bn = None
self.has_bn = has_bn
self.BNs = nn.ModuleList()
for i, _out in enumerate(self.choices):
self.BNs.append(nn.BatchNorm2d(_out))
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
if last_max_pool: self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
else : self.maxpool = None
self.in_dim = nIn
self.out_dim = nOut
self.search_mode = 'basic'
def get_flops(self, channels, check_range=True, divide=1):
iC, oC = channels
if check_range: assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels)
assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape)
assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape)
#conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups)
all_positions = self.OutShape[0] * self.OutShape[1]
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
if self.conv.bias is not None: flops += all_positions / divide
return flops
def get_range(self):
return [self.choices]
def forward(self, inputs):
if self.search_mode == 'basic':
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, index, prob = tuple_inputs
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
probability = torch.squeeze(probability)
assert len(index) == 2, 'invalid length : {:}'.format(index)
# compute expected flop
#coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
expected_outC = (self.choices_tensor * probability).sum()
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
if self.avg : out = self.avg( inputs )
else : out = inputs
# convolutional layer
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
# merge
out_channel = max([x.size(1) for x in out_bns])
outA = ChannelWiseInter(out_bns[0], out_channel)
outB = ChannelWiseInter(out_bns[1], out_channel)
out = outA * prob[0] + outB * prob[1]
#out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
if self.relu : out = self.relu( out )
if self.maxpool: out = self.maxpool(out)
return out, expected_outC, expected_flop
def basic_forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.has_bn:out= self.BNs[-1]( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
self.OutShape = (out.size(-2) , out.size(-1))
if self.maxpool: out = self.maxpool(out)
return out
class ResNetBasicblock(nn.Module):
expansion = 1
num_conv = 2
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=True, has_relu=False)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True, has_relu=False)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = 'basic'
def get_range(self):
return self.conv_a.get_range() + self.conv_b.get_range()
def get_flops(self, channels):
assert len(channels) == 3, 'invalid channels : {:}'.format(channels)
flop_A = self.conv_a.get_flops([channels[0], channels[1]])
flop_B = self.conv_b.get_flops([channels[1], channels[2]])
if hasattr(self.downsample, 'get_flops'):
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_C = 0
if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train
flop_C = channels[0] * channels[-1] * self.conv_b.OutShape[0] * self.conv_b.OutShape[1]
return flop_A + flop_B + flop_C
def forward(self, inputs):
if self.search_mode == 'basic' : return self.basic_forward(inputs)
elif self.search_mode == 'search': return self.search_forward(inputs)
else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2
#import pdb; pdb.set_trace()
out_a, expected_inC_a, expected_flop_a = self.conv_a( (inputs, expected_inC , probability[0], indexes[0], probs[0]) )
out_b, expected_inC_b, expected_flop_b = self.conv_b( (out_a , expected_inC_a, probability[1], indexes[1], probs[1]) )
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[1], indexes[1], probs[1]) )
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_b)
return out, expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c])
def basic_forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None: residual = self.downsample(inputs)
else : residual = inputs
out = additive_func(residual, basicblock)
return nn.functional.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True, has_bn=True, has_relu=False)
elif inplanes != planes*self.expansion:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False,has_bn=True, has_relu=False)
else:
self.downsample = None
self.out_dim = planes * self.expansion
self.search_mode = 'basic'
def get_range(self):
return self.conv_1x1.get_range() + self.conv_3x3.get_range() + self.conv_1x4.get_range()
def get_flops(self, channels):
assert len(channels) == 4, 'invalid channels : {:}'.format(channels)
flop_A = self.conv_1x1.get_flops([channels[0], channels[1]])
flop_B = self.conv_3x3.get_flops([channels[1], channels[2]])
flop_C = self.conv_1x4.get_flops([channels[2], channels[3]])
if hasattr(self.downsample, 'get_flops'):
flop_D = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_D = 0
if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train
flop_D = channels[0] * channels[-1] * self.conv_1x4.OutShape[0] * self.conv_1x4.OutShape[1]
return flop_A + flop_B + flop_C + flop_D
def forward(self, inputs):
if self.search_mode == 'basic' : return self.basic_forward(inputs)
elif self.search_mode == 'search': return self.search_forward(inputs)
else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def basic_forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None: residual = self.downsample(inputs)
else : residual = inputs
out = additive_func(residual, bottleneck)
return nn.functional.relu(out, inplace=True)
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3
out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1( (inputs, expected_inC , probability[0], indexes[0], probs[0]) )
out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3( (out_1x1,expected_inC_1x1, probability[1], indexes[1], probs[1]) )
out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4( (out_3x3,expected_inC_3x3, probability[2], indexes[2], probs[2]) )
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[2], indexes[2], probs[2]) )
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_1x4)
return out, expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c])
class SearchShapeImagenetResNet(nn.Module):
def __init__(self, block_name, layers, deep_stem, num_classes):
super(SearchShapeImagenetResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'BasicBlock':
block = ResNetBasicblock
elif block_name == 'Bottleneck':
block = ResNetBottleneck
else:
raise ValueError('invalid block : {:}'.format(block_name))
self.message = 'SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}'.format(sum(layers)*block.num_conv, layers)
self.num_classes = num_classes
if not deep_stem:
self.layers = nn.ModuleList( [ ConvBNReLU(3, 64, 7, 2, 3, False, has_avg=False, has_bn=True, has_relu=True, last_max_pool=True) ] )
self.channels = [64]
else:
self.layers = nn.ModuleList( [ ConvBNReLU(3, 32, 3, 2, 1, False, has_avg=False, has_bn=True, has_relu=True)
,ConvBNReLU(32,64, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True, last_max_pool=True) ] )
self.channels = [32, 64]
meta_depth_info = get_depth_choices(layers)
self.InShape = None
self.depth_info = OrderedDict()
self.depth_at_i = OrderedDict()
for stage, layer_blocks in enumerate(layers):
cur_block_choices = meta_depth_info[stage]
assert cur_block_choices[-1] == layer_blocks, 'stage={:}, {:} vs {:}'.format(stage, cur_block_choices, layer_blocks)
block_choices, xstart = [], len(self.layers)
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 64 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride)
# added for depth
layer_index = len(self.layers) - 1
if iL + 1 in cur_block_choices: block_choices.append( layer_index )
if iL + 1 == layer_blocks:
self.depth_info[layer_index] = {'choices': block_choices,
'stage' : stage,
'xstart' : xstart}
self.depth_info_list = []
for xend, info in self.depth_info.items():
self.depth_info_list.append( (xend, info) )
xstart, xstage = info['xstart'], info['stage']
for ilayer in range(xstart, xend+1):
idx = bisect_right(info['choices'], ilayer-1)
self.depth_at_i[ilayer] = (xstage, idx)
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.classifier = nn.Linear(module.out_dim, num_classes)
self.InShape = None
self.tau = -1
self.search_mode = 'basic'
#assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
# parameters for width
self.Ranges = []
self.layer2indexRange = []
for i, layer in enumerate(self.layers):
start_index = len(self.Ranges)
self.Ranges += layer.get_range()
self.layer2indexRange.append( (start_index, len(self.Ranges)) )
self.register_parameter('width_attentions', nn.Parameter(torch.Tensor(len(self.Ranges), get_width_choices(None))))
self.register_parameter('depth_attentions', nn.Parameter(torch.Tensor(len(layers), meta_depth_info['num'])))
nn.init.normal_(self.width_attentions, 0, 0.01)
nn.init.normal_(self.depth_attentions, 0, 0.01)
self.apply(initialize_resnet)
def arch_parameters(self, LR=None):
if LR is None:
return [self.width_attentions, self.depth_attentions]
else:
return [
{"params": self.width_attentions, "lr": LR},
{"params": self.depth_attentions, "lr": LR},
]
def base_parameters(self):
return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters())
def get_flop(self, mode, config_dict, extra_info):
if config_dict is not None: config_dict = config_dict.copy()
# select channels
channels = [3]
for i, weight in enumerate(self.width_attentions):
if mode == 'genotype':
with torch.no_grad():
probe = nn.functional.softmax(weight, dim=0)
C = self.Ranges[i][ torch.argmax(probe).item() ]
else:
raise ValueError('invalid mode : {:}'.format(mode))
channels.append( C )
# select depth
if mode == 'genotype':
with torch.no_grad():
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
choices = torch.argmax(depth_probs, dim=1).cpu().tolist()
else:
raise ValueError('invalid mode : {:}'.format(mode))
selected_layers = []
for choice, xvalue in zip(choices, self.depth_info_list):
xtemp = xvalue[1]['choices'][choice] - xvalue[1]['xstart'] + 1
selected_layers.append(xtemp)
flop = 0
for i, layer in enumerate(self.layers):
s, e = self.layer2indexRange[i]
xchl = tuple( channels[s:e+1] )
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
if xatti <= choices[xstagei]: # leave this depth
flop+= layer.get_flops(xchl)
else:
flop+= 0 # do not use this layer
else:
flop+= layer.get_flops(xchl)
# the last fc layer
flop += channels[-1] * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict['xchannels'] = channels
config_dict['xblocks'] = selected_layers
config_dict['super_type'] = 'infer-shape'
config_dict['estimated_FLOP'] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self):
string = "for depth and width, there are {:} + {:} attention probabilities.".format(len(self.depth_attentions), len(self.width_attentions))
string+= '\n{:}'.format(self.depth_info)
discrepancy = []
with torch.no_grad():
for i, att in enumerate(self.depth_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist()
prob = ['{:.3f}'.format(x) for x in prob]
xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.depth_attentions), ' '.join(prob))
logt = ['{:.4f}'.format(x) for x in att.cpu().tolist()]
xstring += ' || {:17s}'.format(' '.join(logt))
prob = sorted( [float(x) for x in prob] )
disc = prob[-1] - prob[-2]
xstring += ' || discrepancy={:.2f} || select={:}/{:}'.format(disc, selc, len(prob))
discrepancy.append( disc )
string += '\n{:}'.format(xstring)
string += '\n-----------------------------------------------'
for i, att in enumerate(self.width_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist()
prob = ['{:.3f}'.format(x) for x in prob]
xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.width_attentions), ' '.join(prob))
logt = ['{:.3f}'.format(x) for x in att.cpu().tolist()]
xstring += ' || {:52s}'.format(' '.join(logt))
prob = sorted( [float(x) for x in prob] )
disc = prob[-1] - prob[-2]
xstring += ' || dis={:.2f} || select={:}/{:}'.format(disc, selc, len(prob))
discrepancy.append( disc )
string += '\n{:}'.format(xstring)
return string, discrepancy
def set_tau(self, tau_max, tau_min, epoch_ratio):
assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio)
tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
self.tau = tau
def get_message(self):
return self.message
def forward(self, inputs):
if self.search_mode == 'basic':
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, inputs):
flop_width_probs = nn.functional.softmax(self.width_attentions, dim=1)
flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
flop_depth_probs = torch.flip( torch.cumsum( torch.flip(flop_depth_probs, [1]), 1 ), [1] )
selected_widths, selected_width_probs = select2withP(self.width_attentions, self.tau)
selected_depth_probs = select2withP(self.depth_attentions, self.tau, True)
with torch.no_grad():
selected_widths = selected_widths.cpu()
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
feature_maps = []
for i, layer in enumerate(self.layers):
selected_w_index = selected_widths [last_channel_idx: last_channel_idx+layer.num_conv]
selected_w_probs = selected_width_probs[last_channel_idx: last_channel_idx+layer.num_conv]
layer_prob = flop_width_probs [last_channel_idx: last_channel_idx+layer.num_conv]
x, expected_inC, expected_flop = layer( (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) )
feature_maps.append( x )
last_channel_idx += layer.num_conv
if i in self.depth_info: # aggregate the information
choices = self.depth_info[i]['choices']
xstagei = self.depth_info[i]['stage']
#print ('iL={:}, choices={:}, stage={:}, probs={:}'.format(i, choices, xstagei, selected_depth_probs[xstagei].cpu().tolist()))
#for A, W in zip(choices, selected_depth_probs[xstagei]):
# print('Size = {:}, W = {:}'.format(feature_maps[A].size(), W))
possible_tensors = []
max_C = max( feature_maps[A].size(1) for A in choices )
for tempi, A in enumerate(choices):
xtensor = ChannelWiseInter(feature_maps[A], max_C)
possible_tensors.append( xtensor )
weighted_sum = sum( xtensor * W for xtensor, W in zip(possible_tensors, selected_depth_probs[xstagei]) )
x = weighted_sum
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
x_expected_flop = flop_depth_probs[xstagei, xatti] * expected_flop
else:
x_expected_flop = expected_flop
flops.append( x_expected_flop )
flops.append( expected_inC * (self.classifier.out_features*1.0/1e6) )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack( [sum(flops)] )
def basic_forward(self, inputs):
if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1))
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,108 @@
import math, torch
import torch.nn as nn
def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7):
if tau <= 0:
new_logits = logits
probs = nn.functional.softmax(new_logits, dim=1)
else :
while True: # a trick to avoid the gumbels bug
gumbels = -torch.empty_like(logits).exponential_().log()
new_logits = (logits + gumbels) / tau
probs = nn.functional.softmax(new_logits, dim=1)
if (not torch.isinf(gumbels).any()) and (not torch.isinf(probs).any()) and (not torch.isnan(probs).any()): break
if just_prob: return probs
#with torch.no_grad(): # add eps for unexpected torch error
# probs = nn.functional.softmax(new_logits, dim=1)
# selected_index = torch.multinomial(probs + eps, 2, False)
with torch.no_grad(): # add eps for unexpected torch error
probs = probs.cpu()
selected_index = torch.multinomial(probs + eps, num, False).to(logits.device)
selected_logit = torch.gather(new_logits, 1, selected_index)
selcted_probs = nn.functional.softmax(selected_logit, dim=1)
return selected_index, selcted_probs
def ChannelWiseInter(inputs, oC, mode='v2'):
if mode == 'v1':
return ChannelWiseInterV1(inputs, oC)
elif mode == 'v2':
return ChannelWiseInterV2(inputs, oC)
else:
raise ValueError('invalid mode : {:}'.format(mode))
def ChannelWiseInterV1(inputs, oC):
assert inputs.dim() == 4, 'invalid dimension : {:}'.format(inputs.size())
def start_index(a, b, c):
return int( math.floor(float(a * c) / b) )
def end_index(a, b, c):
return int( math.ceil(float((a + 1) * c) / b) )
batch, iC, H, W = inputs.size()
outputs = torch.zeros((batch, oC, H, W), dtype=inputs.dtype, device=inputs.device)
if iC == oC: return inputs
for ot in range(oC):
istartT, iendT = start_index(ot, oC, iC), end_index(ot, oC, iC)
values = inputs[:, istartT:iendT].mean(dim=1)
outputs[:, ot, :, :] = values
return outputs
def ChannelWiseInterV2(inputs, oC):
assert inputs.dim() == 4, 'invalid dimension : {:}'.format(inputs.size())
batch, C, H, W = inputs.size()
if C == oC: return inputs
else : return nn.functional.adaptive_avg_pool3d(inputs, (oC,H,W))
#inputs_5D = inputs.view(batch, 1, C, H, W)
#otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'area', None)
#otputs = otputs_5D.view(batch, oC, H, W)
#otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'trilinear', False)
#return otputs
def linear_forward(inputs, linear):
if linear is None: return inputs
iC = inputs.size(1)
weight = linear.weight[:, :iC]
if linear.bias is None: bias = None
else : bias = linear.bias
return nn.functional.linear(inputs, weight, bias)
def get_width_choices(nOut):
xsrange = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
if nOut is None:
return len(xsrange)
else:
Xs = [int(nOut * i) for i in xsrange]
#xs = [ int(nOut * i // 10) for i in range(2, 11)]
#Xs = [x for i, x in enumerate(xs) if i+1 == len(xs) or xs[i+1] > x+1]
Xs = sorted( list( set(Xs) ) )
return tuple(Xs)
def get_depth_choices(nDepth):
if nDepth is None:
return 3
else:
assert nDepth >= 3, 'nDepth should be greater than 2 vs {:}'.format(nDepth)
if nDepth == 1 : return (1, 1, 1)
elif nDepth == 2: return (1, 1, 2)
elif nDepth >= 3:
return (nDepth//3, nDepth*2//3, nDepth)
else:
raise ValueError('invalid Depth : {:}'.format(nDepth))
def drop_path(x, drop_prob):
if drop_prob > 0.:
keep_prob = 1. - drop_prob
mask = x.new_zeros(x.size(0), 1, 1, 1)
mask = mask.bernoulli_(keep_prob)
x = x * (mask / keep_prob)
#x.div_(keep_prob)
#x.mul_(mask)
return x

View File

@@ -0,0 +1,4 @@
from .SearchCifarResNet_width import SearchWidthCifarResNet
from .SearchCifarResNet_depth import SearchDepthCifarResNet
from .SearchCifarResNet import SearchShapeCifarResNet
from .SearchImagenetResNet import SearchShapeImagenetResNet

View File

@@ -0,0 +1,17 @@
import torch
import torch.nn as nn
from SoftSelect import ChannelWiseInter
if __name__ == '__main__':
tensors = torch.rand((16, 128, 7, 7))
for oc in range(200, 210):
out_v1 = ChannelWiseInter(tensors, oc, 'v1')
out_v2 = ChannelWiseInter(tensors, oc, 'v2')
assert (out_v1 == out_v2).any().item() == 1
for oc in range(48, 160):
out_v1 = ChannelWiseInter(tensors, oc, 'v1')
out_v2 = ChannelWiseInter(tensors, oc, 'v2')
assert (out_v1 == out_v2).any().item() == 1

139
lib/models/sphereface.py Normal file
View File

@@ -0,0 +1,139 @@
# SphereFace: Deep Hypersphere Embedding for Face Recognition
#
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def myphi(x,m):
x = x * m
return 1-x**2/math.factorial(2)+x**4/math.factorial(4)-x**6/math.factorial(6) + \
x**8/math.factorial(8) - x**9/math.factorial(9)
class AngleLinear(nn.Module):
def __init__(self, in_features, out_features, m = 4, phiflag=True):
super(AngleLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(in_features,out_features))
self.weight.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
self.phiflag = phiflag
self.m = m
self.mlambda = [
lambda x: x**0,
lambda x: x**1,
lambda x: 2*x**2-1,
lambda x: 4*x**3-3*x,
lambda x: 8*x**4-8*x**2+1,
lambda x: 16*x**5-20*x**3+5*x
]
def forward(self, input):
x = input # size=(B,F) F is feature len
w = self.weight # size=(F,Classnum) F=in_features Classnum=out_features
ww = w.renorm(2,1,1e-5).mul(1e5)
xlen = x.pow(2).sum(1).pow(0.5) # size=B
wlen = ww.pow(2).sum(0).pow(0.5) # size=Classnum
cos_theta = x.mm(ww) # size=(B,Classnum)
cos_theta = cos_theta / xlen.view(-1,1) / wlen.view(1,-1)
cos_theta = cos_theta.clamp(-1,1)
if self.phiflag:
cos_m_theta = self.mlambda[self.m](cos_theta)
with torch.no_grad():
theta = cos_theta.acos()
k = (self.m*theta/3.14159265).floor()
n_one = k*0.0 - 1
phi_theta = (n_one**k) * cos_m_theta - 2*k
else:
theta = cos_theta.acos()
phi_theta = myphi(theta,self.m)
phi_theta = phi_theta.clamp(-1*self.m,1)
cos_theta = cos_theta * xlen.view(-1,1)
phi_theta = phi_theta * xlen.view(-1,1)
output = (cos_theta,phi_theta)
return output # size=(B,Classnum,2)
class SphereFace20(nn.Module):
def __init__(self, classnum=10574):
super(SphereFace20, self).__init__()
self.classnum = classnum
#input = B*3*112*96
self.conv1_1 = nn.Conv2d(3,64,3,2,1) #=>B*64*56*48
self.relu1_1 = nn.PReLU(64)
self.conv1_2 = nn.Conv2d(64,64,3,1,1)
self.relu1_2 = nn.PReLU(64)
self.conv1_3 = nn.Conv2d(64,64,3,1,1)
self.relu1_3 = nn.PReLU(64)
self.conv2_1 = nn.Conv2d(64,128,3,2,1) #=>B*128*28*24
self.relu2_1 = nn.PReLU(128)
self.conv2_2 = nn.Conv2d(128,128,3,1,1)
self.relu2_2 = nn.PReLU(128)
self.conv2_3 = nn.Conv2d(128,128,3,1,1)
self.relu2_3 = nn.PReLU(128)
self.conv2_4 = nn.Conv2d(128,128,3,1,1) #=>B*128*28*24
self.relu2_4 = nn.PReLU(128)
self.conv2_5 = nn.Conv2d(128,128,3,1,1)
self.relu2_5 = nn.PReLU(128)
self.conv3_1 = nn.Conv2d(128,256,3,2,1) #=>B*256*14*12
self.relu3_1 = nn.PReLU(256)
self.conv3_2 = nn.Conv2d(256,256,3,1,1)
self.relu3_2 = nn.PReLU(256)
self.conv3_3 = nn.Conv2d(256,256,3,1,1)
self.relu3_3 = nn.PReLU(256)
self.conv3_4 = nn.Conv2d(256,256,3,1,1) #=>B*256*14*12
self.relu3_4 = nn.PReLU(256)
self.conv3_5 = nn.Conv2d(256,256,3,1,1)
self.relu3_5 = nn.PReLU(256)
self.conv3_6 = nn.Conv2d(256,256,3,1,1) #=>B*256*14*12
self.relu3_6 = nn.PReLU(256)
self.conv3_7 = nn.Conv2d(256,256,3,1,1)
self.relu3_7 = nn.PReLU(256)
self.conv3_8 = nn.Conv2d(256,256,3,1,1) #=>B*256*14*12
self.relu3_8 = nn.PReLU(256)
self.conv3_9 = nn.Conv2d(256,256,3,1,1)
self.relu3_9 = nn.PReLU(256)
self.conv4_1 = nn.Conv2d(256,512,3,2,1) #=>B*512*7*6
self.relu4_1 = nn.PReLU(512)
self.conv4_2 = nn.Conv2d(512,512,3,1,1)
self.relu4_2 = nn.PReLU(512)
self.conv4_3 = nn.Conv2d(512,512,3,1,1)
self.relu4_3 = nn.PReLU(512)
self.fc5 = nn.Linear(512*7*6,512)
self.fc6 = AngleLinear(512, self.classnum)
def forward(self, x):
x = self.relu1_1(self.conv1_1(x))
x = x + self.relu1_3(self.conv1_3(self.relu1_2(self.conv1_2(x))))
x = self.relu2_1(self.conv2_1(x))
x = x + self.relu2_3(self.conv2_3(self.relu2_2(self.conv2_2(x))))
x = x + self.relu2_5(self.conv2_5(self.relu2_4(self.conv2_4(x))))
x = self.relu3_1(self.conv3_1(x))
x = x + self.relu3_3(self.conv3_3(self.relu3_2(self.conv3_2(x))))
x = x + self.relu3_5(self.conv3_5(self.relu3_4(self.conv3_4(x))))
x = x + self.relu3_7(self.conv3_7(self.relu3_6(self.conv3_6(x))))
x = x + self.relu3_9(self.conv3_9(self.relu3_8(self.conv3_8(x))))
x = self.relu4_1(self.conv4_1(x))
x = x + self.relu4_3(self.conv4_3(self.relu4_2(self.conv4_2(x))))
x = x.view(x.size(0),-1)
features = self.fc5(x)
logits = self.fc6(features)
return features, logits

View File

@@ -1,89 +0,0 @@
import torch
import torch.nn as nn
from .construct_utils import Cell, Transition
class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x
class NetworkCIFAR(nn.Module):
def __init__(self, C, num_classes, layers, auxiliary, genotype):
super(NetworkCIFAR, self).__init__()
self._layers = layers
stem_multiplier = 3
C_curr = stem_multiplier*C
self.stem = nn.Sequential(
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr)
)
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
for i in range(layers):
if i in [layers//3, 2*layers//3]:
C_curr *= 2
reduction = True
else:
reduction = False
if reduction and genotype.reduce is None:
cell = Transition(C_prev_prev, C_prev, C_curr, reduction_prev)
else:
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells.append( cell )
C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr
if i == 2*layers//3:
C_to_auxiliary = C_prev
if auxiliary:
self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes)
else:
self.auxiliary_head = None
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.drop_path_prob = -1
def update_drop_path(self, drop_path_prob):
self.drop_path_prob = drop_path_prob
def auxiliary_param(self):
if self.auxiliary_head is None: return []
else: return list( self.auxiliary_head.parameters() )
def forward(self, inputs):
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
if i == 2*self._layers//3:
if self.auxiliary_head and self.training:
logits_aux = self.auxiliary_head(s1)
out = self.global_pooling(s1)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
if self.auxiliary_head and self.training:
return logits, logits_aux
else:
return logits

View File

@@ -1,104 +0,0 @@
import torch
import torch.nn as nn
from .construct_utils import Cell, Transition
class AuxiliaryHeadImageNet(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 14x14"""
super(AuxiliaryHeadImageNet, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
# NOTE: This batchnorm was omitted in my earlier implementation due to a typo.
# Commenting it out for consistency with the experiments in the paper.
# nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x
class NetworkImageNet(nn.Module):
def __init__(self, C, num_classes, layers, auxiliary, genotype):
super(NetworkImageNet, self).__init__()
self._layers = layers
self.stem0 = nn.Sequential(
nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C // 2),
nn.ReLU(inplace=True),
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
self.stem1 = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
C_prev_prev, C_prev, C_curr = C, C, C
self.cells = nn.ModuleList()
reduction_prev = True
for i in range(layers):
if i in [layers // 3, 2 * layers // 3]:
C_curr *= 2
reduction = True
else:
reduction = False
if reduction and genotype.reduce is None:
cell = Transition(C_prev_prev, C_prev, C_curr, reduction_prev)
else:
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
if i == 2 * layers // 3:
C_to_auxiliary = C_prev
if auxiliary:
self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
else:
self.auxiliary_head = None
self.global_pooling = nn.AvgPool2d(7)
self.classifier = nn.Linear(C_prev, num_classes)
self.drop_path_prob = -1
def update_drop_path(self, drop_path_prob):
self.drop_path_prob = drop_path_prob
def get_drop_path(self):
return self.drop_path_prob
def auxiliary_param(self):
if self.auxiliary_head is None: return []
else: return list( self.auxiliary_head.parameters() )
def forward(self, input):
s0 = self.stem0(input)
s1 = self.stem1(s0)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
#print ('{:} : {:} - {:}'.format(i, s0.size(), s1.size()))
if i == 2 * self._layers // 3:
if self.auxiliary_head and self.training:
logits_aux = self.auxiliary_head(s1)
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0), -1))
if self.auxiliary_head and self.training:
return logits, logits_aux
else:
return logits

View File

@@ -1,27 +0,0 @@
import torch
import torch.nn as nn
# Squeeze and Excitation module
class SqEx(nn.Module):
def __init__(self, n_features, reduction=16):
super(SqEx, self).__init__()
if n_features % reduction != 0:
raise ValueError('n_features must be divisible by reduction (default = 16)')
self.linear1 = nn.Linear(n_features, n_features // reduction, bias=True)
self.nonlin1 = nn.ReLU(inplace=True)
self.linear2 = nn.Linear(n_features // reduction, n_features, bias=True)
self.nonlin2 = nn.Sigmoid()
def forward(self, x):
y = F.avg_pool2d(x, kernel_size=x.size()[2:4])
y = y.permute(0, 2, 3, 1)
y = self.nonlin1(self.linear1(y))
y = self.nonlin2(self.linear2(y))
y = y.permute(0, 3, 1, 2)
y = x * y
return y

View File

@@ -1,10 +0,0 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .CifarNet import NetworkCIFAR
from .ImageNet import NetworkImageNet
# genotypes
from .genotypes import model_types
from .construct_utils import return_alphas_str

View File

@@ -1,152 +0,0 @@
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from .operations import OPS, FactorizedReduce, ReLUConvBN, Identity
def random_select(length, ratio):
clist = []
index = random.randint(0, length-1)
for i in range(length):
if i == index or random.random() < ratio:
clist.append( 1 )
else:
clist.append( 0 )
return clist
def all_select(length):
return [1 for i in range(length)]
def drop_path(x, drop_prob):
if drop_prob > 0.:
keep_prob = 1. - drop_prob
mask = x.new_zeros(x.size(0), 1, 1, 1)
mask = mask.bernoulli_(keep_prob)
x.div_(keep_prob)
x.mul_(mask)
return x
def return_alphas_str(basemodel):
string = 'normal : {:}'.format( F.softmax(basemodel.alphas_normal, dim=-1) )
if hasattr(basemodel, 'alphas_reduce'):
string = string + '\nreduce : {:}'.format( F.softmax(basemodel.alphas_reduce, dim=-1) )
return string
class Cell(nn.Module):
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
super(Cell, self).__init__()
print(C_prev_prev, C_prev, C)
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
if reduction:
op_names, indices, values = zip(*genotype.reduce)
concat = genotype.reduce_concat
else:
op_names, indices, values = zip(*genotype.normal)
concat = genotype.normal_concat
self._compile(C, op_names, indices, values, concat, reduction)
def _compile(self, C, op_names, indices, values, concat, reduction):
assert len(op_names) == len(indices)
self._steps = len(op_names) // 2
self._concat = concat
self.multiplier = len(concat)
self._ops = nn.ModuleList()
for name, index in zip(op_names, indices):
stride = 2 if reduction and index < 2 else 1
op = OPS[name](C, stride, True)
self._ops.append( op )
self._indices = indices
self._values = values
def forward(self, s0, s1, drop_prob):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
h1 = states[self._indices[2*i]]
h2 = states[self._indices[2*i+1]]
op1 = self._ops[2*i]
op2 = self._ops[2*i+1]
h1 = op1(h1)
h2 = op2(h2)
if self.training and drop_prob > 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
s = h1 + h2
states += [s]
return torch.cat([states[i] for i in self._concat], dim=1)
class Transition(nn.Module):
def __init__(self, C_prev_prev, C_prev, C, reduction_prev, multiplier=4):
super(Transition, self).__init__()
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
self.multiplier = multiplier
self.reduction = True
self.ops1 = nn.ModuleList(
[nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
nn.BatchNorm2d(C, affine=True),
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C, affine=True)),
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
nn.BatchNorm2d(C, affine=True),
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C, affine=True))])
self.ops2 = nn.ModuleList(
[nn.Sequential(
nn.MaxPool2d(3, stride=2, padding=1),
nn.BatchNorm2d(C, affine=True)),
nn.Sequential(
nn.MaxPool2d(3, stride=2, padding=1),
nn.BatchNorm2d(C, affine=True))])
def forward(self, s0, s1, drop_prob = -1):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
X0 = self.ops1[0] (s0)
X1 = self.ops1[1] (s1)
if self.training and drop_prob > 0.:
X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob)
#X2 = self.ops2[0] (X0+X1)
X2 = self.ops2[0] (s0)
X3 = self.ops2[1] (s1)
if self.training and drop_prob > 0.:
X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob)
return torch.cat([X0, X1, X2, X3], dim=1)

View File

@@ -1,245 +0,0 @@
from collections import namedtuple
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
PRIMITIVES = [
'none',
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5'
]
NASNet = Genotype(
normal = [
('sep_conv_5x5', 1, 1.0),
('sep_conv_3x3', 0, 1.0),
('sep_conv_5x5', 0, 1.0),
('sep_conv_3x3', 0, 1.0),
('avg_pool_3x3', 1, 1.0),
('skip_connect', 0, 1.0),
('avg_pool_3x3', 0, 1.0),
('avg_pool_3x3', 0, 1.0),
('sep_conv_3x3', 1, 1.0),
('skip_connect', 1, 1.0),
],
normal_concat = [2, 3, 4, 5, 6],
reduce = [
('sep_conv_5x5', 1, 1.0),
('sep_conv_7x7', 0, 1.0),
('max_pool_3x3', 1, 1.0),
('sep_conv_7x7', 0, 1.0),
('avg_pool_3x3', 1, 1.0),
('sep_conv_5x5', 0, 1.0),
('skip_connect', 3, 1.0),
('avg_pool_3x3', 2, 1.0),
('sep_conv_3x3', 2, 1.0),
('max_pool_3x3', 1, 1.0),
],
reduce_concat = [4, 5, 6],
)
AmoebaNet = Genotype(
normal = [
('avg_pool_3x3', 0, 1.0),
('max_pool_3x3', 1, 1.0),
('sep_conv_3x3', 0, 1.0),
('sep_conv_5x5', 2, 1.0),
('sep_conv_3x3', 0, 1.0),
('avg_pool_3x3', 3, 1.0),
('sep_conv_3x3', 1, 1.0),
('skip_connect', 1, 1.0),
('skip_connect', 0, 1.0),
('avg_pool_3x3', 1, 1.0),
],
normal_concat = [4, 5, 6],
reduce = [
('avg_pool_3x3', 0, 1.0),
('sep_conv_3x3', 1, 1.0),
('max_pool_3x3', 0, 1.0),
('sep_conv_7x7', 2, 1.0),
('sep_conv_7x7', 0, 1.0),
('avg_pool_3x3', 1, 1.0),
('max_pool_3x3', 0, 1.0),
('max_pool_3x3', 1, 1.0),
('conv_7x1_1x7', 0, 1.0),
('sep_conv_3x3', 5, 1.0),
],
reduce_concat = [3, 4, 6]
)
DARTS_V1 = Genotype(
normal=[
('sep_conv_3x3', 1, 1.0),
('sep_conv_3x3', 0, 1.0),
('skip_connect', 0, 1.0),
('sep_conv_3x3', 1, 1.0),
('skip_connect', 0, 1.0),
('sep_conv_3x3', 1, 1.0),
('sep_conv_3x3', 0, 1.0),
('skip_connect', 2, 1.0)],
normal_concat=[2, 3, 4, 5],
reduce=[
('max_pool_3x3', 0, 1.0),
('max_pool_3x3', 1, 1.0),
('skip_connect', 2, 1.0),
('max_pool_3x3', 0, 1.0),
('max_pool_3x3', 0, 1.0),
('skip_connect', 2, 1.0),
('skip_connect', 2, 1.0),
('avg_pool_3x3', 0, 1.0)],
reduce_concat=[2, 3, 4, 5]
)
DARTS_V2 = Genotype(
normal=[
('sep_conv_3x3', 0, 1.0),
('sep_conv_3x3', 1, 1.0),
('sep_conv_3x3', 0, 1.0),
('sep_conv_3x3', 1, 1.0),
('sep_conv_3x3', 1, 1.0),
('skip_connect', 0, 1.0),
('skip_connect', 0, 1.0),
('dil_conv_3x3', 2, 1.0)],
normal_concat=[2, 3, 4, 5],
reduce=[
('max_pool_3x3', 0, 1.0),
('max_pool_3x3', 1, 1.0),
('skip_connect', 2, 1.0),
('max_pool_3x3', 1, 1.0),
('max_pool_3x3', 0, 1.0),
('skip_connect', 2, 1.0),
('skip_connect', 2, 1.0),
('max_pool_3x3', 1, 1.0)],
reduce_concat=[2, 3, 4, 5]
)
PNASNet = Genotype(
normal = [
('sep_conv_5x5', 0, 1.0),
('max_pool_3x3', 0, 1.0),
('sep_conv_7x7', 1, 1.0),
('max_pool_3x3', 1, 1.0),
('sep_conv_5x5', 1, 1.0),
('sep_conv_3x3', 1, 1.0),
('sep_conv_3x3', 4, 1.0),
('max_pool_3x3', 1, 1.0),
('sep_conv_3x3', 0, 1.0),
('skip_connect', 1, 1.0),
],
normal_concat = [2, 3, 4, 5, 6],
reduce = [
('sep_conv_5x5', 0, 1.0),
('max_pool_3x3', 0, 1.0),
('sep_conv_7x7', 1, 1.0),
('max_pool_3x3', 1, 1.0),
('sep_conv_5x5', 1, 1.0),
('sep_conv_3x3', 1, 1.0),
('sep_conv_3x3', 4, 1.0),
('max_pool_3x3', 1, 1.0),
('sep_conv_3x3', 0, 1.0),
('skip_connect', 1, 1.0),
],
reduce_concat = [2, 3, 4, 5, 6],
)
# https://arxiv.org/pdf/1802.03268.pdf
ENASNet = Genotype(
normal = [
('sep_conv_3x3', 1, 1.0),
('skip_connect', 1, 1.0),
('sep_conv_5x5', 1, 1.0),
('skip_connect', 0, 1.0),
('avg_pool_3x3', 0, 1.0),
('sep_conv_3x3', 1, 1.0),
('sep_conv_3x3', 0, 1.0),
('avg_pool_3x3', 1, 1.0),
('sep_conv_5x5', 1, 1.0),
('avg_pool_3x3', 0, 1.0),
],
normal_concat = [2, 3, 4, 5, 6],
reduce = [
('sep_conv_5x5', 0, 1.0),
('sep_conv_3x3', 1, 1.0), # 2
('sep_conv_3x3', 1, 1.0),
('avg_pool_3x3', 1, 1.0), # 3
('sep_conv_3x3', 1, 1.0),
('avg_pool_3x3', 1, 1.0), # 4
('avg_pool_3x3', 1, 1.0),
('sep_conv_5x5', 4, 1.0), # 5
('sep_conv_3x3', 5, 1.0),
('sep_conv_5x5', 0, 1.0),
],
reduce_concat = [2, 3, 4, 5, 6],
)
DARTS = DARTS_V2
# Search by normal and reduce
GDAS_V1 = Genotype(
normal=[('skip_connect', 0, 0.13017432391643524), ('skip_connect', 1, 0.12947972118854523), ('skip_connect', 0, 0.13062666356563568), ('sep_conv_5x5', 2, 0.12980839610099792), ('sep_conv_3x3', 3, 0.12923765182495117), ('skip_connect', 0, 0.12901571393013), ('sep_conv_5x5', 4, 0.12938997149467468), ('sep_conv_3x3', 3, 0.1289220005273819)],
normal_concat=range(2, 6),
reduce=[('sep_conv_5x5', 0, 0.12862831354141235), ('sep_conv_3x3', 1, 0.12783904373645782), ('sep_conv_5x5', 2, 0.12725995481014252), ('sep_conv_5x5', 1, 0.12705285847187042), ('dil_conv_5x5', 2, 0.12797553837299347), ('sep_conv_3x3', 1, 0.12737272679805756), ('sep_conv_5x5', 0, 0.12833961844444275), ('sep_conv_5x5', 1, 0.12758426368236542)],
reduce_concat=range(2, 6)
)
# Search by normal and fixing reduction
GDAS_F1 = Genotype(
normal=[('skip_connect', 0, 0.16), ('skip_connect', 1, 0.13), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.16), ('sep_conv_3x3', 2, 0.15)],
normal_concat=[2, 3, 4, 5],
reduce=None,
reduce_concat=[2, 3, 4, 5],
)
# Combine DMS_V1 and DMS_F1
GDAS_GF = Genotype(
normal=[('skip_connect', 0, 0.13017432391643524), ('skip_connect', 1, 0.12947972118854523), ('skip_connect', 0, 0.13062666356563568), ('sep_conv_5x5', 2, 0.12980839610099792), ('sep_conv_3x3', 3, 0.12923765182495117), ('skip_connect', 0, 0.12901571393013), ('sep_conv_5x5', 4, 0.12938997149467468), ('sep_conv_3x3', 3, 0.1289220005273819)],
normal_concat=range(2, 6),
reduce=None,
reduce_concat=range(2, 6)
)
GDAS_FG = Genotype(
normal=[('skip_connect', 0, 0.16), ('skip_connect', 1, 0.13), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.16), ('sep_conv_3x3', 2, 0.15)],
normal_concat=range(2, 6),
reduce=[('sep_conv_5x5', 0, 0.12862831354141235), ('sep_conv_3x3', 1, 0.12783904373645782), ('sep_conv_5x5', 2, 0.12725995481014252), ('sep_conv_5x5', 1, 0.12705285847187042), ('dil_conv_5x5', 2, 0.12797553837299347), ('sep_conv_3x3', 1, 0.12737272679805756), ('sep_conv_5x5', 0, 0.12833961844444275), ('sep_conv_5x5', 1, 0.12758426368236542)],
reduce_concat=range(2, 6)
)
PDARTS = Genotype(
normal=[
('skip_connect', 0, 1.0),
('dil_conv_3x3', 1, 1.0),
('skip_connect', 0, 1.0),
('sep_conv_3x3', 1, 1.0),
('sep_conv_3x3', 1, 1.0),
('sep_conv_3x3', 3, 1.0),
('sep_conv_3x3', 0, 1.0),
('dil_conv_5x5', 4, 1.0)],
normal_concat=range(2, 6),
reduce=[
('avg_pool_3x3', 0, 1.0),
('sep_conv_5x5', 1, 1.0),
('sep_conv_3x3', 0, 1.0),
('dil_conv_5x5', 2, 1.0),
('max_pool_3x3', 0, 1.0),
('dil_conv_3x3', 1, 1.0),
('dil_conv_3x3', 1, 1.0),
('dil_conv_5x5', 3, 1.0)],
reduce_concat=range(2, 6)
)
model_types = {'DARTS_V1': DARTS_V1,
'DARTS_V2': DARTS_V2,
'NASNet' : NASNet,
'PNASNet' : PNASNet,
'AmoebaNet': AmoebaNet,
'ENASNet' : ENASNet,
'PDARTS' : PDARTS,
'GDAS_V1' : GDAS_V1,
'GDAS_F1' : GDAS_F1,
'GDAS_GF' : GDAS_GF,
'GDAS_FG' : GDAS_FG}

View File

@@ -1,19 +0,0 @@
import torch
import torch.nn as nn
class ImageNetHEAD(nn.Sequential):
def __init__(self, C, stride=2):
super(ImageNetHEAD, self).__init__()
self.add_module('conv1', nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False))
self.add_module('bn1' , nn.BatchNorm2d(C // 2))
self.add_module('relu1', nn.ReLU(inplace=True))
self.add_module('conv2', nn.Conv2d(C // 2, C, kernel_size=3, stride=stride, padding=1, bias=False))
self.add_module('bn2' , nn.BatchNorm2d(C))
class CifarHEAD(nn.Sequential):
def __init__(self, C):
super(CifarHEAD, self).__init__()
self.add_module('conv', nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False))
self.add_module('bn', nn.BatchNorm2d(C))

View File

@@ -1,122 +0,0 @@
import torch
import torch.nn as nn
OPS = {
'none' : lambda C, stride, affine: Zero(stride),
'avg_pool_3x3' : lambda C, stride, affine: nn.Sequential(
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
nn.BatchNorm2d(C, affine=False) ),
'max_pool_3x3' : lambda C, stride, affine: nn.Sequential(
nn.MaxPool2d(3, stride=stride, padding=1),
nn.BatchNorm2d(C, affine=False) ),
'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine),
'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine),
'conv_7x1_1x7' : lambda C, stride, affine: Conv717(C, C, stride, affine),
}
class Conv717(nn.Module):
def __init__(self, C_in, C_out, stride, affine):
super(Conv717, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in , C_out, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
nn.Conv2d(C_out, C_out, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.op(x)
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.op(x)
class DilConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super(DilConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x):
return self.op(x)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x):
return self.op(x)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Zero(nn.Module):
def __init__(self, stride):
super(Zero, self).__init__()
self.stride = stride
def forward(self, x):
if self.stride == 1:
return x.mul(0.)
return x[:,:,::self.stride,::self.stride].mul(0.)
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=True):
super(FactorizedReduce, self).__init__()
assert C_out % 2 == 0
self.relu = nn.ReLU(inplace=False)
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
def forward(self, x):
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.conv_1(x), self.conv_2(y[:,:,1:,1:])], dim=1)
out = self.bn(out)
return out

View File

@@ -0,0 +1,76 @@
import torch
import torch.nn as nn
from .construct_utils import drop_path
from .head_utils import CifarHEAD, AuxiliaryHeadCIFAR
from .base_cells import InferCell
class NetworkCIFAR(nn.Module):
def __init__(self, C, N, stem_multiplier, auxiliary, genotype, num_classes):
super(NetworkCIFAR, self).__init__()
self._C = C
self._layerN = N
self._stem_multiplier = stem_multiplier
C_curr = self._stem_multiplier * C
self.stem = CifarHEAD(C_curr)
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
block_indexs = [0 ] * N + [-1 ] + [1 ] * N + [-1 ] + [2 ] * N
block2index = {0:[], 1:[], 2:[]}
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
reduction_prev, spatial, dims = False, 1, []
self.auxiliary_index = None
self.auxiliary_head = None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells.append( cell )
C_prev_prev, C_prev = C_prev, cell._multiplier*C_curr
if reduction and C_curr == C*4:
if auxiliary:
self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes)
self.auxiliary_index = index
if reduction: spatial *= 2
dims.append( (C_prev, spatial) )
self._Layer= len(self.cells)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.drop_path_prob = -1
def update_drop_path(self, drop_path_prob):
self.drop_path_prob = drop_path_prob
def auxiliary_param(self):
if self.auxiliary_head is None: return []
else: return list( self.auxiliary_head.parameters() )
def get_message(self):
return self.extra_repr()
def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, L={_Layer}, stem={_stem_multiplier}, drop-path={drop_path_prob})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, inputs):
stem_feature, logits_aux = self.stem(inputs), None
cell_results = [stem_feature, stem_feature]
for i, cell in enumerate(self.cells):
cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob)
cell_results.append( cell_feature )
if self.auxiliary_index is not None and i == self.auxiliary_index and self.training:
logits_aux = self.auxiliary_head( cell_results[-1] )
out = self.global_pooling( cell_results[-1] )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
if logits_aux is None: return out, logits
else : return out, [logits, logits_aux]

View File

@@ -0,0 +1,77 @@
import torch
import torch.nn as nn
from .construct_utils import drop_path
from .base_cells import InferCell
from .head_utils import ImageNetHEAD, AuxiliaryHeadImageNet
class NetworkImageNet(nn.Module):
def __init__(self, C, N, auxiliary, genotype, num_classes):
super(NetworkImageNet, self).__init__()
self._C = C
self._layerN = N
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
self.stem0 = nn.Sequential(
nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C // 2),
nn.ReLU(inplace=True),
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
self.stem1 = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
C_prev_prev, C_prev, C_curr, reduction_prev = C, C, C, True
self.cells = nn.ModuleList()
self.auxiliary_index = None
for i, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell._multiplier * C_curr
if reduction and C_curr == C*4:
C_to_auxiliary = C_prev
self.auxiliary_index = i
self._NNN = len(self.cells)
if auxiliary:
self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
else:
self.auxiliary_head = None
self.global_pooling = nn.AvgPool2d(7)
self.classifier = nn.Linear(C_prev, num_classes)
self.drop_path_prob = -1
def update_drop_path(self, drop_path_prob):
self.drop_path_prob = drop_path_prob
def extra_repr(self):
return ('{name}(C={_C}, N=[{_layerN}, {_NNN}], aux-index={auxiliary_index}, drop-path={drop_path_prob})'.format(name=self.__class__.__name__, **self.__dict__))
def get_message(self):
return self.extra_repr()
def auxiliary_param(self):
if self.auxiliary_head is None: return []
else: return list( self.auxiliary_head.parameters() )
def forward(self, inputs):
s0 = self.stem0(inputs)
s1 = self.stem1(s0)
logits_aux = None
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
if i == self.auxiliary_index and self.auxiliary_head and self.training:
logits_aux = self.auxiliary_head(s1)
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0), -1))
if logits_aux is None: return out, logits
else : return out, [logits, logits_aux]

View File

@@ -0,0 +1,4 @@
# Performance-Aware Template Network for One-Shot Neural Architecture Search
from .CifarNet import NetworkCIFAR as CifarNet
from .ImageNet import NetworkImageNet as ImageNet
from .genotypes import Networks

View File

@@ -0,0 +1,173 @@
import math
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from .construct_utils import drop_path
from ..operations import OPS, Identity, FactorizedReduce, ReLUConvBN
class MixedOp(nn.Module):
def __init__(self, C, stride, PRIMITIVES):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
self.name2idx = {}
for idx, primitive in enumerate(PRIMITIVES):
op = OPS[primitive](C, C, stride, False)
self._ops.append(op)
assert primitive not in self.name2idx, '{:} has already in'.format(primitive)
self.name2idx[primitive] = idx
def forward(self, x, weights, op_name):
if op_name is None:
if weights is None:
return [op(x) for op in self._ops]
else:
return sum(w * op(x) for w, op in zip(weights, self._ops))
else:
op_index = self.name2idx[op_name]
return self._ops[op_index](x)
class SearchCell(nn.Module):
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, PRIMITIVES, use_residual):
super(SearchCell, self).__init__()
self.reduction = reduction
self.PRIMITIVES = deepcopy(PRIMITIVES)
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine=False)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
self._steps = steps
self._multiplier = multiplier
self._use_residual = use_residual
self._ops = nn.ModuleList()
for i in range(self._steps):
for j in range(2+i):
stride = 2 if reduction and j < 2 else 1
op = MixedOp(C, stride, self.PRIMITIVES)
self._ops.append(op)
def extra_repr(self):
return ('{name}(residual={_use_residual}, steps={_steps}, multiplier={_multiplier})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, S0, S1, weights, connect, adjacency, drop_prob, modes):
if modes[0] is None:
if modes[1] == 'normal':
output = self.__forwardBoth(S0, S1, weights, connect, adjacency, drop_prob)
elif modes[1] == 'only_W':
output = self.__forwardOnlyW(S0, S1, drop_prob)
else:
test_genotype = modes[0]
if self.reduction: operations, concats = test_genotype.reduce, test_genotype.reduce_concat
else : operations, concats = test_genotype.normal, test_genotype.normal_concat
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
states, offset = [s0, s1], 0
assert self._steps == len(operations), '{:} vs. {:}'.format(self._steps, len(operations))
for i, (opA, opB) in enumerate(operations):
A = self._ops[offset + opA[1]](states[opA[1]], None, opA[0])
B = self._ops[offset + opB[1]](states[opB[1]], None, opB[0])
state = A + B
offset += len(states)
states.append(state)
output = torch.cat([states[i] for i in concats], dim=1)
if self._use_residual and S1.size() == output.size():
return S1 + output
else: return output
def __forwardBoth(self, S0, S1, weights, connect, adjacency, drop_prob):
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
states, offset = [s0, s1], 0
for i in range(self._steps):
clist = []
for j, h in enumerate(states):
x = self._ops[offset+j](h, weights[offset+j], None)
if self.training and drop_prob > 0.:
x = drop_path(x, math.pow(drop_prob, 1./len(states)))
clist.append( x )
connection = torch.mm(connect['{:}'.format(i)], adjacency[i]).squeeze(0)
state = sum(w * node for w, node in zip(connection, clist))
offset += len(states)
states.append(state)
return torch.cat(states[-self._multiplier:], dim=1)
def __forwardOnlyW(self, S0, S1, drop_prob):
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
states, offset = [s0, s1], 0
for i in range(self._steps):
clist = []
for j, h in enumerate(states):
xs = self._ops[offset+j](h, None, None)
clist += xs
if self.training and drop_prob > 0.:
xlist = [drop_path(x, math.pow(drop_prob, 1./len(states))) for x in clist]
else: xlist = clist
state = sum(xlist) * 2 / len(xlist)
offset += len(states)
states.append(state)
return torch.cat(states[-self._multiplier:], dim=1)
class InferCell(nn.Module):
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
super(InferCell, self).__init__()
print(C_prev_prev, C_prev, C)
if reduction_prev is None:
self.preprocess0 = Identity()
elif reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
if reduction: step_ops, concat = genotype.reduce, genotype.reduce_concat
else : step_ops, concat = genotype.normal, genotype.normal_concat
self._steps = len(step_ops)
self._concat = concat
self._multiplier = len(concat)
self._ops = nn.ModuleList()
self._indices = []
for operations in step_ops:
for name, index in operations:
stride = 2 if reduction and index < 2 else 1
if reduction_prev is None and index == 0:
op = OPS[name](C_prev_prev, C, stride, True)
else:
op = OPS[name](C , C, stride, True)
self._ops.append( op )
self._indices.append( index )
def extra_repr(self):
return ('{name}(steps={_steps}, concat={_concat})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, S0, S1, drop_prob):
s0 = self.preprocess0(S0)
s1 = self.preprocess1(S1)
states = [s0, s1]
for i in range(self._steps):
h1 = states[self._indices[2*i]]
h2 = states[self._indices[2*i+1]]
op1 = self._ops[2*i]
op2 = self._ops[2*i+1]
h1 = op1(h1)
h2 = op2(h2)
if self.training and drop_prob > 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
state = h1 + h2
states += [state]
output = torch.cat([states[i] for i in self._concat], dim=1)
return output

View File

@@ -0,0 +1,60 @@
import torch
import torch.nn.functional as F
def drop_path(x, drop_prob):
if drop_prob > 0.:
keep_prob = 1. - drop_prob
mask = x.new_zeros(x.size(0), 1, 1, 1)
mask = mask.bernoulli_(keep_prob)
x = torch.div(x, keep_prob)
x.mul_(mask)
return x
def return_alphas_str(basemodel):
if hasattr(basemodel, 'alphas_normal'):
string = 'normal [{:}] : \n-->>{:}'.format(basemodel.alphas_normal.size(), F.softmax(basemodel.alphas_normal, dim=-1) )
else: string = ''
if hasattr(basemodel, 'alphas_reduce'):
string = string + '\nreduce : {:}'.format( F.softmax(basemodel.alphas_reduce, dim=-1) )
if hasattr(basemodel, 'get_adjacency'):
adjacency = basemodel.get_adjacency()
for i in range( len(adjacency) ):
weight = F.softmax( basemodel.connect_normal[str(i)], dim=-1 )
adj = torch.mm(weight, adjacency[i]).view(-1)
adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()]
string = string + '\nnormal--{:}-->{:}'.format(i, ', '.join(adj))
for i in range( len(adjacency) ):
weight = F.softmax( basemodel.connect_reduce[str(i)], dim=-1 )
adj = torch.mm(weight, adjacency[i]).view(-1)
adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()]
string = string + '\nreduce--{:}-->{:}'.format(i, ', '.join(adj))
if hasattr(basemodel, 'alphas_connect'):
weight = F.softmax(basemodel.alphas_connect, dim=-1).cpu()
ZERO = ['{:.3f}'.format(x) for x in weight[:,0].tolist()]
IDEN = ['{:.3f}'.format(x) for x in weight[:,1].tolist()]
string = string + '\nconnect [{:}] : \n ->{:}\n ->{:}'.format( list(basemodel.alphas_connect.size()), ZERO, IDEN )
else:
string = string + '\nconnect = None'
if hasattr(basemodel, 'get_gcn_out'):
outputs = basemodel.get_gcn_out(True)
for i, output in enumerate(outputs):
string = string + '\nnormal:[{:}] : {:}'.format(i, F.softmax(output, dim=-1) )
return string
def remove_duplicate_archs(all_archs):
archs = []
str_archs = ['{:}'.format(x) for x in all_archs]
for i, arch_x in enumerate(str_archs):
choose = True
for j in range(i):
if arch_x == str_archs[j]:
choose = False; break
if choose: archs.append(all_archs[i])
return archs

View File

@@ -0,0 +1,172 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from collections import namedtuple
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat connectN connects')
#Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
PRIMITIVES_small = [
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'conv_3x1_1x3',
]
PRIMITIVES_large = [
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5',
'conv_3x1_1x3',
]
PRIMITIVES_huge = [
'skip_connect',
'nor_conv_1x1',
'max_pool_3x3',
'avg_pool_3x3',
'nor_conv_3x3',
'sep_conv_3x3',
'dil_conv_3x3',
'conv_3x1_1x3',
'sep_conv_5x5',
'dil_conv_5x5',
'sep_conv_7x7',
'conv_7x1_1x7',
'att_squeeze',
]
PRIMITIVES = {'small': PRIMITIVES_small,
'large': PRIMITIVES_large,
'huge' : PRIMITIVES_huge}
NASNet = Genotype(
normal = [
(('sep_conv_5x5', 1), ('sep_conv_3x3', 0)),
(('sep_conv_5x5', 0), ('sep_conv_3x3', 0)),
(('avg_pool_3x3', 1), ('skip_connect', 0)),
(('avg_pool_3x3', 0), ('avg_pool_3x3', 0)),
(('sep_conv_3x3', 1), ('skip_connect', 1)),
],
normal_concat = [2, 3, 4, 5, 6],
reduce = [
(('sep_conv_5x5', 1), ('sep_conv_7x7', 0)),
(('max_pool_3x3', 1), ('sep_conv_7x7', 0)),
(('avg_pool_3x3', 1), ('sep_conv_5x5', 0)),
(('skip_connect', 3), ('avg_pool_3x3', 2)),
(('sep_conv_3x3', 2), ('max_pool_3x3', 1)),
],
reduce_concat = [4, 5, 6],
connectN=None, connects=None,
)
PNASNet = Genotype(
normal = [
(('sep_conv_5x5', 0), ('max_pool_3x3', 0)),
(('sep_conv_7x7', 1), ('max_pool_3x3', 1)),
(('sep_conv_5x5', 1), ('sep_conv_3x3', 1)),
(('sep_conv_3x3', 4), ('max_pool_3x3', 1)),
(('sep_conv_3x3', 0), ('skip_connect', 1)),
],
normal_concat = [2, 3, 4, 5, 6],
reduce = [
(('sep_conv_5x5', 0), ('max_pool_3x3', 0)),
(('sep_conv_7x7', 1), ('max_pool_3x3', 1)),
(('sep_conv_5x5', 1), ('sep_conv_3x3', 1)),
(('sep_conv_3x3', 4), ('max_pool_3x3', 1)),
(('sep_conv_3x3', 0), ('skip_connect', 1)),
],
reduce_concat = [2, 3, 4, 5, 6],
connectN=None, connects=None,
)
DARTS_V1 = Genotype(
normal=[
(('sep_conv_3x3', 1), ('sep_conv_3x3', 0)), # step 1
(('skip_connect', 0), ('sep_conv_3x3', 1)), # step 2
(('skip_connect', 0), ('sep_conv_3x3', 1)), # step 3
(('sep_conv_3x3', 0), ('skip_connect', 2)) # step 4
],
normal_concat=[2, 3, 4, 5],
reduce=[
(('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1
(('skip_connect', 2), ('max_pool_3x3', 0)), # step 2
(('max_pool_3x3', 0), ('skip_connect', 2)), # step 3
(('skip_connect', 2), ('avg_pool_3x3', 0)) # step 4
],
reduce_concat=[2, 3, 4, 5],
connectN=None, connects=None,
)
# DARTS: Differentiable Architecture Search, ICLR 2019
DARTS_V2 = Genotype(
normal=[
(('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 1
(('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 2
(('sep_conv_3x3', 1), ('skip_connect', 0)), # step 3
(('skip_connect', 0), ('dil_conv_3x3', 2)) # step 4
],
normal_concat=[2, 3, 4, 5],
reduce=[
(('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1
(('skip_connect', 2), ('max_pool_3x3', 1)), # step 2
(('max_pool_3x3', 0), ('skip_connect', 2)), # step 3
(('skip_connect', 2), ('max_pool_3x3', 1)) # step 4
],
reduce_concat=[2, 3, 4, 5],
connectN=None, connects=None,
)
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019
SETN = Genotype(
normal=[
(('skip_connect', 0), ('sep_conv_5x5', 1)),
(('sep_conv_5x5', 0), ('sep_conv_3x3', 1)),
(('sep_conv_5x5', 1), ('sep_conv_5x5', 3)),
(('max_pool_3x3', 1), ('conv_3x1_1x3', 4))],
normal_concat=[2, 3, 4, 5],
reduce=[
(('sep_conv_3x3', 0), ('sep_conv_5x5', 1)),
(('avg_pool_3x3', 0), ('sep_conv_5x5', 1)),
(('avg_pool_3x3', 0), ('sep_conv_5x5', 1)),
(('avg_pool_3x3', 0), ('skip_connect', 1))],
reduce_concat=[2, 3, 4, 5],
connectN=None, connects=None
)
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019
GDAS_V1 = Genotype(
normal=[
(('skip_connect', 0), ('skip_connect', 1)),
(('skip_connect', 0), ('sep_conv_5x5', 2)),
(('sep_conv_3x3', 3), ('skip_connect', 0)),
(('sep_conv_5x5', 4), ('sep_conv_3x3', 3))],
normal_concat=[2, 3, 4, 5],
reduce=[
(('sep_conv_5x5', 0), ('sep_conv_3x3', 1)),
(('sep_conv_5x5', 2), ('sep_conv_5x5', 1)),
(('dil_conv_5x5', 2), ('sep_conv_3x3', 1)),
(('sep_conv_5x5', 0), ('sep_conv_5x5', 1))],
reduce_concat=[2, 3, 4, 5],
connectN=None, connects=None
)
Networks = {'DARTS_V1': DARTS_V1,
'DARTS_V2': DARTS_V2,
'DARTS' : DARTS_V2,
'NASNet' : NASNet,
'GDAS_V1' : GDAS_V1,
'PNASNet' : PNASNet,
'SETN' : SETN,
}

View File

@@ -0,0 +1,65 @@
import torch
import torch.nn as nn
class ImageNetHEAD(nn.Sequential):
def __init__(self, C, stride=2):
super(ImageNetHEAD, self).__init__()
self.add_module('conv1', nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False))
self.add_module('bn1' , nn.BatchNorm2d(C // 2))
self.add_module('relu1', nn.ReLU(inplace=True))
self.add_module('conv2', nn.Conv2d(C // 2, C, kernel_size=3, stride=stride, padding=1, bias=False))
self.add_module('bn2' , nn.BatchNorm2d(C))
class CifarHEAD(nn.Sequential):
def __init__(self, C):
super(CifarHEAD, self).__init__()
self.add_module('conv', nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False))
self.add_module('bn', nn.BatchNorm2d(C))
class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x
class AuxiliaryHeadImageNet(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 14x14"""
super(AuxiliaryHeadImageNet, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x

View File

@@ -0,0 +1,16 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
def obtain_nas_infer_model(config):
if config.arch == 'dxys':
from .DXYs import CifarNet, ImageNet, Networks
genotype = Networks[config.genotype]
if config.dataset == 'cifar':
return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num)
elif config.dataset == 'imagenet':
return ImageNet(config.ichannel, config.layers, config.auxiliary, genotype, config.class_num)
else: raise ValueError('invalid dataset : {:}'.format(config.dataset))
else:
raise ValueError('invalid nas arch type : {:}'.format(config.arch))

View File

@@ -0,0 +1,180 @@
import torch
import torch.nn as nn
OPS = {
'none' : lambda C_in, C_out, stride, affine: Zero(stride),
'avg_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg'),
'max_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'max'),
'nor_conv_7x7' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), affine),
'nor_conv_3x3' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), affine),
'nor_conv_1x1' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), affine),
'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine),
'sep_conv_3x3' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 3, stride, 1, affine=affine),
'sep_conv_5x5' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 5, stride, 2, affine=affine),
'sep_conv_7x7' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 7, stride, 3, affine=affine),
'dil_conv_3x3' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 3, stride, 2, 2, affine=affine),
'dil_conv_5x5' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 5, stride, 4, 2, affine=affine),
'conv_7x1_1x7' : lambda C_in, C_out, stride, affine: Conv717(C_in, C_out, stride, affine),
'conv_3x1_1x3' : lambda C_in, C_out, stride, affine: Conv313(C_in, C_out, stride, affine)
}
class POOLING(nn.Module):
def __init__(self, C_in, C_out, stride, mode):
super(POOLING, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0)
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
def forward(self, inputs):
if self.preprocess is not None:
x = self.preprocess(inputs)
else: x = inputs
return self.op(x)
class Conv313(nn.Module):
def __init__(self, C_in, C_out, stride, affine):
super(Conv313, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in , C_out, (1,3), stride=(1, stride), padding=(0, 1), bias=False),
nn.Conv2d(C_out, C_out, (3,1), stride=(stride, 1), padding=(1, 0), bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.op(x)
class Conv717(nn.Module):
def __init__(self, C_in, C_out, stride, affine):
super(Conv717, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in , C_out, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
nn.Conv2d(C_out, C_out, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.op(x)
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.op(x)
class DilConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super(DilConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x):
return self.op(x)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride= 1, padding=padding, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x):
return self.op(x)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Zero(nn.Module):
def __init__(self, stride):
super(Zero, self).__init__()
self.stride = stride
def forward(self, x):
if self.stride == 1:
return x.mul(0.)
return x[:,:,::self.stride,::self.stride].mul(0.)
def extra_repr(self):
return 'stride={stride}'.format(**self.__dict__)
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, stride, affine=True):
super(FactorizedReduce, self).__init__()
self.stride = stride
self.C_in = C_in
self.C_out = C_out
self.relu = nn.ReLU(inplace=False)
if stride == 2:
#assert C_out % 2 == 0, 'C_out : {:}'.format(C_out)
C_outs = [C_out // 2, C_out - C_out // 2]
self.convs = nn.ModuleList()
for i in range(2):
self.convs.append( nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False) )
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
elif stride == 4:
assert C_out % 4 == 0, 'C_out : {:}'.format(C_out)
self.convs = nn.ModuleList()
for i in range(4):
self.convs.append( nn.Conv2d(C_in, C_out // 4, 1, stride=stride, padding=0, bias=False) )
self.pad = nn.ConstantPad2d((0, 3, 0, 3), 0)
else:
raise ValueError('Invalid stride : {:}'.format(stride))
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
x = self.relu(x)
y = self.pad(x)
if self.stride == 2:
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1)
else:
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:-2,1:-2]),
self.convs[2](y[:,:,2:-1,2:-1]), self.convs[3](y[:,:,3:,3:])], dim=1)
out = self.bn(out)
return out
def extra_repr(self):
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)

View File

@@ -1,9 +0,0 @@
# utils
from .utils import batchify, get_batch, repackage_hidden
# models
from .model_search import RNNModelSearch
from .model_search import DARTSCellSearch
from .basemodel import DARTSCell, RNNModel
# architecture
from .genotypes import DARTS_V1, DARTS_V2
from .genotypes import GDAS

View File

@@ -1,181 +0,0 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .genotypes import STEPS
from .utils import mask2d, LockedDropout, embedded_dropout
INITRANGE = 0.04
def none_func(x):
return x * 0
class DARTSCell(nn.Module):
def __init__(self, ninp, nhid, dropouth, dropoutx, genotype):
super(DARTSCell, self).__init__()
self.nhid = nhid
self.dropouth = dropouth
self.dropoutx = dropoutx
self.genotype = genotype
# genotype is None when doing arch search
steps = len(self.genotype.recurrent) if self.genotype is not None else STEPS
self._W0 = nn.Parameter(torch.Tensor(ninp+nhid, 2*nhid).uniform_(-INITRANGE, INITRANGE))
self._Ws = nn.ParameterList([
nn.Parameter(torch.Tensor(nhid, 2*nhid).uniform_(-INITRANGE, INITRANGE)) for i in range(steps)
])
def forward(self, inputs, hidden, arch_probs):
T, B = inputs.size(0), inputs.size(1)
if self.training:
x_mask = mask2d(B, inputs.size(2), keep_prob=1.-self.dropoutx)
h_mask = mask2d(B, hidden.size(2), keep_prob=1.-self.dropouth)
else:
x_mask = h_mask = None
hidden = hidden[0]
hiddens = []
for t in range(T):
hidden = self.cell(inputs[t], hidden, x_mask, h_mask, arch_probs)
hiddens.append(hidden)
hiddens = torch.stack(hiddens)
return hiddens, hiddens[-1].unsqueeze(0)
def _compute_init_state(self, x, h_prev, x_mask, h_mask):
if self.training:
xh_prev = torch.cat([x * x_mask, h_prev * h_mask], dim=-1)
else:
xh_prev = torch.cat([x, h_prev], dim=-1)
c0, h0 = torch.split(xh_prev.mm(self._W0), self.nhid, dim=-1)
c0 = c0.sigmoid()
h0 = h0.tanh()
s0 = h_prev + c0 * (h0-h_prev)
return s0
def _get_activation(self, name):
if name == 'tanh':
f = torch.tanh
elif name == 'relu':
f = torch.relu
elif name == 'sigmoid':
f = torch.sigmoid
elif name == 'identity':
f = lambda x: x
elif name == 'none':
f = none_func
else:
raise NotImplementedError
return f
def cell(self, x, h_prev, x_mask, h_mask, _):
s0 = self._compute_init_state(x, h_prev, x_mask, h_mask)
states = [s0]
for i, (name, pred) in enumerate(self.genotype.recurrent):
s_prev = states[pred]
if self.training:
ch = (s_prev * h_mask).mm(self._Ws[i])
else:
ch = s_prev.mm(self._Ws[i])
c, h = torch.split(ch, self.nhid, dim=-1)
c = c.sigmoid()
fn = self._get_activation(name)
h = fn(h)
s = s_prev + c * (h-s_prev)
states += [s]
output = torch.mean(torch.stack([states[i] for i in self.genotype.concat], -1), -1)
return output
class RNNModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
def __init__(self, ntoken, ninp, nhid, nhidlast,
dropout=0.5, dropouth=0.5, dropoutx=0.5, dropouti=0.5, dropoute=0.1,
cell_cls=None, genotype=None):
super(RNNModel, self).__init__()
self.lockdrop = LockedDropout()
self.encoder = nn.Embedding(ntoken, ninp)
assert ninp == nhid == nhidlast
if cell_cls == DARTSCell:
assert genotype is not None
rnns = [cell_cls(ninp, nhid, dropouth, dropoutx, genotype)]
else:
assert genotype is None
rnns = [cell_cls(ninp, nhid, dropouth, dropoutx)]
self.rnns = torch.nn.ModuleList(rnns)
self.decoder = nn.Linear(ninp, ntoken)
self.decoder.weight = self.encoder.weight
self.init_weights()
self.arch_weights = None
self.ninp = ninp
self.nhid = nhid
self.nhidlast = nhidlast
self.dropout = dropout
self.dropouti = dropouti
self.dropoute = dropoute
self.ntoken = ntoken
self.cell_cls = cell_cls
# acceleration
self.tau = None
self.use_gumbel = False
def set_gumbel(self, use_gumbel, set_check):
self.use_gumbel = use_gumbel
for i, rnn in enumerate(self.rnns):
rnn.set_check(set_check)
def set_tau(self, tau):
self.tau = tau
def get_tau(self):
return self.tau
def init_weights(self):
self.encoder.weight.data.uniform_(-INITRANGE, INITRANGE)
self.decoder.bias.data.fill_(0)
self.decoder.weight.data.uniform_(-INITRANGE, INITRANGE)
def forward(self, input, hidden, return_h=False):
batch_size = input.size(1)
emb = embedded_dropout(self.encoder, input, dropout=self.dropoute if self.training else 0)
emb = self.lockdrop(emb, self.dropouti)
raw_output = emb
new_hidden = []
raw_outputs = []
outputs = []
if self.arch_weights is None:
arch_probs = None
else:
if self.use_gumbel: arch_probs = F.gumbel_softmax(self.arch_weights, self.tau, False)
else : arch_probs = F.softmax(self.arch_weights, dim=-1)
for l, rnn in enumerate(self.rnns):
current_input = raw_output
raw_output, new_h = rnn(raw_output, hidden[l], arch_probs)
new_hidden.append(new_h)
raw_outputs.append(raw_output)
hidden = new_hidden
output = self.lockdrop(raw_output, self.dropout)
outputs.append(output)
logit = self.decoder(output.view(-1, self.ninp))
log_prob = nn.functional.log_softmax(logit, dim=-1)
model_output = log_prob
model_output = model_output.view(-1, batch_size, self.ntoken)
if return_h: return model_output, hidden, raw_outputs, outputs
else : return model_output, hidden
def init_hidden(self, bsz):
weight = next(self.parameters()).clone()
return [weight.new(1, bsz, self.nhid).zero_()]

View File

@@ -1,55 +0,0 @@
from collections import namedtuple
Genotype = namedtuple('Genotype', 'recurrent concat')
PRIMITIVES = [
'none',
'tanh',
'relu',
'sigmoid',
'identity'
]
STEPS = 8
CONCAT = 8
ENAS = Genotype(
recurrent = [
('tanh', 0),
('tanh', 1),
('relu', 1),
('tanh', 3),
('tanh', 3),
('relu', 3),
('relu', 4),
('relu', 7),
('relu', 8),
('relu', 8),
('relu', 8),
],
concat = [2, 5, 6, 9, 10, 11]
)
DARTS_V1 = Genotype(
recurrent = [
('relu', 0),
('relu', 1),
('tanh', 2),
('relu', 3), ('relu', 4), ('identity', 1), ('relu', 5), ('relu', 1)
],
concat=range(1, 9)
)
DARTS_V2 = Genotype(
recurrent = [
('sigmoid', 0), ('relu', 1), ('relu', 1),
('identity', 1), ('tanh', 2), ('sigmoid', 5),
('tanh', 3), ('relu', 5)
],
concat=range(1, 9)
)
GDAS = Genotype(
recurrent=[('relu', 0), ('relu', 0), ('identity', 1), ('relu', 1), ('tanh', 0), ('relu', 2), ('identity', 4), ('identity', 2)],
concat=range(1, 9)
)

View File

@@ -1,104 +0,0 @@
import copy, torch
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple
from .genotypes import PRIMITIVES, STEPS, CONCAT, Genotype
from .basemodel import DARTSCell, RNNModel
class DARTSCellSearch(DARTSCell):
def __init__(self, ninp, nhid, dropouth, dropoutx):
super(DARTSCellSearch, self).__init__(ninp, nhid, dropouth, dropoutx, genotype=None)
self.bn = nn.BatchNorm1d(nhid, affine=False)
self.check_zero = False
def set_check(self, check_zero):
self.check_zero = check_zero
def cell(self, x, h_prev, x_mask, h_mask, arch_probs):
s0 = self._compute_init_state(x, h_prev, x_mask, h_mask)
s0 = self.bn(s0)
if self.check_zero:
arch_probs_cpu = arch_probs.cpu().tolist()
#arch_probs = F.softmax(self.weights, dim=-1)
offset = 0
states = s0.unsqueeze(0)
for i in range(STEPS):
if self.training:
masked_states = states * h_mask.unsqueeze(0)
else:
masked_states = states
ch = masked_states.view(-1, self.nhid).mm(self._Ws[i]).view(i+1, -1, 2*self.nhid)
c, h = torch.split(ch, self.nhid, dim=-1)
c = c.sigmoid()
s = torch.zeros_like(s0)
for k, name in enumerate(PRIMITIVES):
if name == 'none':
continue
fn = self._get_activation(name)
unweighted = states + c * (fn(h) - states)
if self.check_zero:
INDEX, INDDX = [], []
for jj in range(offset, offset+i+1):
if arch_probs_cpu[jj][k] > 0:
INDEX.append(jj)
INDDX.append(jj-offset)
if len(INDEX) == 0: continue
s += torch.sum(arch_probs[INDEX, k].unsqueeze(-1).unsqueeze(-1) * unweighted[INDDX, :, :], dim=0)
else:
s += torch.sum(arch_probs[offset:offset+i+1, k].unsqueeze(-1).unsqueeze(-1) * unweighted, dim=0)
s = self.bn(s)
states = torch.cat([states, s.unsqueeze(0)], 0)
offset += i+1
output = torch.mean(states[-CONCAT:], dim=0)
return output
class RNNModelSearch(RNNModel):
def __init__(self, *args):
super(RNNModelSearch, self).__init__(*args)
self._args = copy.deepcopy( args )
k = sum(i for i in range(1, STEPS+1))
self.arch_weights = nn.Parameter(torch.Tensor(k, len(PRIMITIVES)))
nn.init.normal_(self.arch_weights, 0, 0.001)
def base_parameters(self):
lists = list(self.lockdrop.parameters())
lists += list(self.encoder.parameters())
lists += list(self.rnns.parameters())
lists += list(self.decoder.parameters())
return lists
def arch_parameters(self):
return [self.arch_weights]
def genotype(self):
def _parse(probs):
gene = []
start = 0
for i in range(STEPS):
end = start + i + 1
W = probs[start:end].copy()
#j = sorted(range(i + 1), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[0]
j = sorted(range(i + 1), key=lambda x: -max(W[x][k] for k in range(len(W[x])) ))[0]
k_best = None
for k in range(len(W[j])):
#if k != PRIMITIVES.index('none'):
# if k_best is None or W[j][k] > W[j][k_best]:
# k_best = k
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
gene.append((PRIMITIVES[k_best], j))
start = end
return gene
with torch.no_grad():
gene = _parse(F.softmax(self.arch_weights, dim=-1).cpu().numpy())
genotype = Genotype(recurrent=gene, concat=list(range(STEPS+1)[-CONCAT:]))
return genotype

View File

@@ -1,66 +0,0 @@
import torch
import torch.nn as nn
import os, shutil
import numpy as np
def repackage_hidden(h):
if isinstance(h, torch.Tensor):
return h.detach()
else:
return tuple(repackage_hidden(v) for v in h)
def batchify(data, bsz, use_cuda):
nbatch = data.size(0) // bsz
data = data.narrow(0, 0, nbatch * bsz)
data = data.view(bsz, -1).t().contiguous()
if use_cuda: return data.cuda()
else : return data
def get_batch(source, i, seq_len):
seq_len = min(seq_len, len(source) - 1 - i)
data = source[i:i+seq_len].clone()
target = source[i+1:i+1+seq_len].clone()
return data, target
def embedded_dropout(embed, words, dropout=0.1, scale=None):
if dropout:
mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout)
mask.requires_grad_(True)
masked_embed_weight = mask * embed.weight
else:
masked_embed_weight = embed.weight
if scale:
masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight
padding_idx = embed.padding_idx
if padding_idx is None:
padding_idx = -1
X = torch.nn.functional.embedding(
words, masked_embed_weight,
padding_idx, embed.max_norm, embed.norm_type,
embed.scale_grad_by_freq, embed.sparse)
return X
class LockedDropout(nn.Module):
def __init__(self):
super(LockedDropout, self).__init__()
def forward(self, x, dropout=0.5):
if not self.training or not dropout:
return x
m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout)
mask = m.div_(1 - dropout).detach()
mask = mask.expand_as(x)
return mask * x
def mask2d(B, D, keep_prob, cuda=True):
m = torch.floor(torch.rand(B, D) + keep_prob) / keep_prob
if cuda: return m.cuda()
else : return m

View File

@@ -0,0 +1,22 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .starts import prepare_seed, prepare_logger, get_machine_info, save_checkpoint, copy_checkpoint
from .optimizers import get_optim_scheduler
def get_procedures(procedure):
from .basic_main import basic_train, basic_valid
from .search_main import search_train, search_valid
from .search_main_v2 import search_train_v2
from .simple_KD_main import simple_KD_train, simple_KD_valid
train_funcs = {'basic' : basic_train, \
'search': search_train,'Simple-KD': simple_KD_train, \
'search-v2': search_train_v2}
valid_funcs = {'basic' : basic_valid, \
'search': search_valid,'Simple-KD': simple_KD_valid, \
'search-v2': search_valid}
train_func = train_funcs[procedure]
valid_func = valid_funcs[procedure]
return train_func, valid_func

View File

@@ -0,0 +1,75 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, torch
from log_utils import AverageMeter, time_string
from utils import obtain_accuracy
def basic_train(xloader, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger):
loss, acc1, acc5 = procedure(xloader, network, criterion, scheduler, optimizer, 'train', optim_config, extra_info, print_freq, logger)
return loss, acc1, acc5
def basic_valid(xloader, network, criterion, optim_config, extra_info, print_freq, logger):
with torch.no_grad():
loss, acc1, acc5 = procedure(xloader, network, criterion, None, None, 'valid', None, extra_info, print_freq, logger)
return loss, acc1, acc5
def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
if mode == 'train':
network.train()
elif mode == 'valid':
network.eval()
else: raise ValueError("The mode is not right : {:}".format(mode))
#logger.log('[{:5s}] config :: auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message()))
logger.log('[{:5s}] config :: auxiliary={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1))
end = time.time()
for i, (inputs, targets) in enumerate(xloader):
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
# measure data loading time
data_time.update(time.time() - end)
# calculate prediction and loss
targets = targets.cuda(non_blocking=True)
if mode == 'train': optimizer.zero_grad()
features, logits = network(inputs)
if isinstance(logits, list):
assert len(logits) == 2, 'logits must has {:} items instead of {:}'.format(2, len(logits))
logits, logits_aux = logits
else:
logits, logits_aux = logits, None
loss = criterion(logits, targets)
if config is not None and hasattr(config, 'auxiliary') and config.auxiliary > 0:
loss_aux = criterion(logits_aux, targets)
loss += config.auxiliary * loss_aux
if mode == 'train':
loss.backward()
optimizer.step()
# record
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update (prec1.item(), inputs.size(0))
top5.update (prec5.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % print_freq == 0 or (i+1) == len(xloader):
Sstr = ' {:5s} '.format(mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader))
if scheduler is not None:
Sstr += ' {:}'.format(scheduler.get_min_info())
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5)
Istr = 'Size={:}'.format(list(inputs.size()))
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)
logger.log(' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(mode=mode.upper(), top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg))
return losses.avg, top1.avg, top5.avg

View File

@@ -0,0 +1,201 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, torch
import torch.nn as nn
from bisect import bisect_right
from torch.optim import Optimizer
class _LRScheduler(object):
def __init__(self, optimizer, warmup_epochs, epochs):
if not isinstance(optimizer, Optimizer):
raise TypeError('{:} is not an Optimizer'.format(type(optimizer).__name__))
self.optimizer = optimizer
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
self.max_epochs = epochs
self.warmup_epochs = warmup_epochs
self.current_epoch = 0
self.current_iter = 0
def extra_repr(self):
return ''
def __repr__(self):
return ('{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}'.format(name=self.__class__.__name__, **self.__dict__)
+ ', {:})'.format(self.extra_repr()))
def state_dict(self):
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def get_lr(self):
raise NotImplementedError
def get_min_info(self):
lrs = self.get_lr()
return '#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#'.format(min(lrs), max(lrs), self.current_epoch, self.current_iter)
def get_min_lr(self):
return min( self.get_lr() )
def update(self, cur_epoch, cur_iter):
if cur_epoch is not None:
assert isinstance(cur_epoch, int) and cur_epoch>=0, 'invalid cur-epoch : {:}'.format(cur_epoch)
self.current_epoch = cur_epoch
if cur_iter is not None:
assert isinstance(cur_iter, float) and cur_iter>=0, 'invalid cur-iter : {:}'.format(cur_iter)
self.current_iter = cur_iter
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
class CosineAnnealingLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min):
self.T_max = T_max
self.eta_min = eta_min
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, T-max={:}, eta-min={:}'.format('cosine', self.T_max, self.eta_min)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
if last_epoch < self.T_max:
lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2
else:
lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class MultiStepLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
assert len(milestones) == len(gammas), 'invalid {:} vs {:}'.format(len(milestones), len(gammas))
self.milestones = milestones
self.gammas = gammas
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, milestones={:}, gammas={:}, base-lrs={:}'.format('multistep', self.milestones, self.gammas, self.base_lrs)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
idx = bisect_right(self.milestones, last_epoch)
lr = base_lr
for x in self.gammas[:idx]: lr *= x
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class ExponentialLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, gamma):
self.gamma = gamma
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, gamma={:}, base-lrs={:}'.format('exponential', self.gamma, self.base_lrs)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
lr = base_lr * (self.gamma ** last_epoch)
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class LinearLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR):
self.max_LR = max_LR
self.min_LR = min_LR
super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, max_LR={:}, min_LR={:}, base-lrs={:}'.format('LinearLR', self.max_LR, self.min_LR, self.base_lrs)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR
lr = base_lr * (1-ratio)
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class CrossEntropyLabelSmooth(nn.Module):
def __init__(self, num_classes, epsilon):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
log_probs = self.logsoftmax(inputs)
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (-targets * log_probs).mean(0).sum()
return loss
def get_optim_scheduler(parameters, config):
assert hasattr(config, 'optim') and hasattr(config, 'scheduler') and hasattr(config, 'criterion'), 'config must have optim / scheduler / criterion keys instead of {:}'.format(config)
if config.optim == 'SGD':
optim = torch.optim.SGD(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov)
elif config.optim == 'RMSprop':
optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay)
else:
raise ValueError('invalid optim : {:}'.format(config.optim))
if config.scheduler == 'cos':
T_max = getattr(config, 'T_max', config.epochs)
scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min)
elif config.scheduler == 'multistep':
scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas)
elif config.scheduler == 'exponential':
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
elif config.scheduler == 'linear':
scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min)
else:
raise ValueError('invalid scheduler : {:}'.format(config.scheduler))
if config.criterion == 'Softmax':
criterion = torch.nn.CrossEntropyLoss()
elif config.criterion == 'SmoothSoftmax':
criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth)
else:
raise ValueError('invalid criterion : {:}'.format(config.criterion))
return optim, scheduler, criterion

View File

@@ -0,0 +1,126 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, torch
from log_utils import AverageMeter, time_string
from utils import obtain_accuracy
from models import change_key
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
expected_flop = torch.mean( expected_flop )
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
loss = - torch.log( expected_flop )
#elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
elif flop_cur > flop_need: # Too Large FLOP
loss = torch.log( expected_flop )
else: # Required FLOP
loss = None
if loss is None: return 0, 0
else : return loss, loss.item()
def search_train(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, extra_info, print_freq, logger):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
epoch_str, flop_need, flop_weight, flop_tolerant = extra_info['epoch-str'], extra_info['FLOP-exp'], extra_info['FLOP-weight'], extra_info['FLOP-tolerant']
network.train()
logger.log('[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format(epoch_str, flop_need, flop_weight))
end = time.time()
network.apply( change_key('search_mode', 'search') )
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
scheduler.update(None, 1.0 * step / len(search_loader))
# calculate prediction and loss
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# update the weights
base_optimizer.zero_grad()
logits, expected_flop = network(base_inputs)
#network.apply( change_key('search_mode', 'basic') )
#features, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
base_optimizer.step()
# record
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
top1.update (prec1.item(), base_inputs.size(0))
top5.update (prec5.item(), base_inputs.size(0))
# update the architecture
arch_optimizer.zero_grad()
logits, expected_flop = network(arch_inputs)
flop_cur = network.module.get_flop('genotype', None, None)
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
acls_loss = criterion(logits, arch_targets)
arch_loss = acls_loss + flop_loss * flop_weight
arch_loss.backward()
arch_optimizer.step()
# record
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
arch_cls_losses.update (acls_loss.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or (step+1) == len(search_loader):
Sstr = '**TRAIN** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(search_loader))
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=base_losses, top1=top1, top5=top5)
Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format(aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses)
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr)
#Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
#logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
#print(network.module.get_arch_info())
#print(network.module.width_attentions[0])
#print(network.module.width_attentions[1])
logger.log(' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, baseloss=base_losses.avg, archloss=arch_losses.avg))
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
network.eval()
network.apply( change_key('search_mode', 'search') )
end = time.time()
#logger.log('Starting evaluating {:}'.format(epoch_info))
with torch.no_grad():
for i, (inputs, targets) in enumerate(xloader):
# measure data loading time
data_time.update(time.time() - end)
# calculate prediction and loss
targets = targets.cuda(non_blocking=True)
logits, expected_flop = network(inputs)
loss = criterion(logits, targets)
# record
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update (prec1.item(), inputs.size(0))
top5.update (prec5.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % print_freq == 0 or (i+1) == len(xloader):
Sstr = '**VALID** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader))
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5)
Istr = 'Size={:}'.format(list(inputs.size()))
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)
logger.log(' **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg))
return losses.avg, top1.avg, top5.avg

View File

@@ -0,0 +1,87 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, torch
from log_utils import AverageMeter, time_string
from utils import obtain_accuracy
from models import change_key
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
expected_flop = torch.mean( expected_flop )
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
loss = - torch.log( expected_flop )
#elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
elif flop_cur > flop_need: # Too Large FLOP
loss = torch.log( expected_flop )
else: # Required FLOP
loss = None
if loss is None: return 0, 0
else : return loss, loss.item()
def search_train_v2(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, extra_info, print_freq, logger):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
epoch_str, flop_need, flop_weight, flop_tolerant = extra_info['epoch-str'], extra_info['FLOP-exp'], extra_info['FLOP-weight'], extra_info['FLOP-tolerant']
network.train()
logger.log('[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format(epoch_str, flop_need, flop_weight))
end = time.time()
network.apply( change_key('search_mode', 'search') )
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
scheduler.update(None, 1.0 * step / len(search_loader))
# calculate prediction and loss
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# update the weights
base_optimizer.zero_grad()
logits, expected_flop = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
base_optimizer.step()
# record
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
top1.update (prec1.item(), base_inputs.size(0))
top5.update (prec5.item(), base_inputs.size(0))
# update the architecture
arch_optimizer.zero_grad()
logits, expected_flop = network(arch_inputs)
flop_cur = network.module.get_flop('genotype', None, None)
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
acls_loss = criterion(logits, arch_targets)
arch_loss = acls_loss + flop_loss * flop_weight
arch_loss.backward()
arch_optimizer.step()
# record
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
arch_cls_losses.update (acls_loss.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or (step+1) == len(search_loader):
Sstr = '**TRAIN** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(search_loader))
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=base_losses, top1=top1, top5=top5)
Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format(aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses)
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr)
#num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0
#logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' GPU={:.2f}MB'.format(num_bytes/1e6))
#Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
#logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
#print(network.module.get_arch_info())
#print(network.module.width_attentions[0])
#print(network.module.width_attentions[1])
logger.log(' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, baseloss=base_losses.avg, archloss=arch_losses.avg))
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg

View File

@@ -0,0 +1,94 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, torch
import torch.nn.functional as F
# our modules
from log_utils import AverageMeter, time_string
from utils import obtain_accuracy
def simple_KD_train(xloader, teacher, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger):
loss, acc1, acc5 = procedure(xloader, teacher, network, criterion, scheduler, optimizer, 'train', optim_config, extra_info, print_freq, logger)
return loss, acc1, acc5
def simple_KD_valid(xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger):
with torch.no_grad():
loss, acc1, acc5 = procedure(xloader, teacher, network, criterion, None, None, 'valid', optim_config, extra_info, print_freq, logger)
return loss, acc1, acc5
def loss_KD_fn(criterion, student_logits, teacher_logits, studentFeatures, teacherFeatures, targets, alpha, temperature):
basic_loss = criterion(student_logits, targets) * (1. - alpha)
log_student= F.log_softmax(student_logits / temperature, dim=1)
sof_teacher= F.softmax (teacher_logits / temperature, dim=1)
KD_loss = F.kl_div(log_student, sof_teacher, reduction='batchmean') * (alpha * temperature * temperature)
return basic_loss + KD_loss
def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
Ttop1, Ttop5 = AverageMeter(), AverageMeter()
if mode == 'train':
network.train()
elif mode == 'valid':
network.eval()
else: raise ValueError("The mode is not right : {:}".format(mode))
teacher.eval()
logger.log('[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, config.KD_alpha, config.KD_temperature))
end = time.time()
for i, (inputs, targets) in enumerate(xloader):
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
# measure data loading time
data_time.update(time.time() - end)
# calculate prediction and loss
targets = targets.cuda(non_blocking=True)
if mode == 'train': optimizer.zero_grad()
student_f, logits = network(inputs)
if isinstance(logits, list):
assert len(logits) == 2, 'logits must has {:} items instead of {:}'.format(2, len(logits))
logits, logits_aux = logits
else:
logits, logits_aux = logits, None
with torch.no_grad():
teacher_f, teacher_logits = teacher(inputs)
loss = loss_KD_fn(criterion, logits, teacher_logits, student_f, teacher_f, targets, config.KD_alpha, config.KD_temperature)
if config is not None and hasattr(config, 'auxiliary') and config.auxiliary > 0:
loss_aux = criterion(logits_aux, targets)
loss += config.auxiliary * loss_aux
if mode == 'train':
loss.backward()
optimizer.step()
# record
sprec1, sprec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update (sprec1.item(), inputs.size(0))
top5.update (sprec5.item(), inputs.size(0))
# teacher
tprec1, tprec5 = obtain_accuracy(teacher_logits.data, targets.data, topk=(1, 5))
Ttop1.update (tprec1.item(), inputs.size(0))
Ttop5.update (tprec5.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % print_freq == 0 or (i+1) == len(xloader):
Sstr = ' {:5s} '.format(mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader))
if scheduler is not None:
Sstr += ' {:}'.format(scheduler.get_min_info())
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5)
Lstr+= ' Teacher : acc@1={:.2f}, acc@5={:.2f}'.format(Ttop1.avg, Ttop5.avg)
Istr = 'Size={:}'.format(list(inputs.size()))
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)
logger.log(' **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}'.format(mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg))
logger.log(' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(mode=mode.upper(), top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg))
return losses.avg, top1.avg, top5.avg

67
lib/procedures/starts.py Normal file
View File

@@ -0,0 +1,67 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os, sys, time, torch, random, PIL, copy, numpy as np
from os import path as osp
from shutil import copyfile
def prepare_seed(rand_seed):
random.seed(rand_seed)
np.random.seed(rand_seed)
torch.manual_seed(rand_seed)
torch.cuda.manual_seed(rand_seed)
torch.cuda.manual_seed_all(rand_seed)
def prepare_logger(xargs):
args = copy.deepcopy( xargs )
from log_utils import Logger
logger = Logger(args.save_dir, args.rand_seed)
logger.log('Main Function with logger : {:}'.format(logger))
logger.log('Arguments : -------------------------------')
for name, value in args._get_kwargs():
logger.log('{:16} : {:}'.format(name, value))
logger.log("Python Version : {:}".format(sys.version.replace('\n', ' ')))
logger.log("Pillow Version : {:}".format(PIL.__version__))
logger.log("PyTorch Version : {:}".format(torch.__version__))
logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
logger.log("CUDA_VISIBLE_DEVICES : {:}".format(os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'None'))
return logger
def get_machine_info():
info = "Python Version : {:}".format(sys.version.replace('\n', ' '))
info+= "\nPillow Version : {:}".format(PIL.__version__)
info+= "\nPyTorch Version : {:}".format(torch.__version__)
info+= "\ncuDNN Version : {:}".format(torch.backends.cudnn.version())
info+= "\nCUDA available : {:}".format(torch.cuda.is_available())
info+= "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count())
if 'CUDA_VISIBLE_DEVICES' in os.environ:
info+= "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ['CUDA_VISIBLE_DEVICES'])
else:
info+= "\nDoes not set CUDA_VISIBLE_DEVICES"
return info
def save_checkpoint(state, filename, logger):
if osp.isfile(filename):
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(filename))
os.remove(filename)
torch.save(state, filename)
assert osp.isfile(filename), 'save filename : {:} failed, which is not found.'.format(filename)
if hasattr(logger, 'log'): logger.log('save checkpoint into {:}'.format(filename))
return filename
def copy_checkpoint(src, dst, logger):
if osp.isfile(dst):
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(dst))
os.remove(dst)
copyfile(src, dst)
if hasattr(logger, 'log'): logger.log('copy the file from {:} into {:}'.format(src, dst))

View File

@@ -1,5 +0,0 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .utils import load_config
from .scheduler import MultiStepLR, obtain_scheduler

View File

@@ -1,32 +0,0 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
from bisect import bisect_right
class MultiStepLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, milestones, gammas, last_epoch=-1):
if not list(milestones) == sorted(milestones):
raise ValueError('Milestones should be a list of'
' increasing integers. Got {:}', milestones)
assert len(milestones) == len(gammas), '{:} vs {:}'.format(milestones, gammas)
self.milestones = milestones
self.gammas = gammas
super(MultiStepLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
LR = 1
for x in self.gammas[:bisect_right(self.milestones, self.last_epoch)]: LR = LR * x
return [base_lr * LR for base_lr in self.base_lrs]
def obtain_scheduler(config, optimizer):
if config.type == 'multistep':
scheduler = MultiStepLR(optimizer, milestones=config.milestones, gammas=config.gammas)
elif config.type == 'cosine':
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.epochs)
else:
raise ValueError('Unknown learning rate scheduler type : {:}'.format(config.type))
return scheduler

View File

@@ -1,42 +0,0 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, json
from pathlib import Path
from collections import namedtuple
support_types = ('str', 'int', 'bool', 'float')
def convert_param(original_lists):
assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists)
ctype, value = original_lists[0], original_lists[1]
assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types)
is_list = isinstance(value, list)
if not is_list: value = [value]
outs = []
for x in value:
if ctype == 'int':
x = int(x)
elif ctype == 'str':
x = str(x)
elif ctype == 'bool':
x = bool(int(x))
elif ctype == 'float':
x = float(x)
else:
raise TypeError('Does not know this type : {:}'.format(ctype))
outs.append(x)
if not is_list: outs = outs[0]
return outs
def load_config(path):
path = str(path)
assert os.path.exists(path), 'Can not find {:}'.format(path)
# Reading data back
with open(path, 'r') as f:
data = json.load(f)
f.close()
content = { k: convert_param(v) for k,v in data.items()}
Arguments = namedtuple('Configure', ' '.join(content.keys()))
content = Arguments(**content)
return content

View File

@@ -1,16 +1,6 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .utils import AverageMeter, RecorderMeter, convert_secs2time
from .utils import time_file_str, time_string
from .utils import test_imagenet_data
from .utils import print_log
from .evaluation_utils import obtain_accuracy
#from .draw_pts import draw_points
from .gpu_manager import GPUManager
from .save_meta import Save_Meta
from .model_utils import count_parameters_in_MB
from .model_utils import Cutout
from .flop_benchmark import print_FLOPs
from .gpu_manager import GPUManager
from .flop_benchmark import get_model_infos

View File

@@ -1,41 +0,0 @@
import os, sys, time
import numpy as np
import matplotlib
import random
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.cm as cm
def draw_points(points, labels, save_path):
title = 'the visualized features'
dpi = 100
width, height = 1000, 1000
legend_fontsize = 10
figsize = width / float(dpi), height / float(dpi)
fig = plt.figure(figsize=figsize)
classes = np.unique(labels).tolist()
colors = cm.rainbow(np.linspace(0, 1, len(classes)))
legends = []
legendnames = []
for cls, c in zip(classes, colors):
indexes = labels == cls
ptss = points[indexes, :]
x = ptss[:,0]
y = ptss[:,1]
if cls % 2 == 0: marker = 'x'
else: marker = 'o'
legend = plt.scatter(x, y, color=c, s=1, marker=marker)
legendname = '{:02d}'.format(cls+1)
legends.append( legend )
legendnames.append( legendname )
plt.legend(legends, legendnames, scatterpoints=1, ncol=5, fontsize=8)
if save_path is not None:
fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
print ('---- save figure {} into {}'.format(title, save_path))
plt.close(fig)

View File

@@ -3,21 +3,44 @@
##################################################
# modified from https://github.com/warmspringwinds/pytorch-segmentation-detection/blob/master/pytorch_segmentation_detection/utils/flops_benchmark.py
import copy, torch
import torch.nn as nn
import numpy as np
def print_FLOPs(model, shape, logs):
print_log, log = logs
model = copy.deepcopy( model )
def count_parameters_in_MB(model):
if isinstance(model, nn.Module):
return np.sum(np.prod(v.size()) for v in model.parameters())/1e6
else:
return np.sum(np.prod(v.size()) for v in model)/1e6
def get_model_infos(model, shape):
#model = copy.deepcopy( model )
model = add_flops_counting_methods(model)
model = model.cuda()
#model = model.cuda()
model.eval()
cache_inputs = torch.zeros(*shape).cuda()
#cache_inputs = torch.zeros(*shape).cuda()
#cache_inputs = torch.zeros(*shape)
cache_inputs = torch.rand(*shape)
if next(model.parameters()).is_cuda: cache_inputs = cache_inputs.cuda()
#print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log)
_ = model(cache_inputs)
with torch.no_grad():
_____ = model(cache_inputs)
FLOPs = compute_average_flops_cost( model ) / 1e6
print_log('FLOPs : {:} MB'.format(FLOPs), log)
Param = count_parameters_in_MB(model)
if hasattr(model, 'auxiliary_param'):
aux_params = count_parameters_in_MB(model.auxiliary_param())
print ('The auxiliary params of this model is : {:}'.format(aux_params))
print ('We remove the auxiliary params from the total params ({:}) when counting'.format(Param))
Param = Param - aux_params
#print_log('FLOPs : {:} MB'.format(FLOPs), log)
torch.cuda.empty_cache()
model.apply( remove_hook_function )
return FLOPs, Param
# ---- Public functions
@@ -37,8 +60,11 @@ def compute_average_flops_cost(model):
"""
batches_count = model.__batch_counter__
flops_sum = 0
#or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
for module in model.modules():
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
or isinstance(module, torch.nn.Conv1d) \
or hasattr(module, 'calculate_flop_self'):
flops_sum += module.__flops__
return flops_sum / batches_count
@@ -54,6 +80,11 @@ def pool_flops_counter_hook(pool_module, inputs, output):
pool_module.__flops__ += overall_flops
def self_calculate_flops_counter_hook(self_module, inputs, output):
overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape)
self_module.__flops__ += overall_flops
def fc_flops_counter_hook(fc_module, inputs, output):
batch_size = inputs[0].size(0)
xin, xout = fc_module.in_features, fc_module.out_features
@@ -64,7 +95,24 @@ def fc_flops_counter_hook(fc_module, inputs, output):
fc_module.__flops__ += overall_flops
def conv_flops_counter_hook(conv_module, inputs, output):
def conv1d_flops_counter_hook(conv_module, inputs, outputs):
batch_size = inputs[0].size(0)
outL = outputs.shape[-1]
[kernel] = conv_module.kernel_size
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
conv_per_position_flops = kernel * in_channels * out_channels / groups
active_elements_count = batch_size * outL
overall_flops = conv_per_position_flops * active_elements_count
if conv_module.bias is not None:
overall_flops += out_channels * active_elements_count
conv_module.__flops__ += overall_flops
def conv2d_flops_counter_hook(conv_module, inputs, output):
batch_size = inputs[0].size(0)
output_height, output_width = output.shape[2:]
@@ -97,14 +145,20 @@ def add_batch_counter_hook_function(module):
def add_flops_counter_variable_or_reset(module):
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d):
or isinstance(module, torch.nn.Conv1d) \
or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
or hasattr(module, 'calculate_flop_self'):
module.__flops__ = 0
def add_flops_counter_hook_function(module):
if isinstance(module, torch.nn.Conv2d):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(conv_flops_counter_hook)
handle = module.register_forward_hook(conv2d_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.Conv1d):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(conv1d_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.Linear):
if not hasattr(module, '__flops_handle__'):
@@ -114,3 +168,18 @@ def add_flops_counter_hook_function(module):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(pool_flops_counter_hook)
module.__flops_handle__ = handle
elif hasattr(module, 'calculate_flop_self'): # self-defined module
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(self_calculate_flops_counter_hook)
module.__flops_handle__ = handle
def remove_hook_function(module):
hookers = ['__batch_counter_handle__', '__flops_handle__']
for hooker in hookers:
if hasattr(module, hooker):
handle = getattr(module, hooker)
handle.remove()
keys = ['__flops__', '__batch_counter__', '__flops__'] + hookers
for ckey in keys:
if hasattr(module, ckey): delattr(module, ckey)

View File

@@ -1,35 +0,0 @@
import torch
import torch.nn as nn
import numpy as np
def count_parameters_in_MB(model):
if isinstance(model, nn.Module):
return np.sum(np.prod(v.size()) for v in model.parameters())/1e6
else:
return np.sum(np.prod(v.size()) for v in model)/1e6
class Cutout(object):
def __init__(self, length):
self.length = length
def __repr__(self):
return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img

View File

@@ -1,53 +0,0 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
import os, sys
import os.path as osp
import numpy as np
def tensor2np(x):
if isinstance(x, np.ndarray): return x
if x.is_cuda: x = x.cpu()
return x.numpy()
class Save_Meta():
def __init__(self):
self.reset()
def __repr__(self):
return ('{name}'.format(name=self.__class__.__name__)+'(number of data = {})'.format(len(self)))
def reset(self):
self.predictions = []
self.groundtruth = []
def __len__(self):
return len(self.predictions)
def append(self, _pred, _ground):
_pred, _ground = tensor2np(_pred), tensor2np(_ground)
assert _ground.shape[0] == _pred.shape[0] and len(_pred.shape) == 2 and len(_ground.shape) == 1, 'The shapes are wrong : {} & {}'.format(_pred.shape, _ground.shape)
self.predictions.append(_pred)
self.groundtruth.append(_ground)
def save(self, save_dir, filename, test=True):
meta = {'predictions': self.predictions,
'groundtruth': self.groundtruth}
filename = osp.join(save_dir, filename)
torch.save(meta, filename)
if test:
predictions = np.concatenate(self.predictions)
groundtruth = np.concatenate(self.groundtruth)
predictions = np.argmax(predictions, axis=1)
accuracy = np.sum(groundtruth==predictions) * 100.0 / predictions.size
else:
accuracy = None
print ('save save_meta into {} with accuracy = {}'.format(filename, accuracy))
def load(self, filename):
assert os.path.isfile(filename), '{} is not a file'.format(filename)
checkpoint = torch.load(filename)
self.predictions = checkpoint['predictions']
self.groundtruth = checkpoint['groundtruth']

1
lib/xvision/__init__.py Normal file
View File

@@ -0,0 +1 @@
from .affine_utils import normalize_points, denormalize_points

132
lib/xvision/affine_utils.py Normal file
View File

@@ -0,0 +1,132 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
#
# functions for affine transformation
import math, torch
import numpy as np
import torch.nn.functional as F
def identity2affine(full=False):
if not full:
parameters = torch.zeros((2,3))
parameters[0, 0] = parameters[1, 1] = 1
else:
parameters = torch.zeros((3,3))
parameters[0, 0] = parameters[1, 1] = parameters[2, 2] = 1
return parameters
def normalize_L(x, L):
return -1. + 2. * x / (L-1)
def denormalize_L(x, L):
return (x + 1.0) / 2.0 * (L-1)
def crop2affine(crop_box, W, H):
assert len(crop_box) == 4, 'Invalid crop-box : {:}'.format(crop_box)
parameters = torch.zeros(3,3)
x1, y1 = normalize_L(crop_box[0], W), normalize_L(crop_box[1], H)
x2, y2 = normalize_L(crop_box[2], W), normalize_L(crop_box[3], H)
parameters[0,0] = (x2-x1)/2
parameters[0,2] = (x2+x1)/2
parameters[1,1] = (y2-y1)/2
parameters[1,2] = (y2+y1)/2
parameters[2,2] = 1
return parameters
def scale2affine(scalex, scaley):
parameters = torch.zeros(3,3)
parameters[0,0] = scalex
parameters[1,1] = scaley
parameters[2,2] = 1
return parameters
def offset2affine(offx, offy):
parameters = torch.zeros(3,3)
parameters[0,0] = parameters[1,1] = parameters[2,2] = 1
parameters[0,2] = offx
parameters[1,2] = offy
return parameters
def horizontalmirror2affine():
parameters = torch.zeros(3,3)
parameters[0,0] = -1
parameters[1,1] = parameters[2,2] = 1
return parameters
# clockwise rotate image = counterclockwise rotate the rectangle
# degree is between [0, 360]
def rotate2affine(degree):
assert degree >= 0 and degree <= 360, 'Invalid degree : {:}'.format(degree)
degree = degree / 180 * math.pi
parameters = torch.zeros(3,3)
parameters[0,0] = math.cos(-degree)
parameters[0,1] = -math.sin(-degree)
parameters[1,0] = math.sin(-degree)
parameters[1,1] = math.cos(-degree)
parameters[2,2] = 1
return parameters
# shape is a tuple [H, W]
def normalize_points(shape, points):
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape)
assert isinstance(points, torch.Tensor) and (points.shape[0] == 2), 'points are wrong : {:}'.format(points.shape)
(H, W), points = shape, points.clone()
points[0, :] = normalize_L(points[0,:], W)
points[1, :] = normalize_L(points[1,:], H)
return points
# shape is a tuple [H, W]
def normalize_points_batch(shape, points):
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape)
assert isinstance(points, torch.Tensor) and (points.size(-1) == 2), 'points are wrong : {:}'.format(points.shape)
(H, W), points = shape, points.clone()
x = normalize_L(points[...,0], W)
y = normalize_L(points[...,1], H)
return torch.stack((x,y), dim=-1)
# shape is a tuple [H, W]
def denormalize_points(shape, points):
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape)
assert isinstance(points, torch.Tensor) and (points.shape[0] == 2), 'points are wrong : {:}'.format(points.shape)
(H, W), points = shape, points.clone()
points[0, :] = denormalize_L(points[0,:], W)
points[1, :] = denormalize_L(points[1,:], H)
return points
# shape is a tuple [H, W]
def denormalize_points_batch(shape, points):
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape)
assert isinstance(points, torch.Tensor) and (points.shape[-1] == 2), 'points are wrong : {:}'.format(points.shape)
(H, W), points = shape, points.clone()
x = denormalize_L(points[...,0], W)
y = denormalize_L(points[...,1], H)
return torch.stack((x,y), dim=-1)
# make target * theta = source
def solve2theta(source, target):
source, target = source.clone(), target.clone()
oks = source[2, :] == 1
assert torch.sum(oks).item() >= 3, 'valid points : {:} is short'.format(oks)
if target.size(0) == 2: target = torch.cat((target, oks.unsqueeze(0).float()), dim=0)
source, target = source[:, oks], target[:, oks]
source, target = source.transpose(1,0), target.transpose(1,0)
assert source.size(1) == target.size(1) == 3
#X, residual, rank, s = np.linalg.lstsq(target.numpy(), source.numpy())
#theta = torch.Tensor(X.T[:2, :])
X_, qr = torch.gels(source, target)
theta = X_[:3, :2].transpose(1, 0)
return theta
# shape = [H,W]
def affine2image(image, theta, shape):
C, H, W = image.size()
theta = theta[:2, :].unsqueeze(0)
grid_size = torch.Size([1, C, shape[0], shape[1]])
grid = F.affine_grid(theta, grid_size)
affI = F.grid_sample(image.unsqueeze(0), grid, mode='bilinear', padding_mode='border')
return affI.squeeze(0)