Initial commit
This commit is contained in:
		
							
								
								
									
										129
									
								
								datasets/DownsampledImageNet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										129
									
								
								datasets/DownsampledImageNet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,129 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, hashlib, torch | ||||
| import numpy as np | ||||
| from PIL import Image | ||||
| import torch.utils.data as data | ||||
| if sys.version_info[0] == 2: | ||||
|   import cPickle as pickle | ||||
| else: | ||||
|   import pickle | ||||
|  | ||||
|  | ||||
| def calculate_md5(fpath, chunk_size=1024 * 1024): | ||||
|   md5 = hashlib.md5() | ||||
|   with open(fpath, 'rb') as f: | ||||
|     for chunk in iter(lambda: f.read(chunk_size), b''): | ||||
|       md5.update(chunk) | ||||
|   return md5.hexdigest() | ||||
|  | ||||
|  | ||||
| def check_md5(fpath, md5, **kwargs): | ||||
|   return md5 == calculate_md5(fpath, **kwargs) | ||||
|  | ||||
|  | ||||
| def check_integrity(fpath, md5=None): | ||||
|   if not os.path.isfile(fpath): return False | ||||
|   if md5 is None: return True | ||||
|   else          : return check_md5(fpath, md5) | ||||
|  | ||||
|  | ||||
| class ImageNet16(data.Dataset): | ||||
|   # http://image-net.org/download-images | ||||
|   # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets | ||||
|   # https://arxiv.org/pdf/1707.08819.pdf | ||||
|    | ||||
|   train_list = [ | ||||
|         ['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'], | ||||
|         ['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'], | ||||
|         ['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'], | ||||
|         ['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'], | ||||
|         ['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'], | ||||
|         ['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'], | ||||
|         ['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'], | ||||
|         ['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'], | ||||
|         ['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'], | ||||
|         ['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'], | ||||
|     ] | ||||
|   valid_list = [ | ||||
|         ['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'], | ||||
|     ] | ||||
|  | ||||
|   def __init__(self, root, train, transform, use_num_of_class_only=None): | ||||
|     self.root      = root | ||||
|     self.transform = transform | ||||
|     self.train     = train  # training set or valid set | ||||
|     if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.') | ||||
|  | ||||
|     if self.train: downloaded_list = self.train_list | ||||
|     else         : downloaded_list = self.valid_list | ||||
|     self.data    = [] | ||||
|     self.targets = [] | ||||
|    | ||||
|     # now load the picked numpy arrays | ||||
|     for i, (file_name, checksum) in enumerate(downloaded_list): | ||||
|       file_path = os.path.join(self.root, file_name) | ||||
|       #print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path)) | ||||
|       with open(file_path, 'rb') as f: | ||||
|         if sys.version_info[0] == 2: | ||||
|           entry = pickle.load(f) | ||||
|         else: | ||||
|           entry = pickle.load(f, encoding='latin1') | ||||
|         self.data.append(entry['data']) | ||||
|         self.targets.extend(entry['labels']) | ||||
|     self.data = np.vstack(self.data).reshape(-1, 3, 16, 16) | ||||
|     self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC | ||||
|     if use_num_of_class_only is not None: | ||||
|       assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only) | ||||
|       new_data, new_targets = [], [] | ||||
|       for I, L in zip(self.data, self.targets): | ||||
|         if 1 <= L <= use_num_of_class_only: | ||||
|           new_data.append( I ) | ||||
|           new_targets.append( L ) | ||||
|       self.data    = new_data | ||||
|       self.targets = new_targets | ||||
|     #    self.mean.append(entry['mean']) | ||||
|     #self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16) | ||||
|     #self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1) | ||||
|     #print ('Mean : {:}'.format(self.mean)) | ||||
|     #temp      = self.data - np.reshape(self.mean, (1, 1, 1, 3)) | ||||
|     #std_data  = np.std(temp, axis=0) | ||||
|     #std_data  = np.mean(np.mean(std_data, axis=0), axis=0) | ||||
|     #print ('Std  : {:}'.format(std_data)) | ||||
|  | ||||
|   def __getitem__(self, index): | ||||
|     img, target = self.data[index], self.targets[index] - 1 | ||||
|  | ||||
|     img = Image.fromarray(img) | ||||
|  | ||||
|     if self.transform is not None: | ||||
|       img = self.transform(img) | ||||
|  | ||||
|     return img, target | ||||
|  | ||||
|   def __len__(self): | ||||
|     return len(self.data) | ||||
|  | ||||
|   def _check_integrity(self): | ||||
|     root = self.root | ||||
|     for fentry in (self.train_list + self.valid_list): | ||||
|       filename, md5 = fentry[0], fentry[1] | ||||
|       fpath = os.path.join(root, filename) | ||||
|       if not check_integrity(fpath, md5): | ||||
|         return False | ||||
|     return True | ||||
|  | ||||
| # | ||||
| if __name__ == '__main__': | ||||
|   train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None)  | ||||
|   valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None)  | ||||
|  | ||||
|   print ( len(train) ) | ||||
|   print ( len(valid) ) | ||||
|   image, label = train[111] | ||||
|   trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None, 200) | ||||
|   validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False , None, 200) | ||||
|   print ( len(trainX) ) | ||||
|   print ( len(validX) ) | ||||
|   #import pdb; pdb.set_trace() | ||||
							
								
								
									
										191
									
								
								datasets/LandmarkDataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										191
									
								
								datasets/LandmarkDataset.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,191 @@ | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
| # | ||||
| from os import path as osp | ||||
| from copy import deepcopy as copy | ||||
| from tqdm import tqdm | ||||
| import warnings, time, random, numpy as np | ||||
|  | ||||
| from pts_utils import generate_label_map | ||||
| from xvision import denormalize_points | ||||
| from xvision import identity2affine, solve2theta, affine2image | ||||
| from .dataset_utils import pil_loader | ||||
| from .landmark_utils import PointMeta2V | ||||
| from .augmentation_utils import CutOut | ||||
| import torch | ||||
| import torch.utils.data as data | ||||
|  | ||||
|  | ||||
| class LandmarkDataset(data.Dataset): | ||||
|  | ||||
|   def __init__(self, transform, sigma, downsample, heatmap_type, shape, use_gray, mean_file, data_indicator, cache_images=None): | ||||
|  | ||||
|     self.transform    = transform | ||||
|     self.sigma        = sigma | ||||
|     self.downsample   = downsample | ||||
|     self.heatmap_type = heatmap_type | ||||
|     self.dataset_name = data_indicator | ||||
|     self.shape        = shape # [H,W] | ||||
|     self.use_gray     = use_gray | ||||
|     assert transform is not None, 'transform : {:}'.format(transform) | ||||
|     self.mean_file    = mean_file | ||||
|     if mean_file is None: | ||||
|       self.mean_data  = None | ||||
|       warnings.warn('LandmarkDataset initialized with mean_data = None') | ||||
|     else: | ||||
|       assert osp.isfile(mean_file), '{:} is not a file.'.format(mean_file) | ||||
|       self.mean_data  = torch.load(mean_file) | ||||
|     self.reset() | ||||
|     self.cutout       = None | ||||
|     self.cache_images = cache_images | ||||
|     print ('The general dataset initialization done : {:}'.format(self)) | ||||
|     warnings.simplefilter( 'once' ) | ||||
|  | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|  | ||||
|   def set_cutout(self, length): | ||||
|     if length is not None and length >= 1: | ||||
|       self.cutout = CutOut( int(length) ) | ||||
|     else: self.cutout = None | ||||
|  | ||||
|  | ||||
|   def reset(self, num_pts=-1, boxid='default', only_pts=False): | ||||
|     self.NUM_PTS = num_pts | ||||
|     if only_pts: return | ||||
|     self.length  = 0 | ||||
|     self.datas   = [] | ||||
|     self.labels  = [] | ||||
|     self.NormDistances = [] | ||||
|     self.BOXID = boxid | ||||
|     if self.mean_data is None: | ||||
|       self.mean_face = None | ||||
|     else: | ||||
|       self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T) | ||||
|       assert (self.mean_face >= -1).all() and (self.mean_face <= 1).all(), 'mean-{:}-face : {:}'.format(boxid, self.mean_face) | ||||
|     #assert self.dataset_name is not None, 'The dataset name is None' | ||||
|  | ||||
|  | ||||
|   def __len__(self): | ||||
|     assert len(self.datas) == self.length, 'The length is not correct : {}'.format(self.length) | ||||
|     return self.length | ||||
|  | ||||
|  | ||||
|   def append(self, data, label, distance): | ||||
|     assert osp.isfile(data), 'The image path is not a file : {:}'.format(data) | ||||
|     self.datas.append( data )             ;  self.labels.append( label ) | ||||
|     self.NormDistances.append( distance ) | ||||
|     self.length = self.length + 1 | ||||
|  | ||||
|  | ||||
|   def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset): | ||||
|     if reset: self.reset(num_pts, boxindicator) | ||||
|     else    : assert self.NUM_PTS == num_pts and self.BOXID == boxindicator, 'The number of point is inconsistance : {:} vs {:}'.format(self.NUM_PTS, num_pts) | ||||
|     if isinstance(file_lists, str): file_lists = [file_lists] | ||||
|     samples = [] | ||||
|     for idx, file_path in enumerate(file_lists): | ||||
|       print (':::: load list {:}/{:} : {:}'.format(idx, len(file_lists), file_path)) | ||||
|       xdata = torch.load(file_path) | ||||
|       if isinstance(xdata, list)  : data = xdata          # image or video dataset list | ||||
|       elif isinstance(xdata, dict): data = xdata['datas'] # multi-view dataset list | ||||
|       else: raise ValueError('Invalid Type Error : {:}'.format( type(xdata) )) | ||||
|       samples = samples + data | ||||
|     # samples is a dict, where the key is the image-path and the value is the annotation | ||||
|     # each annotation is a dict, contains 'points' (3,num_pts), and various box | ||||
|     print ('GeneralDataset-V2 : {:} samples'.format(len(samples))) | ||||
|  | ||||
|     #for index, annotation in enumerate(samples): | ||||
|     for index in tqdm( range( len(samples) ) ): | ||||
|       annotation = samples[index] | ||||
|       image_path  = annotation['current_frame'] | ||||
|       points, box = annotation['points'], annotation['box-{:}'.format(boxindicator)] | ||||
|       label = PointMeta2V(self.NUM_PTS, points, box, image_path, self.dataset_name) | ||||
|       if normalizeL is None: normDistance = None | ||||
|       else                 : normDistance = annotation['normalizeL-{:}'.format(normalizeL)] | ||||
|       self.append(image_path, label, normDistance) | ||||
|  | ||||
|     assert len(self.datas) == self.length, 'The length and the data is not right {} vs {}'.format(self.length, len(self.datas)) | ||||
|     assert len(self.labels) == self.length, 'The length and the labels is not right {} vs {}'.format(self.length, len(self.labels)) | ||||
|     assert len(self.NormDistances) == self.length, 'The length and the NormDistances is not right {} vs {}'.format(self.length, len(self.NormDistance)) | ||||
|     print ('Load data done for LandmarkDataset, which has {:} images.'.format(self.length)) | ||||
|  | ||||
|  | ||||
|   def __getitem__(self, index): | ||||
|     assert index >= 0 and index < self.length, 'Invalid index : {:}'.format(index) | ||||
|     if self.cache_images is not None and self.datas[index] in self.cache_images: | ||||
|       image = self.cache_images[ self.datas[index] ].clone() | ||||
|     else: | ||||
|       image = pil_loader(self.datas[index], self.use_gray) | ||||
|     target = self.labels[index].copy() | ||||
|     return self._process_(image, target, index) | ||||
|  | ||||
|  | ||||
|   def _process_(self, image, target, index): | ||||
|  | ||||
|     # transform the image and points | ||||
|     image, target, theta = self.transform(image, target) | ||||
|     (C, H, W), (height, width) = image.size(), self.shape | ||||
|  | ||||
|     # obtain the visiable indicator vector | ||||
|     if target.is_none(): nopoints = True | ||||
|     else               : nopoints = False | ||||
|     if index == -1: __path = None | ||||
|     else          : __path = self.datas[index] | ||||
|     if isinstance(theta, list) or isinstance(theta, tuple): | ||||
|       affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = [], [], [], [], [], [] | ||||
|       for _theta in theta: | ||||
|         _affineImage, _heatmaps, _mask, _norm_trans_points, _theta, _transpose_theta \ | ||||
|           = self.__process_affine(image, target, _theta, nopoints, 'P[{:}]@{:}'.format(index, __path)) | ||||
|         affineImage.append(_affineImage) | ||||
|         heatmaps.append(_heatmaps) | ||||
|         mask.append(_mask) | ||||
|         norm_trans_points.append(_norm_trans_points) | ||||
|         THETA.append(_theta) | ||||
|         transpose_theta.append(_transpose_theta) | ||||
|       affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = \ | ||||
|           torch.stack(affineImage), torch.stack(heatmaps), torch.stack(mask), torch.stack(norm_trans_points), torch.stack(THETA), torch.stack(transpose_theta) | ||||
|     else: | ||||
|       affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = self.__process_affine(image, target, theta, nopoints, 'S[{:}]@{:}'.format(index, __path)) | ||||
|  | ||||
|     torch_index = torch.IntTensor([index]) | ||||
|     torch_nopoints = torch.ByteTensor( [ nopoints ] ) | ||||
|     torch_shape = torch.IntTensor([H,W]) | ||||
|  | ||||
|     return affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta, torch_index, torch_nopoints, torch_shape | ||||
|  | ||||
|    | ||||
|   def __process_affine(self, image, target, theta, nopoints, aux_info=None): | ||||
|     image, target, theta = image.clone(), target.copy(), theta.clone() | ||||
|     (C, H, W), (height, width) = image.size(), self.shape | ||||
|     if nopoints: # do not have label | ||||
|       norm_trans_points = torch.zeros((3, self.NUM_PTS)) | ||||
|       heatmaps          = torch.zeros((self.NUM_PTS+1, height//self.downsample, width//self.downsample)) | ||||
|       mask              = torch.ones((self.NUM_PTS+1, 1, 1), dtype=torch.uint8) | ||||
|       transpose_theta   = identity2affine(False) | ||||
|     else: | ||||
|       norm_trans_points = apply_affine2point(target.get_points(), theta, (H,W)) | ||||
|       norm_trans_points = apply_boundary(norm_trans_points) | ||||
|       real_trans_points = norm_trans_points.clone() | ||||
|       real_trans_points[:2, :] = denormalize_points(self.shape, real_trans_points[:2,:]) | ||||
|       heatmaps, mask = generate_label_map(real_trans_points.numpy(), height//self.downsample, width//self.downsample, self.sigma, self.downsample, nopoints, self.heatmap_type) # H*W*C | ||||
|       heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(torch.FloatTensor) | ||||
|       mask     = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor) | ||||
|       if self.mean_face is None: | ||||
|         #warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.') | ||||
|         transpose_theta = identity2affine(False) | ||||
|       else: | ||||
|         if torch.sum(norm_trans_points[2,:] == 1) < 3: | ||||
|           warnings.warn('In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}'.format(aux_info)) | ||||
|           transpose_theta = identity2affine(False) | ||||
|         else: | ||||
|           transpose_theta = solve2theta(norm_trans_points, self.mean_face.clone()) | ||||
|  | ||||
|     affineImage = affine2image(image, theta, self.shape) | ||||
|     if self.cutout is not None: affineImage = self.cutout( affineImage ) | ||||
|  | ||||
|     return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta | ||||
							
								
								
									
										46
									
								
								datasets/SearchDatasetWrap.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								datasets/SearchDatasetWrap.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,46 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import torch, copy, random | ||||
| import torch.utils.data as data | ||||
|  | ||||
|  | ||||
| class SearchDataset(data.Dataset): | ||||
|  | ||||
|   def __init__(self, name, data, train_split, valid_split, check=True): | ||||
|     self.datasetname = name | ||||
|     if isinstance(data, (list, tuple)): # new type of SearchDataset | ||||
|       assert len(data) == 2, 'invalid length: {:}'.format( len(data) ) | ||||
|       self.train_data  = data[0] | ||||
|       self.valid_data  = data[1] | ||||
|       self.train_split = train_split.copy() | ||||
|       self.valid_split = valid_split.copy() | ||||
|       self.mode_str    = 'V2' # new mode  | ||||
|     else: | ||||
|       self.mode_str    = 'V1' # old mode  | ||||
|       self.data        = data | ||||
|       self.train_split = train_split.copy() | ||||
|       self.valid_split = valid_split.copy() | ||||
|       if check: | ||||
|         intersection = set(train_split).intersection(set(valid_split)) | ||||
|         assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection' | ||||
|     self.length      = len(self.train_split) | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split), ver=self.mode_str)) | ||||
|  | ||||
|   def __len__(self): | ||||
|     return self.length | ||||
|  | ||||
|   def __getitem__(self, index): | ||||
|     assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index) | ||||
|     train_index = self.train_split[index] | ||||
|     valid_index = random.choice( self.valid_split ) | ||||
|     if self.mode_str == 'V1': | ||||
|       train_image, train_label = self.data[train_index] | ||||
|       valid_image, valid_label = self.data[valid_index] | ||||
|     elif self.mode_str == 'V2': | ||||
|       train_image, train_label = self.train_data[train_index] | ||||
|       valid_image, valid_label = self.valid_data[valid_index] | ||||
|     else: raise ValueError('invalid mode : {:}'.format(self.mode_str)) | ||||
|     return train_image, train_label, valid_image, valid_label | ||||
							
								
								
									
										5
									
								
								datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from .get_dataset_with_transform import get_datasets, get_nas_search_loaders | ||||
| from .SearchDatasetWrap import SearchDataset | ||||
							
								
								
									
										227
									
								
								datasets/get_dataset_with_transform.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										227
									
								
								datasets/get_dataset_with_transform.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,227 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, torch | ||||
| import os.path as osp | ||||
| import numpy as np | ||||
| import torchvision.datasets as dset | ||||
| import torchvision.transforms as transforms | ||||
| from copy import deepcopy | ||||
| from PIL import Image | ||||
|  | ||||
| from .DownsampledImageNet import ImageNet16 | ||||
| from .SearchDatasetWrap import SearchDataset | ||||
| from config_utils import load_config | ||||
|  | ||||
|  | ||||
| Dataset2Class = {'cifar10' : 10, | ||||
|                  'cifar100': 100, | ||||
|                  'imagenet-1k-s':1000, | ||||
|                  'imagenet-1k' : 1000, | ||||
|                  'ImageNet16'  : 1000, | ||||
|                  'ImageNet16-150': 150, | ||||
|                  'ImageNet16-120': 120, | ||||
|                  'ImageNet16-200': 200} | ||||
|  | ||||
|  | ||||
| class CUTOUT(object): | ||||
|  | ||||
|   def __init__(self, length): | ||||
|     self.length = length | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def __call__(self, img): | ||||
|     h, w = img.size(1), img.size(2) | ||||
|     mask = np.ones((h, w), np.float32) | ||||
|     y = np.random.randint(h) | ||||
|     x = np.random.randint(w) | ||||
|  | ||||
|     y1 = np.clip(y - self.length // 2, 0, h) | ||||
|     y2 = np.clip(y + self.length // 2, 0, h) | ||||
|     x1 = np.clip(x - self.length // 2, 0, w) | ||||
|     x2 = np.clip(x + self.length // 2, 0, w) | ||||
|  | ||||
|     mask[y1: y2, x1: x2] = 0. | ||||
|     mask = torch.from_numpy(mask) | ||||
|     mask = mask.expand_as(img) | ||||
|     img *= mask | ||||
|     return img | ||||
|  | ||||
|  | ||||
| imagenet_pca = { | ||||
|     'eigval': np.asarray([0.2175, 0.0188, 0.0045]), | ||||
|     'eigvec': np.asarray([ | ||||
|         [-0.5675, 0.7192, 0.4009], | ||||
|         [-0.5808, -0.0045, -0.8140], | ||||
|         [-0.5836, -0.6948, 0.4203], | ||||
|     ]) | ||||
| } | ||||
|  | ||||
|  | ||||
| class Lighting(object): | ||||
|   def __init__(self, alphastd, | ||||
|          eigval=imagenet_pca['eigval'], | ||||
|          eigvec=imagenet_pca['eigvec']): | ||||
|     self.alphastd = alphastd | ||||
|     assert eigval.shape == (3,) | ||||
|     assert eigvec.shape == (3, 3) | ||||
|     self.eigval = eigval | ||||
|     self.eigvec = eigvec | ||||
|  | ||||
|   def __call__(self, img): | ||||
|     if self.alphastd == 0.: | ||||
|       return img | ||||
|     rnd = np.random.randn(3) * self.alphastd | ||||
|     rnd = rnd.astype('float32') | ||||
|     v = rnd | ||||
|     old_dtype = np.asarray(img).dtype | ||||
|     v = v * self.eigval | ||||
|     v = v.reshape((3, 1)) | ||||
|     inc = np.dot(self.eigvec, v).reshape((3,)) | ||||
|     img = np.add(img, inc) | ||||
|     if old_dtype == np.uint8: | ||||
|       img = np.clip(img, 0, 255) | ||||
|     img = Image.fromarray(img.astype(old_dtype), 'RGB') | ||||
|     return img | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return self.__class__.__name__ + '()' | ||||
|  | ||||
|  | ||||
| def get_datasets(name, root, cutout): | ||||
|  | ||||
|   if name == 'cifar10': | ||||
|     mean = [x / 255 for x in [125.3, 123.0, 113.9]] | ||||
|     std  = [x / 255 for x in [63.0, 62.1, 66.7]] | ||||
|   elif name == 'cifar100': | ||||
|     mean = [x / 255 for x in [129.3, 124.1, 112.4]] | ||||
|     std  = [x / 255 for x in [68.2, 65.4, 70.4]] | ||||
|   elif name.startswith('imagenet-1k'): | ||||
|     mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] | ||||
|   elif name.startswith('ImageNet16'): | ||||
|     mean = [x / 255 for x in [122.68, 116.66, 104.01]] | ||||
|     std  = [x / 255 for x in [63.22,  61.26 , 65.09]] | ||||
|   else: | ||||
|     raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|  | ||||
|   # Data Argumentation | ||||
|   if name == 'cifar10' or name == 'cifar100': | ||||
|     lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||
|     if cutout > 0 : lists += [CUTOUT(cutout)] | ||||
|     train_transform = transforms.Compose(lists) | ||||
|     test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|     xshape = (1, 3, 32, 32) | ||||
|   elif name.startswith('ImageNet16'): | ||||
|     lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||
|     if cutout > 0 : lists += [CUTOUT(cutout)] | ||||
|     train_transform = transforms.Compose(lists) | ||||
|     test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|     xshape = (1, 3, 16, 16) | ||||
|   elif name == 'tiered': | ||||
|     lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(80, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||
|     if cutout > 0 : lists += [CUTOUT(cutout)] | ||||
|     train_transform = transforms.Compose(lists) | ||||
|     test_transform  = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|     xshape = (1, 3, 32, 32) | ||||
|   elif name.startswith('imagenet-1k'): | ||||
|     normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | ||||
|     if name == 'imagenet-1k': | ||||
|       xlists    = [transforms.RandomResizedCrop(224)] | ||||
|       xlists.append( | ||||
|         transforms.ColorJitter( | ||||
|         brightness=0.4, | ||||
|         contrast=0.4, | ||||
|         saturation=0.4, | ||||
|         hue=0.2)) | ||||
|       xlists.append( Lighting(0.1)) | ||||
|     elif name == 'imagenet-1k-s': | ||||
|       xlists    = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))] | ||||
|     else: raise ValueError('invalid name : {:}'.format(name)) | ||||
|     xlists.append( transforms.RandomHorizontalFlip(p=0.5) ) | ||||
|     xlists.append( transforms.ToTensor() ) | ||||
|     xlists.append( normalize ) | ||||
|     train_transform = transforms.Compose(xlists) | ||||
|     test_transform  = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize]) | ||||
|     xshape = (1, 3, 224, 224) | ||||
|   else: | ||||
|     raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|  | ||||
|   if name == 'cifar10': | ||||
|     train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True) | ||||
|     test_data  = dset.CIFAR10 (root, train=False, transform=test_transform , download=True) | ||||
|     assert len(train_data) == 50000 and len(test_data) == 10000 | ||||
|   elif name == 'cifar100': | ||||
|     train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True) | ||||
|     test_data  = dset.CIFAR100(root, train=False, transform=test_transform , download=True) | ||||
|     assert len(train_data) == 50000 and len(test_data) == 10000 | ||||
|   elif name.startswith('imagenet-1k'): | ||||
|     train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform) | ||||
|     test_data  = dset.ImageFolder(osp.join(root, 'val'),   test_transform) | ||||
|     assert len(train_data) == 1281167 and len(test_data) == 50000, 'invalid number of images : {:} & {:} vs {:} & {:}'.format(len(train_data), len(test_data), 1281167, 50000) | ||||
|   elif name == 'ImageNet16': | ||||
|     train_data = ImageNet16(root, True , train_transform) | ||||
|     test_data  = ImageNet16(root, False, test_transform) | ||||
|     assert len(train_data) == 1281167 and len(test_data) == 50000 | ||||
|   elif name == 'ImageNet16-120': | ||||
|     train_data = ImageNet16(root, True , train_transform, 120) | ||||
|     test_data  = ImageNet16(root, False, test_transform , 120) | ||||
|     assert len(train_data) == 151700 and len(test_data) == 6000 | ||||
|   elif name == 'ImageNet16-150': | ||||
|     train_data = ImageNet16(root, True , train_transform, 150) | ||||
|     test_data  = ImageNet16(root, False, test_transform , 150) | ||||
|     assert len(train_data) == 190272 and len(test_data) == 7500 | ||||
|   elif name == 'ImageNet16-200': | ||||
|     train_data = ImageNet16(root, True , train_transform, 200) | ||||
|     test_data  = ImageNet16(root, False, test_transform , 200) | ||||
|     assert len(train_data) == 254775 and len(test_data) == 10000 | ||||
|   else: raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|    | ||||
|   class_num = Dataset2Class[name] | ||||
|   return train_data, test_data, xshape, class_num | ||||
|  | ||||
|  | ||||
| def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers): | ||||
|   if isinstance(batch_size, (list,tuple)): | ||||
|     batch, test_batch = batch_size | ||||
|   else: | ||||
|     batch, test_batch = batch_size, batch_size | ||||
|   if dataset == 'cifar10': | ||||
|     #split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|     cifar_split = load_config('{:}/cifar-split.txt'.format(config_root), None, None) | ||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set | ||||
|     #logger.log('Load split file from {:}'.format(split_Fpath))      # they are two disjoint groups in the original CIFAR-10 training set | ||||
|     # To split data | ||||
|     xvalid_data  = deepcopy(train_data) | ||||
|     if hasattr(xvalid_data, 'transforms'): # to avoid a print issue | ||||
|       xvalid_data.transforms = valid_data.transform | ||||
|     xvalid_data.transform  = deepcopy( valid_data.transform ) | ||||
|     search_data   = SearchDataset(dataset, train_data, train_split, valid_split) | ||||
|     # data loader | ||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||
|     train_loader  = torch.utils.data.DataLoader(train_data , batch_size=batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), num_workers=workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=workers, pin_memory=True) | ||||
|   elif dataset == 'cifar100': | ||||
|     cifar100_test_split = load_config('{:}/cifar100-test-split.txt'.format(config_root), None, None) | ||||
|     search_train_data = train_data | ||||
|     search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform | ||||
|     search_data   = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), cifar100_test_split.xvalid) | ||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||
|     train_loader  = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_test_split.xvalid), num_workers=workers, pin_memory=True) | ||||
|   elif dataset == 'ImageNet16-120': | ||||
|     imagenet_test_split = load_config('{:}/imagenet-16-120-test-split.txt'.format(config_root), None, None) | ||||
|     search_train_data = train_data | ||||
|     search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform | ||||
|     search_data   = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid) | ||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||
|     train_loader  = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=workers, pin_memory=True) | ||||
|   else: | ||||
|     raise ValueError('invalid dataset : {:}'.format(dataset)) | ||||
|   return search_loader, train_loader, valid_loader | ||||
|  | ||||
| #if __name__ == '__main__': | ||||
| #  train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1) | ||||
| #  import pdb; pdb.set_trace() | ||||
							
								
								
									
										1
									
								
								datasets/landmark_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								datasets/landmark_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| from .point_meta import PointMeta2V, apply_affine2point, apply_boundary | ||||
							
								
								
									
										116
									
								
								datasets/landmark_utils/point_meta.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										116
									
								
								datasets/landmark_utils/point_meta.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,116 @@ | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
| # | ||||
| import copy, math, torch, numpy as np | ||||
| from xvision import normalize_points | ||||
| from xvision import denormalize_points | ||||
|  | ||||
|  | ||||
| class PointMeta(): | ||||
|   # points    : 3 x num_pts (x, y, oculusion) | ||||
|   # image_size: original [width, height] | ||||
|   def __init__(self, num_point, points, box, image_path, dataset_name): | ||||
|  | ||||
|     self.num_point = num_point | ||||
|     if box is not None: | ||||
|       assert (isinstance(box, tuple) or isinstance(box, list)) and len(box) == 4 | ||||
|       self.box = torch.Tensor(box) | ||||
|     else: self.box = None | ||||
|     if points is None: | ||||
|       self.points = points | ||||
|     else: | ||||
|       assert len(points.shape) == 2 and points.shape[0] == 3 and points.shape[1] == self.num_point, 'The shape of point is not right : {}'.format( points ) | ||||
|       self.points = torch.Tensor(points.copy()) | ||||
|     self.image_path = image_path | ||||
|     self.datasets = dataset_name | ||||
|  | ||||
|   def __repr__(self): | ||||
|     if self.box is None: boxstr = 'None' | ||||
|     else               : boxstr = 'box=[{:.1f}, {:.1f}, {:.1f}, {:.1f}]'.format(*self.box.tolist()) | ||||
|     return ('{name}(points={num_point}, '.format(name=self.__class__.__name__, **self.__dict__) + boxstr + ')') | ||||
|  | ||||
|   def get_box(self, return_diagonal=False): | ||||
|     if self.box is None: return None | ||||
|     if not return_diagonal: | ||||
|       return self.box.clone() | ||||
|     else: | ||||
|       W = (self.box[2]-self.box[0]).item() | ||||
|       H = (self.box[3]-self.box[1]).item() | ||||
|       return math.sqrt(H*H+W*W) | ||||
|  | ||||
|   def get_points(self, ignore_indicator=False): | ||||
|     if ignore_indicator: last = 2 | ||||
|     else               : last = 3 | ||||
|     if self.points is not None: return self.points.clone()[:last, :] | ||||
|     else                      : return torch.zeros((last, self.num_point)) | ||||
|  | ||||
|   def is_none(self): | ||||
|     #assert self.box is not None, 'The box should not be None' | ||||
|     return self.points is None | ||||
|     #if self.box is None: return True | ||||
|     #else               : return self.points is None | ||||
|  | ||||
|   def copy(self): | ||||
|     return copy.deepcopy(self) | ||||
|  | ||||
|   def visiable_pts_num(self): | ||||
|     with torch.no_grad(): | ||||
|       ans = self.points[2,:] > 0 | ||||
|       ans = torch.sum(ans) | ||||
|       ans = ans.item() | ||||
|     return ans | ||||
|    | ||||
|   def special_fun(self, indicator): | ||||
|     if indicator == '68to49': # For 300W or 300VW, convert the default 68 points to 49 points. | ||||
|       assert self.num_point == 68, 'num-point must be 68 vs. {:}'.format(self.num_point) | ||||
|       self.num_point = 49 | ||||
|       out = torch.ones((68), dtype=torch.uint8) | ||||
|       out[[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,60,64]] = 0 | ||||
|       if self.points is not None: self.points = self.points.clone()[:, out] | ||||
|     else: | ||||
|       raise ValueError('Invalid indicator : {:}'.format( indicator )) | ||||
|  | ||||
|   def apply_horizontal_flip(self): | ||||
|     #self.points[0, :] = width - self.points[0, :] - 1 | ||||
|     # Mugsy spefic or Synthetic | ||||
|     if self.datasets.startswith('HandsyROT'): | ||||
|       ori = np.array(list(range(0, 42))) | ||||
|       pos = np.array(list(range(21,42)) + list(range(0,21))) | ||||
|       self.points[:, pos] = self.points[:, ori] | ||||
|     elif self.datasets.startswith('face68'): | ||||
|       ori = np.array(list(range(0, 68))) | ||||
|       pos = np.array([17,16,15,14,13,12,11,10, 9, 8,7,6,5,4,3,2,1, 27,26,25,24,23,22,21,20,19,18, 28,29,30,31, 36,35,34,33,32, 46,45,44,43,48,47, 40,39,38,37,42,41, 55,54,53,52,51,50,49,60,59,58,57,56,65,64,63,62,61,68,67,66])-1 | ||||
|       self.points[:, ori] = self.points[:, pos] | ||||
|     else: | ||||
|       raise ValueError('Does not support {:}'.format(self.datasets)) | ||||
|  | ||||
|  | ||||
|  | ||||
| # shape = (H,W) | ||||
| def apply_affine2point(points, theta, shape): | ||||
|   assert points.size(0) == 3, 'invalid points shape : {:}'.format(points.size()) | ||||
|   with torch.no_grad(): | ||||
|     ok_points = points[2,:] == 1 | ||||
|     assert torch.sum(ok_points).item() > 0, 'there is no visiable point' | ||||
|     points[:2,:] = normalize_points(shape, points[:2,:]) | ||||
|  | ||||
|     norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float() | ||||
|  | ||||
|     trans_points, ___ = torch.gesv(points[:, ok_points], theta) | ||||
|  | ||||
|     norm_trans_points[:, ok_points] = trans_points | ||||
|      | ||||
|   return norm_trans_points | ||||
|  | ||||
|  | ||||
|  | ||||
| def apply_boundary(norm_trans_points): | ||||
|   with torch.no_grad(): | ||||
|     norm_trans_points = norm_trans_points.clone() | ||||
|     oks = torch.stack((norm_trans_points[0]>-1, norm_trans_points[0]<1, norm_trans_points[1]>-1, norm_trans_points[1]<1, norm_trans_points[2]>0)) | ||||
|     oks = torch.sum(oks, dim=0) == 5 | ||||
|     norm_trans_points[2, :] = oks | ||||
|   return norm_trans_points | ||||
							
								
								
									
										20
									
								
								datasets/test_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								datasets/test_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os | ||||
|  | ||||
|  | ||||
| def test_imagenet_data(imagenet): | ||||
|   total_length = len(imagenet) | ||||
|   assert total_length == 1281166 or total_length == 50000, 'The length of ImageNet is wrong : {}'.format(total_length) | ||||
|   map_id = {} | ||||
|   for index in range(total_length): | ||||
|     path, target = imagenet.imgs[index] | ||||
|     folder, image_name = os.path.split(path) | ||||
|     _, folder = os.path.split(folder) | ||||
|     if folder not in map_id: | ||||
|       map_id[folder] = target | ||||
|     else: | ||||
|       assert map_id[folder] == target, 'Class : {} is not {}'.format(folder, target) | ||||
|     assert image_name.find(folder) == 0, '{} is wrong.'.format(path) | ||||
|   print ('Check ImageNet Dataset OK') | ||||
		Reference in New Issue
	
	Block a user