add aircraft in utilities
This commit is contained in:
		| @@ -13,7 +13,8 @@ Dataset2Class = {'cifar10': 10, | |||||||
|                  'ImageNet16' : 1000, |                  'ImageNet16' : 1000, | ||||||
|                  'ImageNet16-120': 120, |                  'ImageNet16-120': 120, | ||||||
|                  'ImageNet16-150': 150, |                  'ImageNet16-150': 150, | ||||||
|                  'ImageNet16-200': 200} |                  'ImageNet16-200': 200, | ||||||
|  |                  'aircraft': 100} | ||||||
|  |  | ||||||
| class RandChannel(object): | class RandChannel(object): | ||||||
|     # randomly pick channels from input |     # randomly pick channels from input | ||||||
| @@ -46,6 +47,10 @@ def get_datasets(name, root, input_size, cutout=-1): | |||||||
|     elif name.startswith('ImageNet16'): |     elif name.startswith('ImageNet16'): | ||||||
|         mean = [0.481098, 0.45749, 0.407882] |         mean = [0.481098, 0.45749, 0.407882] | ||||||
|         std  = [0.247922, 0.240235, 0.255255] |         std  = [0.247922, 0.240235, 0.255255] | ||||||
|  |     elif name == 'aircraft': | ||||||
|  |         mean = [0.4785, 0.5100, 0.5338] | ||||||
|  |         std  = [0.1845, 0.1830, 0.2060] | ||||||
|  |  | ||||||
|     else: |     else: | ||||||
|         raise TypeError("Unknow dataset : {:}".format(name)) |         raise TypeError("Unknow dataset : {:}".format(name)) | ||||||
|  |  | ||||||
| @@ -55,6 +60,12 @@ def get_datasets(name, root, input_size, cutout=-1): | |||||||
|         if cutout > 0 : lists += [CUTOUT(cutout)] |         if cutout > 0 : lists += [CUTOUT(cutout)] | ||||||
|         train_transform = transforms.Compose(lists) |         train_transform = transforms.Compose(lists) | ||||||
|         test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) |         test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||||
|  |     elif name == 'aircraft': | ||||||
|  |         lists = [transforms.RandomCrop(input_size[1], padding=0), transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||||
|  |         if cutout > 0 : lists += [CUTOUT(cutout)] | ||||||
|  |         train_transform = transforms.Compose(lists) | ||||||
|  |         test_transform  = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||||
|  |  | ||||||
|     elif name.startswith('ImageNet16'): |     elif name.startswith('ImageNet16'): | ||||||
|         lists = [transforms.RandomCrop(input_size[1], padding=0), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])] |         lists = [transforms.RandomCrop(input_size[1], padding=0), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])] | ||||||
|         if cutout > 0 : lists += [CUTOUT(cutout)] |         if cutout > 0 : lists += [CUTOUT(cutout)] | ||||||
| @@ -86,9 +97,12 @@ def get_datasets(name, root, input_size, cutout=-1): | |||||||
|         train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True) |         train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True) | ||||||
|         test_data  = dset.CIFAR100(root, train=False, transform=test_transform , download=True) |         test_data  = dset.CIFAR100(root, train=False, transform=test_transform , download=True) | ||||||
|         assert len(train_data) == 50000 and len(test_data) == 10000 |         assert len(train_data) == 50000 and len(test_data) == 10000 | ||||||
|  |     elif name == 'aircraft': | ||||||
|  |         train_data = dset.ImageFolder(osp.join(root, 'train_sorted_images'), train_transform) | ||||||
|  |         test_data  = dset.ImageFolder(osp.join(root, 'test_sorted_images'),  test_transform) | ||||||
|     elif name.startswith('imagenet-1k'): |     elif name.startswith('imagenet-1k'): | ||||||
|         train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform) |         train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform) | ||||||
|         test_data  = dset.ImageFolder(osp.join(root, 'val'),   test_transform) |         test_data  = dset.ImageFolder(osp.join(root, 'test'),   test_transform) | ||||||
|     elif name == 'ImageNet16': |     elif name == 'ImageNet16': | ||||||
|         root = osp.join(root, 'ImageNet16') |         root = osp.join(root, 'ImageNet16') | ||||||
|         train_data = ImageNet16(root, True , train_transform) |         train_data = ImageNet16(root, True , train_transform) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user