2024-03-31 14:51:57 +02:00
|
|
|
|
{
|
|
|
|
|
"cells": [
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-04-01 00:16:59 +02:00
|
|
|
|
"execution_count": 17,
|
2024-03-31 14:51:57 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"import gc\n",
|
|
|
|
|
"import os\n",
|
|
|
|
|
"import cv2\n",
|
|
|
|
|
"import math\n",
|
|
|
|
|
"import base64\n",
|
|
|
|
|
"import random\n",
|
|
|
|
|
"import numpy as np\n",
|
|
|
|
|
"from PIL import Image \n",
|
|
|
|
|
"from tqdm import tqdm\n",
|
|
|
|
|
"from datetime import datetime\n",
|
|
|
|
|
"import matplotlib.pyplot as plt\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"import torch\n",
|
|
|
|
|
"import torch.nn as nn\n",
|
|
|
|
|
"from torch.cuda import amp\n",
|
|
|
|
|
"import torch.nn.functional as F\n",
|
|
|
|
|
"from torch.optim import Adam, AdamW\n",
|
|
|
|
|
"from torch.utils.data import Dataset, DataLoader\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"import torchvision\n",
|
|
|
|
|
"import torchvision.transforms as TF\n",
|
|
|
|
|
"import torchvision.datasets as datasets\n",
|
|
|
|
|
"from torchvision.utils import make_grid\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"from torchmetrics import MeanMetric\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"from IPython.display import display, HTML, clear_output\n"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"## Helper functions"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-04-01 12:38:50 +02:00
|
|
|
|
"execution_count": 45,
|
2024-03-31 14:51:57 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def to_device(data, device):\n",
|
|
|
|
|
" \"\"\"将张量移动到选择的设备\"\"\"\n",
|
|
|
|
|
" \"\"\"Move tensor(s) to chosen device\"\"\"\n",
|
|
|
|
|
" if isinstance(data, (list, tuple)):\n",
|
|
|
|
|
" return [to_device(x, device) for x in data]\n",
|
2024-04-01 12:38:50 +02:00
|
|
|
|
" return data.to(device, non_blocking=True)"
|
2024-03-31 14:51:57 +02:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-04-01 00:16:59 +02:00
|
|
|
|
"execution_count": 19,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"class DeviceDataLoader:\n",
|
|
|
|
|
" \"\"\"包装一个数据加载器,来把数据移动到另一个设备上\"\"\"\n",
|
|
|
|
|
" \"\"\"Wrap a dataloader to move data to a device\"\"\"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" def __init__(self, dl, device):\n",
|
|
|
|
|
" self.dl = dl\n",
|
|
|
|
|
" self.device = device\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" def __iter__(self):\n",
|
|
|
|
|
" \"\"\"在移动到设备后生成一个批次的数据\"\"\"\n",
|
|
|
|
|
" \"\"\"Yield a batch of data after moving it to device\"\"\"\n",
|
|
|
|
|
" for b in self.dl:\n",
|
|
|
|
|
" yield to_device(b, self.device)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" def __len__(self):\n",
|
|
|
|
|
" \"\"\"批次的数量\"\"\"\n",
|
|
|
|
|
" \"\"\"Number of batches\"\"\"\n",
|
|
|
|
|
" return len(self.dl)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 20,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def get_default_device():\n",
|
|
|
|
|
" return torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 21,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def save_images(images, path, **kwargs):\n",
|
|
|
|
|
" grid = make_grid(images, **kwargs)\n",
|
|
|
|
|
" ndarr = grid.permute(1,2,0).to(\"cpu\").numpy()\n",
|
|
|
|
|
" im = Image.fromarray(ndarr)\n",
|
|
|
|
|
" im.save(path)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 22,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def get(element: torch.Tensor, t: torch.Tensor):\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
" Get value at index position \"t\" in \"element\" and \n",
|
|
|
|
|
" reshape it to have the same dimension as a batch of images\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" 获得在\"element\"中位置\"t\"并且reshape,以和一组照片有相同的维度\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
" ele = element.gather(-1, t)\n",
|
|
|
|
|
" return ele.reshape(-1, 1, 1, 1)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 23,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"element = torch.tensor([[1,2,3,4,5],\n",
|
|
|
|
|
" [2,3,4,5,6],\n",
|
|
|
|
|
" [3,4,5,6,7]])\n",
|
|
|
|
|
"t = torch.tensor([1,2,0]).unsqueeze(1)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 24,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"tensor([[1, 2, 3, 4, 5],\n",
|
|
|
|
|
" [2, 3, 4, 5, 6],\n",
|
|
|
|
|
" [3, 4, 5, 6, 7]])\n",
|
|
|
|
|
"tensor([[1],\n",
|
|
|
|
|
" [2],\n",
|
|
|
|
|
" [0]])\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"print(element)\n",
|
|
|
|
|
"print(t)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 25,
|
2024-03-31 14:51:57 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
2024-04-01 00:16:59 +02:00
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"tensor([[[[2]]],\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" [[[4]]],\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" [[[3]]]])\n"
|
2024-03-31 14:51:57 +02:00
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
2024-04-01 00:16:59 +02:00
|
|
|
|
"source": [
|
|
|
|
|
"extracted_scores = get(element, t)\n",
|
|
|
|
|
"print(extracted_scores)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 26,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def setup_log_directory(config):\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
" Log and Model checkpoint directory Setup\n",
|
|
|
|
|
" 记录并且建模目录准备\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" if os.path.isdir(config.root_log_dir):\n",
|
|
|
|
|
" # Get all folders numbers in the root_log_dir\n",
|
|
|
|
|
" # 在root_log_dir下获得所有文件夹数目\n",
|
|
|
|
|
" folder_numbers = [int(folder.replace(\"version_\", \"\")) for folder in os.listdir(config.root_log_dir)]\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" # Find the latest version number present in the log_dir\n",
|
|
|
|
|
" # 找到在log_dir下的最新版本数字\n",
|
|
|
|
|
" last_version_number = max(folder_numbers)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" # New version name\n",
|
|
|
|
|
" version_name = f\"version{last_version_number + 1}\"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" else:\n",
|
|
|
|
|
" version_name = config.log_dir\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" # Update the training config default directory\n",
|
|
|
|
|
" # 更新训练config默认目录\n",
|
|
|
|
|
" log_dir = os.path.join(config.root_log_dir, version_name)\n",
|
|
|
|
|
" checkpoint_dir = os.path.join(config.root_checkpoint_dir, version_name) \n",
|
|
|
|
|
"\n",
|
|
|
|
|
" # Create new directory for saving new experiment version\n",
|
|
|
|
|
" # 创建一个新目录来保存新的实验版本\n",
|
|
|
|
|
" os.makedirs(log_dir, exist_ok=True)\n",
|
|
|
|
|
" os.makedirs(checkpoint_dir, exist_ok=True)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" print(f\"Logging at: {log_dir}\")\n",
|
|
|
|
|
" print(f\"Model Checkpoint at: {checkpoint_dir}\")\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" return log_dir, checkpoint_dir\n"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-04-01 12:38:50 +02:00
|
|
|
|
"execution_count": 28,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def frames2vid(images, save_path):\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" WIDTH = images[0].shape[1]\n",
|
|
|
|
|
" HEIGHT = images[0].shape[0]\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" # fourcc = cv2.VideoWriter_fourcc(*'XVID')\n",
|
|
|
|
|
" fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n",
|
|
|
|
|
" video = cv2.VideoWriter(save_path, fourcc, 25, (WIDTH, HEIGHT))\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" # Appending the images to the video one by one\n",
|
|
|
|
|
" # 一个接一个的将照片追加到视频\n",
|
|
|
|
|
" for image in images:\n",
|
|
|
|
|
" video.write(image)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" # Deallocating memories taken for window creation\n",
|
|
|
|
|
" # 释放创建window占用的内存\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" video.release()\n",
|
|
|
|
|
" return\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"def display_gif(gif_path):\n",
|
|
|
|
|
" b64 = base64.b64encode(open(gif_path,'rb').read()).decode('ascii')\n",
|
|
|
|
|
" display(HTML(f'<img src=\"data:image/gif;base64,{b64}\" />'))\n"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"## Configurations"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 29,
|
2024-04-01 00:16:59 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
2024-03-31 14:51:57 +02:00
|
|
|
|
"source": [
|
|
|
|
|
"from dataclasses import dataclass\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"@dataclass\n",
|
|
|
|
|
"class BaseConfig:\n",
|
|
|
|
|
" DEVICE = get_default_device()\n",
|
|
|
|
|
" DATASET = \"Flowers\" #MNIST \"cifar-10\" \"Flowers\"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" # 记录推断日志信息并保存存档点\n",
|
|
|
|
|
" root_log_dir = os.path.join(\"Logs_Checkpoints\", \"Inference\")\n",
|
|
|
|
|
" root_checkpoint_dir = os.path.join(\"Logs_Checkpoints\",\"checkpoints\")\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" #目前的日志和存档点目录\n",
|
|
|
|
|
" log_dir = \"version_0\"\n",
|
|
|
|
|
" checkpoint_dir = \"version_0\"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"@dataclass\n",
|
|
|
|
|
"class TrainingConfig:\n",
|
|
|
|
|
" TIMESTEPS = 1000\n",
|
|
|
|
|
" IMG_SHAPE = (1,32,32) if BaseConfig.DATASET == \"MNIST\" else (3,32,32)\n",
|
|
|
|
|
" NUM_EPOCHS = 800\n",
|
|
|
|
|
" BATCH_SIZE = 32\n",
|
|
|
|
|
" LR = 2e-4\n",
|
|
|
|
|
" NUM_WORKERS = 2"
|
|
|
|
|
]
|
2024-04-01 12:38:50 +02:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"## Load Dataset & Build Dataloader"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 38,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def get_dataset(dataset_name='MNIST'):\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
" Returns the dataset class object that will be passed to the Dataloader\n",
|
|
|
|
|
" Three preprocessing transforms,\n",
|
|
|
|
|
" and one augmentation are applied to every image in the dataset\n",
|
|
|
|
|
" 返回数据集的类对象\n",
|
|
|
|
|
" 这个类对象将会被传递给DataLoader\n",
|
|
|
|
|
" 数据集中的每个图象将会应用三个预处理转换\n",
|
|
|
|
|
" 和一个增强\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
" transforms = TF.Compose(\n",
|
|
|
|
|
" [\n",
|
|
|
|
|
" TF.ToTensor(),\n",
|
|
|
|
|
" TF.Resize((32,32),\n",
|
|
|
|
|
" interpolation=TF.InterpolationMode.BICUBIC,\n",
|
|
|
|
|
" antialias=True),\n",
|
|
|
|
|
" TF.RandomHorizontalFlip(),\n",
|
|
|
|
|
" TF.Lambda(lambda t: (t * 2) - 1) # scale between [-1, 1]\n",
|
|
|
|
|
" ]\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" if dataset_name.upper() == \"MNIST\":\n",
|
|
|
|
|
" dataset = datasets.MNIST(root=\"data\", train=True, download=True, transform=transforms)\n",
|
|
|
|
|
" elif dataset_name == \"Cifar-10\":\n",
|
|
|
|
|
" dataset = datasets.CIFAR10(root=\"data\", train=True,download=True, transform=transforms)\n",
|
|
|
|
|
" elif dataset_name == \"Cifar-100\":\n",
|
|
|
|
|
" dataset = datasets.CIFAR10(root=\"data\", train=True,download=True, transform=transforms)\n",
|
|
|
|
|
" elif dataset_name == \"Flowers\":\n",
|
|
|
|
|
" dataset = datasets.ImageFolder(root=\"data/flowers\", transform=transforms)\n",
|
|
|
|
|
" return dataset"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 40,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def get_dataloader(dataset_name='MNIST',\n",
|
|
|
|
|
" batch_size=32,\n",
|
|
|
|
|
" pin_memory=False,\n",
|
|
|
|
|
" shuffle=True,\n",
|
|
|
|
|
" num_workers=0,\n",
|
|
|
|
|
" device=\"cpu\"\n",
|
|
|
|
|
" ):\n",
|
|
|
|
|
" dataset = get_dataset(dataset_name=dataset_name)\n",
|
|
|
|
|
" dataLoader = DataLoader(dataset, batch_size=batch_size,\n",
|
|
|
|
|
" pin_memory=pin_memory,\n",
|
|
|
|
|
" num_workers=num_workers,\n",
|
|
|
|
|
" shuffle=shuffle\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" device_dataloader = DeviceDataLoader(dataLoader, device)\n",
|
|
|
|
|
" return device_dataloader"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 41,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def inverse_transform(tensors):\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
" Convert tensors from [-1., 1.] to [0., 255.]\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
" return ((tensors.clamp(-1, 1) + 1.0) / 2.0) * 255.0"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 39,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"Dataset ImageFolder\n",
|
|
|
|
|
" Number of datapoints: 4317\n",
|
|
|
|
|
" Root location: data/flowers\n",
|
|
|
|
|
" StandardTransform\n",
|
|
|
|
|
"Transform: Compose(\n",
|
|
|
|
|
" ToTensor()\n",
|
|
|
|
|
" Resize(size=(32, 32), interpolation=bicubic, max_size=None, antialias=True)\n",
|
|
|
|
|
" RandomHorizontalFlip(p=0.5)\n",
|
|
|
|
|
" Lambda()\n",
|
|
|
|
|
" )\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"dataset = get_dataset(dataset_name='Flowers')\n",
|
|
|
|
|
"print(dataset)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"## Visualize Dataset"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 43,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"loader = get_dataloader(\n",
|
|
|
|
|
" dataset_name=BaseConfig.DATASET,\n",
|
|
|
|
|
" batch_size=128,\n",
|
|
|
|
|
" device='cpu'\n",
|
|
|
|
|
")"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 46,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA64AAAHiCAYAAADoA5FMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9Waxta5bfCf3G18xuNbs/7W0jbtyIyN5dpl0yKhvbVWWgpILCpMAq8QBP1AvYCIH8YCEheECinigBT5RUUBRNQSEol6tcWFhygpts7MyMjIz2tqfd7epm83U8fHOuvc+NyHREOlLkwx5H66zdrD3XXHN+zfiP8R//ISmlxL3d273d273d273d273d273d273d2x9RU///PoF7u7d7u7d7u7d7u7d7u7d7u7d7u7ffz+6B673d273d273d273d273d273d2739kbZ74Hpv93Zv93Zv93Zv93Zv93Zv93Zvf6TtHrje273d273d273d273d273d273d2x9puweu93Zv93Zv93Zv93Zv93Zv93Zv9/ZH2u6B673d273d273d273d273d273d2739kbZ74Hpv93Zv93Zv93Zv93Zv93Zv93Zvf6TtHrje273d273d273d273d273d273d2x9puweu93Zv93Zv93Zv93Zv93Zv93Zv9/ZH2u6B673d273d273d273d273d273d2739kTbz47x4d9ny8nfO6WjppcXRE1MgxoQPga7vEQERQRAAQowkEiklIJEAhSAiGGNQSmGMRkShRJFSIqbEMPTEGAkhIKLyMUV+8KQkv1N+XxAlKCUopYDxPcf3hkQ+LaEoS7TRxOjyYZTw4Owxbz99nxfnN6y3PXH8DIzv+8b/Al/44u5J3T29P7Dl895/94Uv0+1P0ptf73/7xuvy/0rg3ccnpDDwW7/zawyDI/hAiIkUQVQ+adGAKJRSxJiICVKMkCCGiChQSkjT5/Mp34fxuicgxURKt58j36dxbKQ03YrxvMYYioDRBlFCCJ6UIEoaby6IJJQCYw0pQQgJQiLFRAyRmPL9Pj054+d/9o/T3ryiXb2mX6+RlDh6fILSioTgO0ccPLoqAdjdXEMMEAK2rDFFSQoRRFBlSXCOfrfNr4kRsQalNbYocc7T7Xr6rsc7T1UqrNXUsxnRB4be0fUO7wNVU2KsoSwrxBiUtSiVP19MPl+vCESQmD+vaEHPG8RYxNb5c0YwZkFKwsff+VWGvgOEGBIhRIxSKBEg4UOi7SJaa6zRNLVFa2HwAzEknI9one93iJEUIYS4n9sx5vsp4z0PxP39slqhteDbSPSJ4G9Hq4zz04hgreX07AxjNEYr2q5ncJ6mmWGtoSwL8mVIeO8JIdB3PdoWnL3/JZJS+NspzDTEp2kyfX37nG5/l+5+PX5PeuNrEsT9/JHxb9J+/E7H2K9l4xtN7zl98QPz8vYHt1M4xfG1cTxG5PRwwdGi5jf+6a9yfXOF1qCUQmuNVgYlirbtCCGQ1Hh+IYHK0yOE/OZKBCWCFmGmFKUo/OCQmKh8wmpFaQyH2tAoxefBsUmR58HTJOEIRZESJp/d/rMbBCuKmxTYkbiS/PnqkLAIpQiGvL4PRqiU4WlZEwEPPIsdnVH81C/8KWKCzz/7LF+HOzYt1enO+vZDLt+4rgqi8nqidN4f8jqSN4O8Z+RrmH+kfuA4KeYBpSQfS43rDEnwIREi+JiYFqqAwkfFxglNoTmeG5aNoSo029bT+cTFFg4bw9OTkm984xs8e/Z8f77TfvXmfnHnA057jZL9ngayH7tffPn+WdJ0+NtnQJLs11xR5LV3v2Wl/LPxb/LP07iPTudzZ8xGxvOX/bj84N2fQkTzD/7Br9APAyne3sMfxZS63bd1TEiCIoFKoKfTvHO1pufEfokkQt6HlKCMxjY1pigoqpoYIzEGVqt1XqO74Y33FwFr9f46hxDx/s0xeXJywi/+4i/y+uo1L86f5zVK8l4nIih167+IUuM1vjMe99dsvA/jOQXviCkSUyRN407y2hBTGOfG7djLd08Qye+HqPFPBFF6/77qzvtBIgSPDwHn3bgnp/1cEHXrHz05e4t5PefXfu0f0bZbqsrQlCWzekbf9zjviSpijGE5n9F3nnbbs+vy72yhsVazPGywRmOMwg0R7wLnF2tSgsKa/bo7OJf3mBRJKe9RSgvaCHVVoZUmBpevh7j8GcVgVEGMcH21hgRGK7ROKJ2wRqEUxKQgCRIVZZkoioSThC0bPnz/T/Lq1Wt+/dd//UcbpH9A0zqP4P3eAG/sJX+Y9sEHH/Dhhx/yq7/2q7x+/fp2QZje9/fwS2VcdLL/LKM/d+fvUva13vhMk4/9Qz+T/N7v9UNP5I63+sWFX958vTWGP/1LfxqRwHe+9xtYESpjMIWgdGLoPTEJPpXoskSXNZcXV3RtS1MZmqrg0elhni+S8KP/GGMixEjbDwTvCd5zcrikspa27wkx4kIgSt6l3BAAOGiq0d8CHyEgzJoaozXeu4xtYn6fECO79ZbgHMp5VEpoIMY3fY4YE4uTRzz68s/yO9/8Bh99/DFK5XkfY3zTzxKYN3kt2/Vhf+VSHB9MfrigDGiT55rWwsHcZH8+gQ8J7xNa5z1Rq3zebRtRCpROlEV+/WbjCWHEB+Pnq6sGqw0hJmL0DHFHUarsE+sahabtHCnlcxmcx/nAT3/1Fzg6PP3hg+WLYyf9GDPo+3//U/7D//5/xjP5iOf6U27Sc4a4o+0D6+2WZy9eZOdXC1obSND2HTFGXPD7AW60xmjDwWJOWZbM5jNKW2KMJfjI4Byvzl/T9z1tu8MYi9YGrfXkr+YFHkGPQNXoDKSKUmELQ1mXpOTz4h8diUhKLm9uWnH88DHNfEY/XIMETGH4L/3Lv8x/69/47/J/+U9/ld/6zjMcmoSMG4QA6tZRGoFGdpQUCbWf9G8A3cnBujMF907MfiBNv7uzOOxH3eQk5800Tp5ByhtcGn+Xpu8nZzuGO7+Lo2Oendb/zl/5c/juNf/mX/9lLi8vWa+3dG3EDQlTCsoINBqtDdZWdD7gfCIMnuQCbtejrcLU+jZnvwlIAlPZffDBu5BB0ZCDA0ZpjCi0KGR0UkQllBKssnnzVsJyMaMwhs1ui0+RXkWSFjBCYSPGCAdHc2KE3dYT20DoIl3b4YPHhcBf/HP/Cv/b/9X/gY9/9T/ik9/427z4rd9GRcef/St/DttUeNFsP7uifbli/vZjIpGPf+0fE3Zb2O04efIuy9NHDLselKZ6/JDt9SUvv/dt0q4FN1AcLSmamsMHD7m5XPPZ91/w/LOXrC6veftxxeHBjPc+/ArtuuXy+QWfv7jiZtPy1pcfsThY8ODJWxSLBeXRIabUKCP0bpWBSS/IEFFdoqzAVJrmZ76Enh+iDt4hhIRzgeXyZ/HO8r/8n/7rXLz8FEmWrh3YbXuWZUVpNDEmtrvAR89a5rOGg8WcL727ZDYzXNxc0naO66uBqimpmoJ2N+BcpNs6+sGx2rQMQ8S5hCryPd/iSJJB0+GiYD6z3Hy/Y1gHtqtITEJUCjMCqIUxHB8e8Jf+8l9mOW9YNCWfPHvJ+eU173/5KxwcHPDkyQlGgyGyWa9ot1tefPaS2fEZ/9J/76/ji4p1iOhxioU4LbR57oQEMUJM4GIGoX4E3yEl4hjgcCGNjlLeQMIYIIspMoRATBBEEX0gOD8C+YQfn4PPfxND2B8zB37ywp3ubD4pkU/qzuaeSKTgSTESoyNGT/ADf+mXfoo/8zPv8m/8t/8K//Af/33qOlGUBbPZglnZUJqKjz76jG3bQanxPjLsPLpMKJvYbQIpQKEtldbMreGDouSRsazOb1Cd4+nac1RVPDo84M/UM94rCv73u2t+23f8B+2a96Lhz8SChymwSJEhJQI5iLFQhmNl+VXf8f3k+BWbr8M7beBYFA+0YZkUVoSrecGTesZfffAevQhrgX+v+4xny5J/69/9f9IPgf/dv/Pv4IPfr3wJSGHc3GO+TvFO4C3GW6dFK4VSGm00WivKqshBGVvsHXtd5MBSWU4BUnMLrFCQEsF5lOT
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 1200x600 with 1 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"plt.figure(figsize=(12, 6), facecolor='white')\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"for b_image, _ in loader:\n",
|
|
|
|
|
" b_image = inverse_transform(b_image).cpu()\n",
|
|
|
|
|
" grid_img = make_grid(b_image / 255.0, \n",
|
|
|
|
|
" nrow = 16, \n",
|
|
|
|
|
" padding=True,\n",
|
|
|
|
|
" pad_value=1,\n",
|
|
|
|
|
" normalize=True \n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" plt.imshow(grid_img.permute(1, 2, 0))\n",
|
|
|
|
|
" plt.axis(\"off\")\n",
|
|
|
|
|
" break"
|
|
|
|
|
]
|
2024-03-31 14:51:57 +02:00
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"metadata": {
|
|
|
|
|
"kernelspec": {
|
|
|
|
|
"display_name": "DLML",
|
|
|
|
|
"language": "python",
|
|
|
|
|
"name": "python3"
|
|
|
|
|
},
|
|
|
|
|
"language_info": {
|
|
|
|
|
"codemirror_mode": {
|
|
|
|
|
"name": "ipython",
|
|
|
|
|
"version": 3
|
|
|
|
|
},
|
|
|
|
|
"file_extension": ".py",
|
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
|
"name": "python",
|
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
|
"pygments_lexer": "ipython3",
|
2024-04-01 00:16:59 +02:00
|
|
|
|
"version": "3.11.8"
|
2024-03-31 14:51:57 +02:00
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"nbformat": 4,
|
|
|
|
|
"nbformat_minor": 2
|
|
|
|
|
}
|