add aircraft root
This commit is contained in:
		| @@ -25,6 +25,32 @@ import torch | ||||
| from .imagenet16 import * | ||||
|  | ||||
|  | ||||
| class CUTOUT(object): | ||||
|     def __init__(self, length): | ||||
|         self.length = length | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(length={length})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def __call__(self, img): | ||||
|         h, w = img.size(1), img.size(2) | ||||
|         mask = np.ones((h, w), np.float32) | ||||
|         y = np.random.randint(h) | ||||
|         x = np.random.randint(w) | ||||
|  | ||||
|         y1 = np.clip(y - self.length // 2, 0, h) | ||||
|         y2 = np.clip(y + self.length // 2, 0, h) | ||||
|         x1 = np.clip(x - self.length // 2, 0, w) | ||||
|         x2 = np.clip(x + self.length // 2, 0, w) | ||||
|  | ||||
|         mask[y1:y2, x1:x2] = 0.0 | ||||
|         mask = torch.from_numpy(mask) | ||||
|         mask = mask.expand_as(img) | ||||
|         img *= mask | ||||
|         return img | ||||
|      | ||||
| def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_workers, resize=None, datadir='_dataset'): | ||||
|     # print(dataset) | ||||
|     if 'ImageNet16' in dataset: | ||||
| @@ -74,7 +100,8 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker | ||||
|         transforms.ToTensor(), | ||||
|         transforms.Normalize(mean,std), | ||||
|     ]) | ||||
|     root = '/nfs/data3/hanzhang/MeCo/data' | ||||
|     root = '/home/iicd/MeCo/data' | ||||
|     aircraft_dataset_root = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data' | ||||
|  | ||||
|     if dataset == 'cifar10': | ||||
|         train_dataset = CIFAR10(datadir, True, train_transform, download=True) | ||||
| @@ -84,18 +111,18 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker | ||||
|         test_dataset = CIFAR100(datadir, False, test_transform, download=True) | ||||
|     elif dataset == 'aircraft':  | ||||
|         lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||
|         # if resize != None :  | ||||
|         #     print(resize) | ||||
|         #     lists += [CUTOUT(resize)] | ||||
|         if resize != None :  | ||||
|             print(resize) | ||||
|             lists += [CUTOUT(resize)] | ||||
|         train_transform = transforms.Compose(lists) | ||||
|         test_transform  = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|         train_data = dset.ImageFolder(os.path.join(root, 'train_sorted_images'), train_transform) | ||||
|         test_data  = dset.ImageFolder(os.path.join(root, 'test_sorted_images'),  test_transform) | ||||
|         train_data = dset.ImageFolder(os.path.join(aircraft_dataset_root, 'train_sorted_images'), train_transform) | ||||
|         test_data  = dset.ImageFolder(os.path.join(aircraft_dataset_root, 'test_sorted_images'),  test_transform) | ||||
|     elif dataset == 'oxford': | ||||
|         lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||
|         # if resize != None :  | ||||
|         #     print(resize) | ||||
|         #     lists += [CUTOUT(resize)] | ||||
|         if resize != None :  | ||||
|             print(resize) | ||||
|             lists += [CUTOUT(resize)] | ||||
|         train_transform = transforms.Compose(lists) | ||||
|         test_transform  = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|  | ||||
| @@ -172,4 +199,4 @@ if __name__ == '__main__': | ||||
|     tr, te = get_cifar_dataloaders(64, 64, 'random', 2, resize=None, datadir='_dataset') | ||||
|     for x, y in tr: | ||||
|         print(x.size(), y.size()) | ||||
|         break | ||||
|         break | ||||
|   | ||||
		Reference in New Issue
	
	Block a user