add aircraft in utilities
This commit is contained in:
		| @@ -13,7 +13,8 @@ Dataset2Class = {'cifar10': 10, | ||||
|                  'ImageNet16' : 1000, | ||||
|                  'ImageNet16-120': 120, | ||||
|                  'ImageNet16-150': 150, | ||||
|                  'ImageNet16-200': 200} | ||||
|                  'ImageNet16-200': 200, | ||||
|                  'aircraft': 100} | ||||
|  | ||||
| class RandChannel(object): | ||||
|     # randomly pick channels from input | ||||
| @@ -46,6 +47,10 @@ def get_datasets(name, root, input_size, cutout=-1): | ||||
|     elif name.startswith('ImageNet16'): | ||||
|         mean = [0.481098, 0.45749, 0.407882] | ||||
|         std  = [0.247922, 0.240235, 0.255255] | ||||
|     elif name == 'aircraft': | ||||
|         mean = [0.4785, 0.5100, 0.5338] | ||||
|         std  = [0.1845, 0.1830, 0.2060] | ||||
|  | ||||
|     else: | ||||
|         raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|  | ||||
| @@ -55,6 +60,12 @@ def get_datasets(name, root, input_size, cutout=-1): | ||||
|         if cutout > 0 : lists += [CUTOUT(cutout)] | ||||
|         train_transform = transforms.Compose(lists) | ||||
|         test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|     elif name == 'aircraft': | ||||
|         lists = [transforms.RandomCrop(input_size[1], padding=0), transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||
|         if cutout > 0 : lists += [CUTOUT(cutout)] | ||||
|         train_transform = transforms.Compose(lists) | ||||
|         test_transform  = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|  | ||||
|     elif name.startswith('ImageNet16'): | ||||
|         lists = [transforms.RandomCrop(input_size[1], padding=0), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])] | ||||
|         if cutout > 0 : lists += [CUTOUT(cutout)] | ||||
| @@ -86,9 +97,12 @@ def get_datasets(name, root, input_size, cutout=-1): | ||||
|         train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True) | ||||
|         test_data  = dset.CIFAR100(root, train=False, transform=test_transform , download=True) | ||||
|         assert len(train_data) == 50000 and len(test_data) == 10000 | ||||
|     elif name == 'aircraft': | ||||
|         train_data = dset.ImageFolder(osp.join(root, 'train_sorted_images'), train_transform) | ||||
|         test_data  = dset.ImageFolder(osp.join(root, 'test_sorted_images'),  test_transform) | ||||
|     elif name.startswith('imagenet-1k'): | ||||
|         train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform) | ||||
|         test_data  = dset.ImageFolder(osp.join(root, 'val'),   test_transform) | ||||
|         test_data  = dset.ImageFolder(osp.join(root, 'test'),   test_transform) | ||||
|     elif name == 'ImageNet16': | ||||
|         root = osp.join(root, 'ImageNet16') | ||||
|         train_data = ImageNet16(root, True , train_transform) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user