add aircraft in utilities
This commit is contained in:
parent
4df5615380
commit
24f15ad0fe
@ -13,7 +13,8 @@ Dataset2Class = {'cifar10': 10,
|
|||||||
'ImageNet16' : 1000,
|
'ImageNet16' : 1000,
|
||||||
'ImageNet16-120': 120,
|
'ImageNet16-120': 120,
|
||||||
'ImageNet16-150': 150,
|
'ImageNet16-150': 150,
|
||||||
'ImageNet16-200': 200}
|
'ImageNet16-200': 200,
|
||||||
|
'aircraft': 100}
|
||||||
|
|
||||||
class RandChannel(object):
|
class RandChannel(object):
|
||||||
# randomly pick channels from input
|
# randomly pick channels from input
|
||||||
@ -46,6 +47,10 @@ def get_datasets(name, root, input_size, cutout=-1):
|
|||||||
elif name.startswith('ImageNet16'):
|
elif name.startswith('ImageNet16'):
|
||||||
mean = [0.481098, 0.45749, 0.407882]
|
mean = [0.481098, 0.45749, 0.407882]
|
||||||
std = [0.247922, 0.240235, 0.255255]
|
std = [0.247922, 0.240235, 0.255255]
|
||||||
|
elif name == 'aircraft':
|
||||||
|
mean = [0.4785, 0.5100, 0.5338]
|
||||||
|
std = [0.1845, 0.1830, 0.2060]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("Unknow dataset : {:}".format(name))
|
raise TypeError("Unknow dataset : {:}".format(name))
|
||||||
|
|
||||||
@ -55,6 +60,12 @@ def get_datasets(name, root, input_size, cutout=-1):
|
|||||||
if cutout > 0 : lists += [CUTOUT(cutout)]
|
if cutout > 0 : lists += [CUTOUT(cutout)]
|
||||||
train_transform = transforms.Compose(lists)
|
train_transform = transforms.Compose(lists)
|
||||||
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||||
|
elif name == 'aircraft':
|
||||||
|
lists = [transforms.RandomCrop(input_size[1], padding=0), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||||
|
if cutout > 0 : lists += [CUTOUT(cutout)]
|
||||||
|
train_transform = transforms.Compose(lists)
|
||||||
|
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||||
|
|
||||||
elif name.startswith('ImageNet16'):
|
elif name.startswith('ImageNet16'):
|
||||||
lists = [transforms.RandomCrop(input_size[1], padding=0), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])]
|
lists = [transforms.RandomCrop(input_size[1], padding=0), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])]
|
||||||
if cutout > 0 : lists += [CUTOUT(cutout)]
|
if cutout > 0 : lists += [CUTOUT(cutout)]
|
||||||
@ -86,9 +97,12 @@ def get_datasets(name, root, input_size, cutout=-1):
|
|||||||
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
|
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
|
||||||
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
|
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
|
||||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||||
|
elif name == 'aircraft':
|
||||||
|
train_data = dset.ImageFolder(osp.join(root, 'train_sorted_images'), train_transform)
|
||||||
|
test_data = dset.ImageFolder(osp.join(root, 'test_sorted_images'), test_transform)
|
||||||
elif name.startswith('imagenet-1k'):
|
elif name.startswith('imagenet-1k'):
|
||||||
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
|
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
|
||||||
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
|
test_data = dset.ImageFolder(osp.join(root, 'test'), test_transform)
|
||||||
elif name == 'ImageNet16':
|
elif name == 'ImageNet16':
|
||||||
root = osp.join(root, 'ImageNet16')
|
root = osp.join(root, 'ImageNet16')
|
||||||
train_data = ImageNet16(root, True , train_transform)
|
train_data = ImageNet16(root, True , train_transform)
|
||||||
|
Loading…
Reference in New Issue
Block a user