add how to prepare dataset
This commit is contained in:
		
							
								
								
									
										139
									
								
								correlation/calculate_dataset_statistics.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										139
									
								
								correlation/calculate_dataset_statistics.py
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,139 @@ | ||||
| # import torch | ||||
| # import torchvision | ||||
| # import torchvision.transforms as transforms | ||||
|  | ||||
| # # 加载CIFAR-10数据集 | ||||
| # transform = transforms.Compose([transforms.ToTensor()]) | ||||
| # trainset = torchvision.datasets.CIFAR10(root='./datasets', train=True, download=True, transform=transform) | ||||
| # trainloader = torch.utils.data.DataLoader(trainset, batch_size=10000, shuffle=False, num_workers=2) | ||||
|  | ||||
| # # 将所有数据加载到内存中 | ||||
| # data = next(iter(trainloader)) | ||||
| # images, _ = data | ||||
|  | ||||
| # # 计算每个通道的均值和标准差 | ||||
| # mean = images.mean([0, 2, 3]) | ||||
| # std = images.std([0, 2, 3]) | ||||
|  | ||||
| # print(f'Mean: {mean}') | ||||
| # print(f'Std: {std}') | ||||
|  | ||||
| # results: | ||||
| # Mean: tensor([0.4935, 0.4834, 0.4472]) | ||||
| # Std: tensor([0.2476, 0.2446, 0.2626])   | ||||
|  | ||||
| import itertools | ||||
| import torch | ||||
| from torchvision import datasets, transforms | ||||
| from torch.utils.data import DataLoader, TensorDataset | ||||
| import argparse | ||||
| import numpy as np | ||||
| import os | ||||
|  | ||||
| parser = argparse.ArgumentParser(description='Calculate mean and std of dataset') | ||||
| parser.add_argument('--dataset', type=str, default='cifar10', help='dataset name') | ||||
| parser.add_argument('--data_path', type=str, default='./datasets/cifar-10-batches-py', help='path to dataset image folder') | ||||
| parser.add_argument('--train_dataset_path', type=str, default='train', help='train dataset path') | ||||
| parser.add_argument('--test_dataset_path', type=str, default='test', help='test dataset path') | ||||
|  | ||||
| args = parser.parse_args() | ||||
|  | ||||
|  | ||||
| # 设置数据集路径 | ||||
| dataset_path = args.data_path | ||||
| dataset_name = args.dataset | ||||
|  | ||||
| if dataset_name == 'cifar10': | ||||
|     transform = transforms.Compose([ | ||||
|         transforms.ToTensor() | ||||
|     ]) | ||||
| elif dataset_name == 'aircraft' or dataset_name == 'oxford': | ||||
|     transform = transforms.Compose([ | ||||
|         transforms.Resize((224, 224)), | ||||
|         transforms.ToTensor() | ||||
|     ]) | ||||
|  | ||||
|  | ||||
| def to_tensor(pic): | ||||
|     """Convert a PIL Image to a PyTorch tensor. | ||||
|      | ||||
|     Args: | ||||
|         pic (PIL.Image.Image): Image to be converted to tensor. | ||||
|          | ||||
|     Returns: | ||||
|         Tensor: Converted image tensor with shape (C, H, W) and pixel values in range [0.0, 1.0]. | ||||
|     """ | ||||
|      | ||||
|     # Convert the image to a NumPy array | ||||
|     img = np.array(pic, dtype=np.float32) | ||||
|      | ||||
|     # If image has an alpha channel, discard it | ||||
|     if img.shape[-1] == 4: | ||||
|         img = img[:, :, :3] | ||||
|      | ||||
|     # Handle grayscale images (no channels dimension) | ||||
|     if len(img.shape) == 2: | ||||
|         img = np.expand_dims(img, axis=-1) | ||||
|      | ||||
|     # Transpose the dimensions from (H, W, C) to (C, H, W) | ||||
|     img = img.transpose((2, 0, 1)) | ||||
|      | ||||
|     # Normalize the pixel values to [0.0, 1.0] | ||||
|     img = img / 255.0 | ||||
|      | ||||
|     # Convert the NumPy array to a PyTorch tensor | ||||
|     tensor = torch.from_numpy(img) | ||||
|      | ||||
|     return tensor | ||||
|  | ||||
| # 使用ImageFolder加载数据集 | ||||
| if args.dataset == 'oxford': | ||||
|     train_data = torch.load(os.path.join(dataset_path, args.train_dataset_path)) | ||||
|     test_data = torch.load(os.path.join(dataset_path, args.test_dataset_path)) | ||||
|  | ||||
|     train_tensor_data = [(image, label) for image, label in train_data] | ||||
|     test_tensor_data = [(image, label) for image, label in test_data] | ||||
|     sum_data = train_tensor_data + test_tensor_data | ||||
|  | ||||
|     train_images = [image for image, label in train_tensor_data] | ||||
|     train_labels = torch.tensor([label for image, label in train_tensor_data]) | ||||
|     test_images = [image for image, label in test_tensor_data] | ||||
|     test_labels = torch.tensor([label for image, label in test_tensor_data]) | ||||
|     sum_images = [image for image, label in sum_data] | ||||
|     sum_labels = torch.tensor([label for image, label in sum_data]) | ||||
|  | ||||
|     train_tensors = torch.stack([transform(image) for image in train_images]) | ||||
|     test_tensors = torch.stack([transform(image) for image in test_images]) | ||||
|     sum_tensors = torch.stack([transform(image) for image in sum_images]) | ||||
|  | ||||
|     train_dataset = TensorDataset(train_tensors, train_labels) | ||||
|     test_dataset = TensorDataset(test_tensors, test_labels) | ||||
|     sum_dataset = TensorDataset(sum_tensors, sum_labels) | ||||
|  | ||||
|     train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=False, num_workers=4) | ||||
|     test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4) | ||||
|     dataloader = DataLoader(sum_dataset, batch_size=64, shuffle=False, num_workers=4) | ||||
| else: | ||||
|     dataset = datasets.ImageFolder(root=dataset_path, transform=transform) | ||||
|     dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4) | ||||
|  | ||||
| # 初始化变量来累积均值和标准差 | ||||
| mean = torch.zeros(3) | ||||
| std = torch.zeros(3) | ||||
| nb_samples = 0 | ||||
|  | ||||
| count = 0 | ||||
| for data in dataloader: | ||||
|     count += 1 | ||||
|     print(f'Processing batch {count}/{len(dataloader)}', end='\r') | ||||
|     batch_samples = data[0].size(0) | ||||
|     data = data[0].view(batch_samples, data[0].size(1), -1) | ||||
|     mean += data.mean(2).sum(0) | ||||
|     std += data.std(2).sum(0) | ||||
|     nb_samples += batch_samples | ||||
|  | ||||
| mean /= nb_samples | ||||
| std /= nb_samples | ||||
|  | ||||
| print(f'Mean: {mean}') | ||||
| print(f'Std: {std}') | ||||
		Reference in New Issue
	
	Block a user