update to oxford and aircraft
This commit is contained in:
		| @@ -71,7 +71,7 @@ def parse_arguments(): | |||||||
|     parser.add_argument('--write_freq', type=int, default=1, help='frequency of write to file') |     parser.add_argument('--write_freq', type=int, default=1, help='frequency of write to file') | ||||||
|     parser.add_argument('--start', type=int, default=0, help='start index') |     parser.add_argument('--start', type=int, default=0, help='start index') | ||||||
|     parser.add_argument('--end', type=int, default=0, help='end index') |     parser.add_argument('--end', type=int, default=0, help='end index') | ||||||
|     parser.add_argument('--noacc', default=False, action='store_true', |     parser.add_argument('--noacc', default=True, action='store_true', | ||||||
|                         help='avoid loading NASBench2 api an instead load a pickle file with tuple (index, arch_str)') |                         help='avoid loading NASBench2 api an instead load a pickle file with tuple (index, arch_str)') | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|     args.device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu") |     args.device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu") | ||||||
| @@ -94,7 +94,14 @@ if __name__ == '__main__': | |||||||
|     x, y = next(iter(train_loader)) |     x, y = next(iter(train_loader)) | ||||||
|  |  | ||||||
|     cached_res = [] |     cached_res = [] | ||||||
|     pre = 'cf' if 'cifar' in args.dataset else 'im' |     if 'cifar' in args.dataset : | ||||||
|  |         pre = 'cf' | ||||||
|  |     elif 'Image' in args.dataset: | ||||||
|  |         pre = 'im' | ||||||
|  |     elif 'oxford' in args.dataset: | ||||||
|  |         pre = 'ox' | ||||||
|  |     elif 'air' in args.dataset: | ||||||
|  |         pre = 'ai' | ||||||
|     pfn = f'nb2_{args.search_space}_{pre}{get_num_classes(args)}_seed{args.seed}_dl{args.dataload}_dlinfo{args.dataload_info}_initw{args.init_w_type}_initb{args.init_b_type}_{args.batch_size}.p' |     pfn = f'nb2_{args.search_space}_{pre}{get_num_classes(args)}_seed{args.seed}_dl{args.dataload}_dlinfo{args.dataload_info}_initw{args.init_w_type}_initb{args.init_b_type}_{args.batch_size}.p' | ||||||
|     op = os.path.join(args.outdir, pfn) |     op = os.path.join(args.outdir, pfn) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -18,6 +18,7 @@ | |||||||
| from torchvision.datasets import MNIST, CIFAR10, CIFAR100, SVHN | from torchvision.datasets import MNIST, CIFAR10, CIFAR100, SVHN | ||||||
| from torchvision.transforms import Compose, ToTensor, Normalize | from torchvision.transforms import Compose, ToTensor, Normalize | ||||||
| from torchvision import transforms | from torchvision import transforms | ||||||
|  | import torchvision.datasets as dset | ||||||
| from torch.utils.data import TensorDataset, DataLoader | from torch.utils.data import TensorDataset, DataLoader | ||||||
| import torch | import torch | ||||||
|  |  | ||||||
| @@ -44,6 +45,14 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker | |||||||
|         mean = (0.485, 0.456, 0.406) |         mean = (0.485, 0.456, 0.406) | ||||||
|         std  = (0.229, 0.224, 0.225) |         std  = (0.229, 0.224, 0.225) | ||||||
|         #resize = 256 |         #resize = 256 | ||||||
|  |     elif dataset == 'aircraft': | ||||||
|  |         mean = (0.4785, 0.5100, 0.5338) | ||||||
|  |         std  = (0.1845, 0.1830, 0.2060) | ||||||
|  |         size, pad = 224, 2 | ||||||
|  |     elif dataset == 'oxford': | ||||||
|  |         mean = (0.4811, 0.4492, 0.3957) | ||||||
|  |         std  = (0.2260, 0.2231, 0.2249) | ||||||
|  |         size, pad = 32, 0 | ||||||
|     elif 'random' in dataset: |     elif 'random' in dataset: | ||||||
|         mean = (0.5, 0.5, 0.5) |         mean = (0.5, 0.5, 0.5) | ||||||
|         std = (1, 1, 1) |         std = (1, 1, 1) | ||||||
| @@ -65,6 +74,7 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker | |||||||
|         transforms.ToTensor(), |         transforms.ToTensor(), | ||||||
|         transforms.Normalize(mean,std), |         transforms.Normalize(mean,std), | ||||||
|     ]) |     ]) | ||||||
|  |     root = '/nfs/data3/hanzhang/MeCo/data' | ||||||
|  |  | ||||||
|     if dataset == 'cifar10': |     if dataset == 'cifar10': | ||||||
|         train_dataset = CIFAR10(datadir, True, train_transform, download=True) |         train_dataset = CIFAR10(datadir, True, train_transform, download=True) | ||||||
| @@ -72,6 +82,40 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker | |||||||
|     elif dataset == 'cifar100': |     elif dataset == 'cifar100': | ||||||
|         train_dataset = CIFAR100(datadir, True, train_transform, download=True) |         train_dataset = CIFAR100(datadir, True, train_transform, download=True) | ||||||
|         test_dataset = CIFAR100(datadir, False, test_transform, download=True) |         test_dataset = CIFAR100(datadir, False, test_transform, download=True) | ||||||
|  |     elif dataset == 'aircraft':  | ||||||
|  |         lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||||
|  |         # if resize != None :  | ||||||
|  |         #     print(resize) | ||||||
|  |         #     lists += [CUTOUT(resize)] | ||||||
|  |         train_transform = transforms.Compose(lists) | ||||||
|  |         test_transform  = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||||
|  |         train_data = dset.ImageFolder(os.path.join(root, 'train_sorted_images'), train_transform) | ||||||
|  |         test_data  = dset.ImageFolder(os.path.join(root, 'test_sorted_images'),  test_transform) | ||||||
|  |     elif dataset == 'oxford': | ||||||
|  |         lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||||
|  |         # if resize != None :  | ||||||
|  |         #     print(resize) | ||||||
|  |         #     lists += [CUTOUT(resize)] | ||||||
|  |         train_transform = transforms.Compose(lists) | ||||||
|  |         test_transform  = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||||
|  |  | ||||||
|  |         train_data = torch.load(os.path.join(root, 'train85.pth')) | ||||||
|  |         test_data  = torch.load(os.path.join(root, 'test15.pth')) | ||||||
|  |  | ||||||
|  |         train_tensor_data = [(image, label) for image, label in train_data] | ||||||
|  |         test_tensor_data = [(image, label) for image, label in test_data] | ||||||
|  |         sum_data = train_tensor_data + test_tensor_data | ||||||
|  |  | ||||||
|  |         train_images = [image for image, label in train_tensor_data] | ||||||
|  |         train_labels = torch.tensor([label for image, label in train_tensor_data]) | ||||||
|  |         test_images = [image for image, label in test_tensor_data] | ||||||
|  |         test_labels = torch.tensor([label for image, label in test_tensor_data]) | ||||||
|  |  | ||||||
|  |         train_tensors = torch.stack([train_transform(image) for image in train_images]) | ||||||
|  |         test_tensors = torch.stack([test_transform(image) for image in test_images]) | ||||||
|  |  | ||||||
|  |         train_dataset = TensorDataset(train_tensors, train_labels) | ||||||
|  |         test_dataset = TensorDataset(test_tensors, test_labels) | ||||||
|     elif dataset == 'svhn': |     elif dataset == 'svhn': | ||||||
|         train_dataset = SVHN(datadir, split='train', transform=train_transform, download=True) |         train_dataset = SVHN(datadir, split='train', transform=train_transform, download=True) | ||||||
|         test_dataset = SVHN(datadir, split='test', transform=test_transform, download=True) |         test_dataset = SVHN(datadir, split='test', transform=test_transform, download=True) | ||||||
| @@ -97,8 +141,6 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker | |||||||
|         shuffle=False, |         shuffle=False, | ||||||
|         num_workers=num_workers, |         num_workers=num_workers, | ||||||
|         pin_memory=True) |         pin_memory=True) | ||||||
|  |  | ||||||
|  |  | ||||||
|     return train_loader, test_loader |     return train_loader, test_loader | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user