MasterThesis/Generating_MNIST_using_DDPMs.ipynb

1895 lines
3.1 MiB
Plaintext
Raw Normal View History

2024-04-08 11:37:01 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:34:53.034694Z",
"start_time": "2023-02-23T07:34:50.200674Z"
},
"execution": {
"iopub.execute_input": "2023-02-22T16:01:33.905754Z",
"iopub.status.busy": "2023-02-22T16:01:33.904879Z",
"iopub.status.idle": "2023-02-22T16:01:36.156631Z",
"shell.execute_reply": "2023-02-22T16:01:36.155423Z",
"shell.execute_reply.started": "2023-02-22T16:01:33.905665Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\vaibh\\miniconda3\\envs\\pytorchx\\lib\\site-packages\\torchvision\\io\\image.py:13: UserWarning: Failed to load image Python extension: [WinError 127] The specified procedure could not be found\n",
" warn(f\"Failed to load image Python extension: {e}\")\n"
]
}
],
"source": [
"import gc\n",
"import os\n",
"import cv2\n",
"import math\n",
"import base64\n",
"import random\n",
"import numpy as np\n",
"from PIL import Image\n",
"from tqdm import tqdm\n",
"from datetime import datetime\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.cuda import amp\n",
"import torch.nn.functional as F\n",
"from torch.optim import Adam, AdamW\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"import torchvision\n",
"import torchvision.transforms as TF\n",
"import torchvision.datasets as datasets\n",
"from torchvision.utils import make_grid\n",
"\n",
"from torchmetrics import MeanMetric\n",
"\n",
"from IPython.display import display, HTML, clear_output"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# UNet Model"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:34:53.142674Z",
"start_time": "2023-02-23T07:34:53.114668Z"
},
"code_folding": [
0,
1,
25,
105
],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:36.165052Z",
"iopub.status.busy": "2023-02-22T16:01:36.164311Z",
"iopub.status.idle": "2023-02-22T16:01:36.199528Z",
"shell.execute_reply": "2023-02-22T16:01:36.198080Z",
"shell.execute_reply.started": "2023-02-22T16:01:36.165002Z"
}
},
"outputs": [],
"source": [
"class SinusoidalPositionEmbeddings(nn.Module):\n",
" def __init__(self, total_time_steps=1000, time_emb_dims=128, time_emb_dims_exp=512):\n",
" super().__init__()\n",
"\n",
" half_dim = time_emb_dims // 2\n",
"\n",
" emb = math.log(10000) / (half_dim - 1)\n",
" emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)\n",
"\n",
" ts = torch.arange(total_time_steps, dtype=torch.float32)\n",
"\n",
" emb = torch.unsqueeze(ts, dim=-1) * torch.unsqueeze(emb, dim=0)\n",
" emb = torch.cat((emb.sin(), emb.cos()), dim=-1)\n",
"\n",
" self.time_blocks = nn.Sequential(\n",
" nn.Embedding.from_pretrained(emb),\n",
" nn.Linear(in_features=time_emb_dims, out_features=time_emb_dims_exp),\n",
" nn.SiLU(),\n",
" nn.Linear(in_features=time_emb_dims_exp, out_features=time_emb_dims_exp),\n",
" )\n",
"\n",
" def forward(self, time):\n",
" return self.time_blocks(time)\n",
"\n",
"\n",
"class AttentionBlock(nn.Module):\n",
" def __init__(self, channels=64):\n",
" super().__init__()\n",
" self.channels = channels\n",
"\n",
" self.group_norm = nn.GroupNorm(num_groups=8, num_channels=channels)\n",
" self.mhsa = nn.MultiheadAttention(embed_dim=self.channels, num_heads=4, batch_first=True)\n",
"\n",
" def forward(self, x):\n",
" B, _, H, W = x.shape\n",
" h = self.group_norm(x)\n",
" h = h.reshape(B, self.channels, H * W).swapaxes(1, 2) # [B, C, H, W] --> [B, C, H * W] --> [B, H*W, C]\n",
" h, _ = self.mhsa(h, h, h) # [B, H*W, C]\n",
" h = h.swapaxes(2, 1).view(B, self.channels, H, W) # [B, C, H*W] --> [B, C, H, W]\n",
" return x + h\n",
"\n",
"\n",
"class ResnetBlock(nn.Module):\n",
" def __init__(self, *, in_channels, out_channels, dropout_rate=0.1, time_emb_dims=512, apply_attention=False):\n",
" super().__init__()\n",
" self.in_channels = in_channels\n",
" self.out_channels = out_channels\n",
"\n",
" self.act_fn = nn.SiLU()\n",
" # Group 1\n",
" self.normlize_1 = nn.GroupNorm(num_groups=8, num_channels=self.in_channels)\n",
" self.conv_1 = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=3, stride=1, padding=\"same\")\n",
"\n",
" # Group 2 time embedding\n",
" self.dense_1 = nn.Linear(in_features=time_emb_dims, out_features=self.out_channels)\n",
"\n",
" # Group 3\n",
" self.normlize_2 = nn.GroupNorm(num_groups=8, num_channels=self.out_channels)\n",
" self.dropout = nn.Dropout2d(p=dropout_rate)\n",
" self.conv_2 = nn.Conv2d(in_channels=self.out_channels, out_channels=self.out_channels, kernel_size=3, stride=1, padding=\"same\")\n",
"\n",
" if self.in_channels != self.out_channels:\n",
" self.match_input = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=1, stride=1)\n",
" else:\n",
" self.match_input = nn.Identity()\n",
"\n",
" if apply_attention:\n",
" self.attention = AttentionBlock(channels=self.out_channels)\n",
" else:\n",
" self.attention = nn.Identity()\n",
"\n",
" def forward(self, x, t):\n",
" # group 1\n",
" h = self.act_fn(self.normlize_1(x))\n",
" h = self.conv_1(h)\n",
"\n",
" # group 2\n",
" # add in timestep embedding\n",
" h += self.dense_1(self.act_fn(t))[:, :, None, None]\n",
"\n",
" # group 3\n",
" h = self.act_fn(self.normlize_2(h))\n",
" h = self.dropout(h)\n",
" h = self.conv_2(h)\n",
"\n",
" # Residual and attention\n",
" h = h + self.match_input(x)\n",
" h = self.attention(h)\n",
"\n",
" return h\n",
"\n",
"\n",
"class DownSample(nn.Module):\n",
" def __init__(self, channels):\n",
" super().__init__()\n",
" self.downsample = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, stride=2, padding=1)\n",
"\n",
" def forward(self, x, *args):\n",
" return self.downsample(x)\n",
"\n",
"\n",
"class UpSample(nn.Module):\n",
" def __init__(self, in_channels):\n",
" super().__init__()\n",
"\n",
" self.upsample = nn.Sequential(\n",
" nn.Upsample(scale_factor=2, mode=\"nearest\"),\n",
" nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1),\n",
" )\n",
"\n",
" def forward(self, x, *args):\n",
" return self.upsample(x)\n",
"\n",
"\n",
"class UNet(nn.Module):\n",
" def __init__(\n",
" self,\n",
" input_channels=3,\n",
" output_channels=3,\n",
" num_res_blocks=2,\n",
" base_channels=128,\n",
" base_channels_multiples=(1, 2, 4, 8),\n",
" apply_attention=(False, False, True, False),\n",
" dropout_rate=0.1,\n",
" time_multiple=4,\n",
" ):\n",
" super().__init__()\n",
"\n",
" time_emb_dims_exp = base_channels * time_multiple\n",
" self.time_embeddings = SinusoidalPositionEmbeddings(time_emb_dims=base_channels, time_emb_dims_exp=time_emb_dims_exp)\n",
"\n",
" self.first = nn.Conv2d(in_channels=input_channels, out_channels=base_channels, kernel_size=3, stride=1, padding=\"same\")\n",
"\n",
" num_resolutions = len(base_channels_multiples)\n",
"\n",
" # Encoder part of the UNet. Dimension reduction.\n",
" self.encoder_blocks = nn.ModuleList()\n",
" curr_channels = [base_channels]\n",
" in_channels = base_channels\n",
"\n",
" for level in range(num_resolutions):\n",
" out_channels = base_channels * base_channels_multiples[level]\n",
"\n",
" for _ in range(num_res_blocks):\n",
"\n",
" block = ResnetBlock(\n",
" in_channels=in_channels,\n",
" out_channels=out_channels,\n",
" dropout_rate=dropout_rate,\n",
" time_emb_dims=time_emb_dims_exp,\n",
" apply_attention=apply_attention[level],\n",
" )\n",
" self.encoder_blocks.append(block)\n",
"\n",
" in_channels = out_channels\n",
" curr_channels.append(in_channels)\n",
"\n",
" if level != (num_resolutions - 1):\n",
" self.encoder_blocks.append(DownSample(channels=in_channels))\n",
" curr_channels.append(in_channels)\n",
"\n",
" # Bottleneck in between\n",
" self.bottleneck_blocks = nn.ModuleList(\n",
" (\n",
" ResnetBlock(\n",
" in_channels=in_channels,\n",
" out_channels=in_channels,\n",
" dropout_rate=dropout_rate,\n",
" time_emb_dims=time_emb_dims_exp,\n",
" apply_attention=True,\n",
" ),\n",
" ResnetBlock(\n",
" in_channels=in_channels,\n",
" out_channels=in_channels,\n",
" dropout_rate=dropout_rate,\n",
" time_emb_dims=time_emb_dims_exp,\n",
" apply_attention=False,\n",
" ),\n",
" )\n",
" )\n",
"\n",
" # Decoder part of the UNet. Dimension restoration with skip-connections.\n",
" self.decoder_blocks = nn.ModuleList()\n",
"\n",
" for level in reversed(range(num_resolutions)):\n",
" out_channels = base_channels * base_channels_multiples[level]\n",
"\n",
" for _ in range(num_res_blocks + 1):\n",
" encoder_in_channels = curr_channels.pop()\n",
" block = ResnetBlock(\n",
" in_channels=encoder_in_channels + in_channels,\n",
" out_channels=out_channels,\n",
" dropout_rate=dropout_rate,\n",
" time_emb_dims=time_emb_dims_exp,\n",
" apply_attention=apply_attention[level],\n",
" )\n",
"\n",
" in_channels = out_channels\n",
" self.decoder_blocks.append(block)\n",
"\n",
" if level != 0:\n",
" self.decoder_blocks.append(UpSample(in_channels))\n",
"\n",
" self.final = nn.Sequential(\n",
" nn.GroupNorm(num_groups=8, num_channels=in_channels),\n",
" nn.SiLU(),\n",
" nn.Conv2d(in_channels=in_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=\"same\"),\n",
" )\n",
"\n",
" def forward(self, x, t):\n",
"\n",
" time_emb = self.time_embeddings(t)\n",
"\n",
" h = self.first(x)\n",
" outs = [h]\n",
"\n",
" for layer in self.encoder_blocks:\n",
" h = layer(h, time_emb)\n",
" outs.append(h)\n",
"\n",
" for layer in self.bottleneck_blocks:\n",
" h = layer(h, time_emb)\n",
"\n",
" for layer in self.decoder_blocks:\n",
" if isinstance(layer, ResnetBlock):\n",
" out = outs.pop()\n",
" h = torch.cat([h, out], dim=1)\n",
" h = layer(h, time_emb)\n",
"\n",
" h = self.final(h)\n",
"\n",
" return h"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Helper Functions"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:34:53.203694Z",
"start_time": "2023-02-23T07:34:53.177669Z"
},
"code_folding": [
0,
25,
31,
39,
68
],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:36.201566Z",
"iopub.status.busy": "2023-02-22T16:01:36.201193Z",
"iopub.status.idle": "2023-02-22T16:01:36.219110Z",
"shell.execute_reply": "2023-02-22T16:01:36.218110Z",
"shell.execute_reply.started": "2023-02-22T16:01:36.201528Z"
}
},
"outputs": [],
"source": [
"def to_device(data, device):\n",
" \"\"\"Move tensor(s) to chosen device\"\"\"\n",
" if isinstance(data, (list, tuple)):\n",
" return [to_device(x, device) for x in data]\n",
" return data.to(device, non_blocking=True)\n",
"\n",
"class DeviceDataLoader:\n",
" \"\"\"Wrap a dataloader to move data to a device\"\"\"\n",
"\n",
" def __init__(self, dl, device):\n",
" self.dl = dl\n",
" self.device = device\n",
"\n",
" def __iter__(self):\n",
" \"\"\"Yield a batch of data after moving it to device\"\"\"\n",
" for b in self.dl:\n",
" yield to_device(b, self.device)\n",
"\n",
" def __len__(self):\n",
" \"\"\"Number of batches\"\"\"\n",
" return len(self.dl)\n",
"\n",
"def get_default_device():\n",
" return torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"def save_images(images, path, **kwargs):\n",
" grid = torchvision.utils.make_grid(images, **kwargs)\n",
" ndarr = grid.permute(1, 2, 0).to(\"cpu\").numpy()\n",
" im = Image.fromarray(ndarr)\n",
" im.save(path)\n",
" \n",
"def get(element: torch.Tensor, t: torch.Tensor):\n",
" \"\"\"\n",
" Get value at index position \"t\" in \"element\" and\n",
" reshape it to have the same dimension as a batch of images.\n",
" \"\"\"\n",
" ele = element.gather(-1, t)\n",
" return ele.reshape(-1, 1, 1, 1)\n",
"\n",
"def setup_log_directory(config):\n",
" '''Log and Model checkpoint directory Setup'''\n",
" \n",
" if os.path.isdir(config.root_log_dir):\n",
" # Get all folders numbers in the root_log_dir\n",
" folder_numbers = [int(folder.replace(\"version_\", \"\")) for folder in os.listdir(config.root_log_dir)]\n",
" \n",
" # Find the latest version number present in the log_dir\n",
" last_version_number = max(folder_numbers)\n",
"\n",
" # New version name\n",
" version_name = f\"version_{last_version_number + 1}\"\n",
"\n",
" else:\n",
" version_name = config.log_dir\n",
"\n",
" # Update the training config default directory \n",
" log_dir = os.path.join(config.root_log_dir, version_name)\n",
" checkpoint_dir = os.path.join(config.root_checkpoint_dir, version_name)\n",
"\n",
" # Create new directory for saving new experiment version\n",
" os.makedirs(log_dir, exist_ok=True)\n",
" os.makedirs(checkpoint_dir, exist_ok=True)\n",
"\n",
" print(f\"Logging at: {log_dir}\")\n",
" print(f\"Model Checkpoint at: {checkpoint_dir}\")\n",
" \n",
" return log_dir, checkpoint_dir\n",
"\n",
"def frames2vid(images, save_path):\n",
"\n",
" WIDTH = images[0].shape[1]\n",
" HEIGHT = images[0].shape[0]\n",
"\n",
"# fourcc = cv2.VideoWriter_fourcc(*'XVID')\n",
"# fourcc = 0\n",
" fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n",
" video = cv2.VideoWriter(save_path, fourcc, 25, (WIDTH, HEIGHT))\n",
"\n",
" # Appending the images to the video one by one\n",
" for image in images:\n",
" video.write(image)\n",
"\n",
" # Deallocating memories taken for window creation\n",
"# cv2.destroyAllWindows()\n",
" video.release()\n",
" return \n",
"\n",
"def display_gif(gif_path):\n",
" b64 = base64.b64encode(open(gif_path,'rb').read()).decode('ascii')\n",
" display(HTML(f'<img src=\"data:image/gif;base64,{b64}\" />'))"
]
},
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-02T18:23:28.639407Z",
"start_time": "2023-02-02T18:23:28.624407Z"
}
},
"source": [
"# Configurations"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:34:54.785018Z",
"start_time": "2023-02-23T07:34:53.639004Z"
},
"code_folding": [],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:36.712080Z",
"iopub.status.busy": "2023-02-22T16:01:36.711714Z",
"iopub.status.idle": "2023-02-22T16:01:36.720529Z",
"shell.execute_reply": "2023-02-22T16:01:36.719078Z",
"shell.execute_reply.started": "2023-02-22T16:01:36.712047Z"
}
},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"\n",
"@dataclass\n",
"class BaseConfig:\n",
" DEVICE = get_default_device()\n",
" DATASET = \"MNIST\" # \"MNIST\", \"Cifar-10\", \"Cifar-100\", \"Flowers\"\n",
" \n",
" # For logging inferece images and saving checkpoints.\n",
" root_log_dir = os.path.join(\"Logs_Checkpoints\", \"Inference\")\n",
" root_checkpoint_dir = os.path.join(\"Logs_Checkpoints\", \"checkpoints\")\n",
"\n",
" # Current log and checkpoint directory.\n",
" log_dir = \"version_0\"\n",
" checkpoint_dir = \"version_0\"\n",
"\n",
"@dataclass\n",
"class TrainingConfig:\n",
" TIMESTEPS = 1000 # Define number of diffusion timesteps\n",
" IMG_SHAPE = (1, 32, 32) if BaseConfig.DATASET == \"MNIST\" else (3, 32, 32) \n",
" NUM_EPOCHS = 30\n",
" BATCH_SIZE = 128\n",
" LR = 2e-4\n",
" NUM_WORKERS = 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-02T18:24:36.837306Z",
"start_time": "2023-02-02T18:24:36.8273Z"
}
},
"source": [
"# Load Dataset & Build Dataloader"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:34:55.313081Z",
"start_time": "2023-02-23T07:34:55.291079Z"
},
"code_folding": [
0,
21,
37
],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:37.597072Z",
"iopub.status.busy": "2023-02-22T16:01:37.596702Z",
"iopub.status.idle": "2023-02-22T16:01:37.608430Z",
"shell.execute_reply": "2023-02-22T16:01:37.607135Z",
"shell.execute_reply.started": "2023-02-22T16:01:37.597040Z"
}
},
"outputs": [],
"source": [
"def get_dataset(dataset_name='MNIST'):\n",
" transforms = TF.Compose(\n",
" [\n",
" TF.ToTensor(),\n",
" TF.Resize((32, 32), \n",
" interpolation=TF.InterpolationMode.BICUBIC, \n",
" antialias=True),\n",
"# TF.RandomHorizontalFlip(),\n",
" TF.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1] \n",
" ]\n",
" )\n",
" \n",
" if dataset_name.upper() == \"MNIST\":\n",
" dataset = datasets.MNIST(root=\"data\", train=True, download=True, transform=transforms)\n",
" elif dataset_name == \"Cifar-10\": \n",
" dataset = datasets.CIFAR10(root=\"data\", train=True, download=True, transform=transforms)\n",
" elif dataset_name == \"Cifar-100\":\n",
" dataset = datasets.CIFAR10(root=\"data\", train=True, download=True, transform=transforms)\n",
" elif dataset_name == \"Flowers\":\n",
" dataset = datasets.ImageFolder(root=\"/kaggle/input/flowers-recognition/flowers\", transform=transforms)\n",
" \n",
" return dataset\n",
"\n",
"def get_dataloader(dataset_name='MNIST', \n",
" batch_size=32, \n",
" pin_memory=False, \n",
" shuffle=True, \n",
" num_workers=0, \n",
" device=\"cpu\"\n",
" ):\n",
" dataset = get_dataset(dataset_name=dataset_name)\n",
" dataloader = DataLoader(dataset, batch_size=batch_size, \n",
" pin_memory=pin_memory, \n",
" num_workers=num_workers, \n",
" shuffle=shuffle\n",
" )\n",
" device_dataloader = DeviceDataLoader(dataloader, device)\n",
" return device_dataloader\n",
"\n",
"def inverse_transform(tensors):\n",
" \"\"\"Convert tensors from [-1., 1.] to [0., 255.]\"\"\"\n",
" return ((tensors.clamp(-1, 1) + 1.0) / 2.0) * 255.0 "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualize Dataset"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:35:01.528613Z",
"start_time": "2023-02-23T07:34:57.085306Z"
},
"code_folding": [],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:39.192513Z",
"iopub.status.busy": "2023-02-22T16:01:39.191545Z",
"iopub.status.idle": "2023-02-22T16:01:39.267958Z",
"shell.execute_reply": "2023-02-22T16:01:39.266946Z",
"shell.execute_reply.started": "2023-02-22T16:01:39.192465Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data\\MNIST\\raw\\train-images-idx3-ubyte.gz\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "941ed307c4714d94a98b12b2349decd8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/9912422 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting data\\MNIST\\raw\\train-images-idx3-ubyte.gz to data\\MNIST\\raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data\\MNIST\\raw\\train-labels-idx1-ubyte.gz\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5d3498e2738b4056868e2ca3b19be20b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28881 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting data\\MNIST\\raw\\train-labels-idx1-ubyte.gz to data\\MNIST\\raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data\\MNIST\\raw\\t10k-images-idx3-ubyte.gz\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "50b753def6a742e5b9a2b637f63bec15",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1648877 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting data\\MNIST\\raw\\t10k-images-idx3-ubyte.gz to data\\MNIST\\raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "423a266fda0247859ccb30c8ff2b1f89",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4542 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting data\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz to data\\MNIST\\raw\n",
"\n"
]
}
],
"source": [
"loader = get_dataloader(\n",
" dataset_name=BaseConfig.DATASET,\n",
" batch_size=128,\n",
" device='cpu',\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-21T08:04:00.707591Z",
"start_time": "2023-02-21T08:03:59.270574Z"
},
"code_folding": [],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:39.573357Z",
"iopub.status.busy": "2023-02-22T16:01:39.572658Z",
"iopub.status.idle": "2023-02-22T16:01:39.920898Z",
"shell.execute_reply": "2023-02-22T16:01:39.919984Z",
"shell.execute_reply.started": "2023-02-22T16:01:39.573319Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA64AAAHiCAYAAADoA5FMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9d2ykd34e/kzvvQ9nhsMhh71sr5JWZ51k6VR8dz7f2WcntuMSB0HiOLAdx7AdBIYRw4cgiBEgsC8BHJf4bMd3vrI6SadT29VWbuGy9+HMkFM4vff5/aHf56vhiiutdCRnuHofgFhpl8t933m/7/f7Kc/neXjNZrMJDhw4cODAgQMHDhw4cODAoUPBb/cFcODAgQMHDhw4cODAgQMHDh8GLnHlwIEDBw4cOHDgwIEDBw4dDS5x5cCBAwcOHDhw4MCBAwcOHQ0uceXAgQMHDhw4cODAgQMHDh0NLnHlwIEDBw4cOHDgwIEDBw4dDS5x5cCBAwcOHDhw4MCBAwcOHQ0uceXAgQMHDhw4cODAgQMHDh0NLnHlwIEDBw4cOHDgwIEDBw4dDS5x5cCBAwcOHDhw4MCBAwcOHQ0uceXAgQMHDhw4cODAgQMHDh0N4cf55mg0isnJSTSbzf26nrbC7XZjeHgYd+7cQSgUavfl7AsEAgHOnDkDgUCAK1euoFartfuS9gVmsxknTpzA4uIiVldX2305+4Zjx45Br9fjypUrKBQK7b6cfYFKpcK5c+cQCoVw7969dl/OvmFwcBA9PT24ceMG4vF4uy9nXyAWi3H+/HkUi0XcuHEDjUaj3Ze0L3A6nRgfH8fU1BSCwWC7L2dfwOfzcerUKUilUly5cgWVSqXdl7QvMBgMOHXqFNbW1rC4uNjuy9k3TExMwGKx4OrVq8hms+2+nH2BQqHAuXPnEIvFcOfOnXZfzr6hr68P/f39mJycRDQabffl7AtEIhHOnj2LWq2G69evo16vt/uS9gV2ux1HjhzB3NwcfD5fuy9nX8Dj8XDy5EmYTKaH+wvNj4GLFy82hUJhk8/nP5Jfv/7rv95sNBrNr3zlK22/lv36UiqVzRs3bjSnp6ebWq227dezX18vvvhis1arNX/3d3+37deyX18CgaD5D//wD81oNNr0er1tv579+pqYmGim0+nm//7f/7vt17KfX3/8x3/cLJfLzc985jNtv5b9+rJYLM3V1dXmG2+80ZRIJG2/nv36+pVf+ZVmo9Fo/ot/8S/afi379SWVSptvv/12c3l5uWk2m9t+Pfv19dRTTzUrlUrzj/7oj9p+Lfv59Rd/8RfNVCrVHB8fb/u17NfXwMBAMxaLNb/xjW80BQJB269nv75+//d/v1mr1ZovvPBC269lv770en1zdna2ee3ataZCoWj79ezX11e/+tVmo9Fo/tt/+2/bfi379SUSiZqvvPLKQ+eiH6vj2mw20Wg0HtkqOXWS6T4fRVBV6lF/lnRfj/KzBN5fs9yzPPz4tDzLT8t90q+P6j1+Ws6ST0NcALx3f/vxLAUCAVwuF7RaLdxuN7a3txEOhxEKhZDP5/fs33kYtJ4lj2qHDvh0nCWflv3n03CW8Hi8j8Xk/ViJKwcOHDhweLTA578ndUCHB31x4MCh/eDxeDt+vR+d/r4KBAK43W50d3fjwoULmJ+fx927d5FOpw88ceXAAXj/zNvLd4fH47GvVnBn6t6DS1w5cODA4VMIqVQKuVyOp59+GlarFV1dXQgEArh79y7W1tawvb2NarXKHbgcOLQBWq0WGo0GDocDCoUCKpUKUqkUSqUSIpEIAoEAm5ubiMViuH79OiqVSsd1EaVSKdRqNcbHx9HX14eJiQnkcjksLi5CIBC0+/I4fMogk8mgVCrx2c9+Fs1mE4uLi9je3kYkEkGtVvvEZ51EIsHx48fR1dWFsbExSCQSiMViFItF5HI5vPPOO4hEIlhdXX1ku6YHCS5x5cDhgCAUCiEQCCAUCtFsNlGpVB5p+geHzoZEIoFKpcLw8DDcbjc8Hg/0ej3S6TSSySTS6fSPdJhz4PBJIBAIwOfzIZFIduyXhHq9jlKphHq9/kiKC/L5fIjFYhiNRlgsFvT19UGj0UCv10Mul0Oj0UAsFkMoFEKj0WBzcxPz8/PIZrMdJ9Ank8mgVqtht9ths9nYPQgEggd2kDlw2C8IhULIZDL09vai0WggHA4jm82yTunHOev4fD4EAgE7R/v6+uDxeHDmzBlIpVJIJBKkUinE43HMzMwwwbOP++9w+CC4xJUDhwOC3W6HyWRCd3c3arUaZmdn2cbGgcNBgc/nQygUwul0wu124+mnn4bH44FWq4XBYGCV4lKpBL/f/8iqxXLoPPB4PBiNRuh0Ohw7dgxGoxFOpxP1eh2NRgP1eh2pVApXrlxBKBR6JFU29Xo9hoaGcP78eRw5cgQOhwNKpRIajYZ1XCn4XV5eRiAQQDabxcrKCm7fvt3uy9+B3t5e9PX14ezZszCZTOy6uYIYh3ZAKBRCIpHAZrOhUCggn8+jXC7v0F54WOh0OphMJpw8eRIulwtPP/00LBYL3G43K8osLi6iWCxie3sbyWSSowzvEQ5V4qpSqSAWi6FQKMDn8xlPHXhvcLlWqyGfz6NSqaBUKn2ixciBw15DKBRCLBajq6sLbrcbXq8X5XIZsVgM9XqdS1w5HCh4PB6EQiGMRiNcLhd0Oh3kcjlqtRpqtRqjG3IdEQ4HCbFYDIlEApfLBZvNhqNHj8JoNMJutzNmSqPRQDKZRCaTgVqtRrPZRDqdRrFYRKVSObTnPY/Hg0AggFqtRldXF4aHhzEwMIDe3l7odDqIRCLU63WUy2XUajXIZDKIxWIYDAY0Gg309/ejUqlgeXkZpVIJ1Wq13bcE4D26s9VqhcFggEqlYuwibm/h0A7U63VUKhWEQiEUCgWkUikUCoWHzhXo7JTL5XA4HPB4PDhy5AhcLhcrMDUaDWSzWWSzWayurmJ9fR3JZJKb595DHJrElc/nw+VywWAwwOPxQCwWQywWsz+vVCrI5/NYXl5GPB5HKBRCpVJ5JKlEHA4XZDIZq8wdO3YMo6OjyOVyrAL3KHYNOHQuhEIhpFIpvF4vTp06BYPBAKFQiEgkgkgkgnA4jHw+33HzchwebajVauj1epw/fx5erxcvvPACNBoN1Gr1ju/LZDIwm81YXV2Fy+XCnTt34Pf7kUgkDu15LxAIIJfL0d/fj9HRUbz00ksYHBxEd3c3KpUKCoUCFhcXkcvlkE6n0dXVBZPJBIfDAb1ej89+9rNQKBRYXl5GKBRiZ0u74XA4MDAwgK6uLkgkEmxvbzOKJZe8cjholMtlpNNpXL58GaVSCT6fj7E5HgYCgQBKpRLd3d04c+YMzpw5gwsXLsBut0MkEqFYLCIWi2F+fh6Li4u4fv06gsEg1tbWOObSHqJjE1eqQIrFYkilUshkMpw7dw7d3d04cuQIEycg1Ot1ZLNZJixy9+5dRCIRZDKZQ01LIe49n89/IDf+sB7WD4JUKmUzPgBQKpWQz+eRy+UOVRedx+NBJBLBYrFgYmIC4+PjGB0dhV6vR6PRgFQqhUgkavdlctgntL6znUQRonVJNEyJRIJKpYKVlRXMzs7i1q1b8Pl8iMfjj9zewqFzQUI+PT098Hg80Gg0kEgkaDabO5IcqVSKgYEBmEwmuN1uaLVazM/PY3JyEplMpuPmPB8GNAvq9XrR09MDk8mEaDTKOkPZbBaTk5NIpVKIRqMYGxuD1+tlAk4ulwuJRALHjx/H9evXkUql2rrfKJVKaLVaOBwO2O12CAQClMtlBINBRKNRZLNZbm/hcOCo1WooFotYXV1l7KKHSVpFIhHEYjE8Hg+sViuOHj2K8fFxjI+PQ6fTgcfjIZ1OIxKJ4NatW7h9+zampqawubn5qVjrPB5vxzgDAGxubqJare7LvXds4koCBQqFAlqtFlqtFsePH8fg4CDOnTsHofCDl57JZKBUKqFSqRCPx1EqlZiIQ6cEjR8XRIkWiUQfSFx5PB6b+zms93c/eDweU37zeDwAwA7rYrF4qISMeDweJBIJzGYzRkdHMTQ
"text/plain": [
"<Figure size 1200x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(12, 6), facecolor='white')\n",
"\n",
"for b_image, _ in loader:\n",
" b_image = inverse_transform(b_image).cpu()\n",
" grid_img = make_grid(b_image / 255.0, nrow=16, padding=True, pad_value=1, normalize=True)\n",
" plt.imshow(grid_img.permute(1, 2, 0))\n",
" plt.axis(\"off\")\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Diffusion Process"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:35:01.574603Z",
"start_time": "2023-02-23T07:35:01.561607Z"
},
"code_folding": [
1
],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:40.386727Z",
"iopub.status.busy": "2023-02-22T16:01:40.385687Z",
"iopub.status.idle": "2023-02-22T16:01:40.397948Z",
"shell.execute_reply": "2023-02-22T16:01:40.396829Z",
"shell.execute_reply.started": "2023-02-22T16:01:40.386688Z"
}
},
"outputs": [],
"source": [
"class SimpleDiffusion:\n",
" def __init__(\n",
" self,\n",
" num_diffusion_timesteps=1000,\n",
" img_shape=(3, 64, 64),\n",
" device=\"cpu\",\n",
" ):\n",
" self.num_diffusion_timesteps = num_diffusion_timesteps\n",
" self.img_shape = img_shape\n",
" self.device = device\n",
"\n",
" self.initialize()\n",
"\n",
" def initialize(self):\n",
" # BETAs & ALPHAs required at different places in the Algorithm.\n",
" self.beta = self.get_betas()\n",
" self.alpha = 1 - self.beta\n",
" \n",
" self_sqrt_beta = torch.sqrt(self.beta)\n",
" self.alpha_cumulative = torch.cumprod(self.alpha, dim=0)\n",
" self.sqrt_alpha_cumulative = torch.sqrt(self.alpha_cumulative)\n",
" self.one_by_sqrt_alpha = 1. / torch.sqrt(self.alpha)\n",
" self.sqrt_one_minus_alpha_cumulative = torch.sqrt(1 - self.alpha_cumulative)\n",
" \n",
" def get_betas(self):\n",
" \"\"\"linear schedule, proposed in original ddpm paper\"\"\"\n",
" scale = 1000 / self.num_diffusion_timesteps\n",
" beta_start = scale * 1e-4\n",
" beta_end = scale * 0.02\n",
" return torch.linspace(\n",
" beta_start,\n",
" beta_end,\n",
" self.num_diffusion_timesteps,\n",
" dtype=torch.float32,\n",
" device=self.device,\n",
" )\n",
" \n",
"def forward_diffusion(sd: SimpleDiffusion, x0: torch.Tensor, timesteps: torch.Tensor):\n",
" eps = torch.randn_like(x0) # Noise\n",
" mean = get(sd.sqrt_alpha_cumulative, t=timesteps) * x0 # Image scaled\n",
" std_dev = get(sd.sqrt_one_minus_alpha_cumulative, t=timesteps) # Noise scaled\n",
" sample = mean + std_dev * eps # scaled inputs * scaled noise\n",
"\n",
" return sample, eps # return ... , gt noise --> model predicts this)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sample Forward Diffusion Process"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-21T08:04:15.117858Z",
"start_time": "2023-02-21T08:04:14.427843Z"
},
"execution": {
"iopub.execute_input": "2023-02-22T16:01:41.581306Z",
"iopub.status.busy": "2023-02-22T16:01:41.580908Z",
"iopub.status.idle": "2023-02-22T16:01:42.154133Z",
"shell.execute_reply": "2023-02-22T16:01:42.153007Z",
"shell.execute_reply.started": "2023-02-22T16:01:41.581270Z"
}
},
"outputs": [],
"source": [
"sd = SimpleDiffusion(num_diffusion_timesteps=TrainingConfig.TIMESTEPS, device=\"cpu\")\n",
"\n",
"loader = iter( # converting dataloader into an iterator for now.\n",
" get_dataloader(\n",
" dataset_name=BaseConfig.DATASET,\n",
" batch_size=6,\n",
" device=\"cpu\",\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-21T08:04:15.117858Z",
"start_time": "2023-02-21T08:04:14.427843Z"
},
"execution": {
"iopub.execute_input": "2023-02-22T16:01:41.581306Z",
"iopub.status.busy": "2023-02-22T16:01:41.580908Z",
"iopub.status.idle": "2023-02-22T16:01:42.154133Z",
"shell.execute_reply": "2023-02-22T16:01:42.153007Z",
"shell.execute_reply.started": "2023-02-22T16:01:41.581270Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxoAAAF+CAYAAAAFumw3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydd3Qc1fn+Pyvtaot2pZVWvXfJ6u62XHA3NrYDmF4MgVBNCJjQIRBKSEgIECCE3gklBhsMBuPesWy5yZZk9d67Vn13fn/43Ptd2YZYwiThxzzn+IC2zN65c+fOW573eTWKoiioUKFChQoVKlSoUKFCxRmE2397ACpUqFChQoUKFSpUqPj/D6qjoUKFChUqVKhQoUKFijMO1dFQoUKFChUqVKhQoULFGYfqaKhQoUKFChUqVKhQoeKMQ3U0VKhQoUKFChUqVKhQccahOhoqVKhQoUKFChUqVKg441AdDRUqVKhQoUKFChUqVJxxqI6GChUqVKhQoUKFChUqzjhUR0OFChUqVKhQoUKFChVnHKqjoUKFih8Fb775JhqN5pT/fvvb3/63h/ejoaysDI1Gw5tvvvm9n9u8efOQOfHw8MDf358pU6Zw//33U15eftJ3xJyWlZUNef2BBx4gIiICrVaL1WoFoL+/nxtvvJHg4GDc3d3JzMw8MyfogquvvpqoqKgzftzTwYlrytvbmxkzZvDFF1/8V8ajQoUKFSpOhva/PQAVKlT8/4033niDpKSkIa+FhIT8l0bzv4c//OEPzJw5E4fDQXNzM99++y2vv/46Tz/9NK+88gqXX365/Ow555zDrl27CA4Olq+tXr2axx9/nPvvv58FCxag1+sBePHFF3nppZd47rnnGDt2LGaz+YyP/cEHH+Q3v/nNGT/u6eKCCy7gjjvuwOl0UlJSwmOPPcbixYv5/PPPOeecc/5r41KhQoUKFcehOhoqVKj4UZGamsq4cePO+HG7u7sxmUxn/Lj/6d+Oj49n0qRJ8u8lS5Zwxx13MGfOHK6++mrS09NJS0sDwN/fH39//yHfz83NBeDWW28lICBgyOtGo5FbbrnljIzzVIiNjf3Rjn06CAwMlHOXlZXF5MmTiYuL45lnnvlOR2NgYACNRoNWqz7+VKhQoeLHhkqdUqFCxX8Vn332GZMnT8ZkMmGxWJg7dy67du0a8pmHH34YjUZDTk4OF1xwAT4+PsTGxvLFF1+g0WjIzs6Wn125ciUajeYkQzM9PZ2lS5fKv1944QWmT59OQEAAnp6epKWl8eSTTzIwMDDkezNmzCA1NZWtW7eSlZWFyWTimmuuAaCmpoaLLroIi8WCt7c3F198MXV1dT94Tnx9fXnppZcYHBzk6aeflq+fSJ2KiorigQceAI4b3RqNRs7Vq6++Sk9Pj6QWvfnmm99L6xLfFWhsbOT6668nPDwcvV4vaV3r16+XnzkVdaq3t5d7772X6OhoPDw8CA0NZfny5bS1tQ35XFRUFIsWLeKrr75izJgxGI1GkpKSeP3110c8b7Gxsfj7+0vamaCnvfPOO9xxxx2Ehoai1+spKioC4PXXXycjIwODwYCvry/nnXceeXl5Jx3322+/ZfHixdhsNgwGA7Gxsdx2221DPlNYWMhll11GQEAAer2eUaNG8cILLwz5jNPp5LHHHiMxMRGj0YjVaiU9PZ1nn31WfuZ05l2FChUqfipQQzoqVKj4UeFwOBgcHBzymogmv//++1x++eXMmzePf/7zn/T19fHkk08yY8YMNmzYwNSpU4d87/zzz+eSSy7hxhtvxG63c9ZZZ6HT6Vi/fj3jx48HYP369RiNRrZs2cLAwAA6nY6GhgZyc3O56aab5LGKi4u57LLLpEF88OBBHn/8cfLz808ydmtra7niiiu46667+MMf/oCbmxs9PT3MmTOHmpoannjiCRISEvjiiy+4+OKLz8i8jR8/nuDgYLZu3fqdn/n000954YUXeO211/jqq6/w9vYmLCyMs88+m0cffZRNmzaxceNG4LgRbrfbT/v3r7zySnJycnj88cdJSEigra2NnJwcmpubv/M7iqJw7rnnsmHDBu69916mTZvGoUOHeOihh9i1axe7du2S1C6AgwcPcscdd3DPPfcQGBjIq6++yrXXXktcXBzTp08/7bEKtLa20tzcTHx8/JDX7733XiZPnsw//vEP3NzcCAgI4IknnuC+++7j0ksv5YknnqC5uZmHH36YyZMnk52dLY/x9ddfs3jxYkaNGsVf//pXIiIiKCsrY926dfL4R48eJSsri4iICJ566imCgoL4+uuvufXWW2lqauKhhx4C4Mknn+Thhx/mgQceYPr06QwMDJCfnz/ECRvJvKtQoULF/ywUFSpUqPgR8MYbbyjAKf8NDAwoDodDCQkJUdLS0hSHwyG/19nZqQQEBChZWVnytYceekgBlN/97ncn/c7UqVOVWbNmyb/j4uKUO++8U3Fzc1O2bNmiKIqivPfeewqgHDt27JRjdTgcysDAgPL2228r7u7uSktLi3zvrLPOUgBlw4YNQ77z4osvKoCyevXqIa9fd911CqC88cYb3zs/mzZtUgDl448//s7PTJw4UTEajfJvMaelpaXyNTE3jY2NQ7571VVXKZ6enkNeKy0t/c6xAcpDDz0k/zabzcptt932vedw1VVXKZGRkfLvr776SgGUJ598csjnPvzwQwVQXn75ZflaZGSkYjAYlPLycvlaT0+P4uvrq9xwww3f+7tivDfffLMyMDCg9Pf3K3l5ecqCBQsUQHnhhRcURfm/OZ4+ffqQ77a2tipGo1FZuHDhkNcrKioUvV6vXHbZZfK12NhYJTY2Vunp6fnOscyfP18JCwtT2tvbh7x+yy23KAaDQa6nRYsWKZmZmd97Xqcz7ypUqFDxU4FKnVKhQsWPirfffpvs7Owh/7RaLQUFBdTU1HDllVfi5vZ/W5HZbGbp0qXs3r2b7u7uIcdypT4JzJ49mx07dtDT00N5eTlFRUVccsklZGZm8s033wDHsxwRERFDIt379+9nyZIl2Gw23N3d0el0LFu2DIfDwbFjx4b8ho+PD7NmzRry2qZNm7BYLCxZsmTI65dddtnIJuoUUBTljB1ruJgwYQJvvvkmjz32GLt37z6JUnYqiOzJ1VdfPeT1Cy+8EE9PTzZs2DDk9czMTCIiIuTfBoOBhISEUypunQp///vf0el0eHh4MGrUKHbu3MkjjzzCzTffPORzJ66bXbt20dPTc9I4w8PDmTVrlhznsWPHKC4u5tprr8VgMJxyDL29vWzYsIHzzjsPk8nE4OCg/Ldw4UJ6e3vZvXs3cHxODx48yM0338zXX39NR0fHSccbybyrUKFCxf8qVEdDhQoVPypGjRrFuHHjhvwDJBXEVUFJICQkBKfTSWtr65DXT/XZOXPm0NfXx/bt2/nmm2/w8/Nj9OjRzJkzR/LaN2zYwJw5c+R3KioqmDZtGtXV1Tz77LNs27aN7Oxsyanv6en5t7/b3NxMYGDgSa8HBQV973wMBxUVFf81ha4PP/yQq666ildffZXJkyfj6+vLsmXLvrcGpbm5Ga1We1LBukajISgo6CT6j81mO+kYer3+pPn/Llx00UVkZ2ezd+9eCgoKaG5u5sEHHzzpcydev3+39sT7jY2NAISFhX3nGJqbmxkcHOS5555Dp9MN+bdw4UIAmpqagOMUrr/85S/s3r2bBQsWYLPZmD17Nnv37pXHG8m8q1ChQsX/KtQaDRUqVPxXIIzM2trak96rqanBzc0NHx+fIa9rNJqTPjtx4kTMZjPr16+nrKyM2bNno9FomD17Nk899RTZ2dlUVFQMcTRWrVqF3W7nk08+ITIyUr5+4MCBU471VL9rs9nYs2fPSa+fKYNwz5491NXVce21156R4wEyKt/X1zfk9VPx//38/HjmmWd45plnqKio4LPPPuOee+6hoaGBr7766pTHt9lsDA4O0tjYOMTZUBSFuro6WUdzpuDv739aimYnXr9/t/b8/Pzk8QGqqqq+89g+Pj64u7tz5ZVXsnz58lN+Jjo6Gjhem7RixQpWrFhBW1sb69e
"text/plain": [
"<Figure size 1000x500 with 12 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"x0s, _ = next(loader)\n",
"\n",
"noisy_images = []\n",
"specific_timesteps = [0, 10, 50, 100, 150, 200, 250, 300, 400, 600, 800, 999]\n",
"\n",
"for timestep in specific_timesteps:\n",
" timestep = torch.as_tensor(timestep, dtype=torch.long)\n",
"\n",
" xts, _ = forward_diffusion(sd, x0s, timestep)\n",
" xts = inverse_transform(xts) / 255.0\n",
" xts = make_grid(xts, nrow=1, padding=1)\n",
" \n",
" noisy_images.append(xts)\n",
"\n",
"# Plot and see samples at different timesteps\n",
"\n",
"_, ax = plt.subplots(1, len(noisy_images), figsize=(10, 5), facecolor='white')\n",
"\n",
"for i, (timestep, noisy_sample) in enumerate(zip(specific_timesteps, noisy_images)):\n",
" ax[i].imshow(noisy_sample.squeeze(0).permute(1, 2, 0))\n",
" ax[i].set_title(f\"t={timestep}\", fontsize=8)\n",
" ax[i].axis(\"off\")\n",
" ax[i].grid(False)\n",
"\n",
"plt.suptitle(\"Forward Diffusion Process\", y=0.9)\n",
"plt.axis(\"off\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-13T13:01:00.619395Z",
"start_time": "2023-02-13T13:01:00.605395Z"
},
"code_folding": [],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:43.160882Z",
"iopub.status.busy": "2023-02-22T16:01:43.160156Z",
"iopub.status.idle": "2023-02-22T16:01:43.170501Z",
"shell.execute_reply": "2023-02-22T16:01:43.169283Z",
"shell.execute_reply.started": "2023-02-22T16:01:43.160846Z"
}
},
"outputs": [],
"source": [
"# Algorithm 1: Training\n",
"\n",
"def train_one_epoch(model, sd, loader, optimizer, scaler, loss_fn, epoch=800, \n",
" base_config=BaseConfig(), training_config=TrainingConfig()):\n",
" \n",
" loss_record = MeanMetric()\n",
" model.train()\n",
"\n",
" with tqdm(total=len(loader), dynamic_ncols=True) as tq:\n",
" tq.set_description(f\"Train :: Epoch: {epoch}/{training_config.NUM_EPOCHS}\")\n",
" \n",
" for x0s, _ in loader:\n",
" tq.update(1)\n",
" \n",
" ts = torch.randint(low=1, high=training_config.TIMESTEPS, size=(x0s.shape[0],), device=base_config.DEVICE)\n",
" xts, gt_noise = forward_diffusion(sd, x0s, ts)\n",
"\n",
" with amp.autocast():\n",
" pred_noise = model(xts, ts)\n",
" loss = loss_fn(gt_noise, pred_noise)\n",
"\n",
" optimizer.zero_grad(set_to_none=True)\n",
" scaler.scale(loss).backward()\n",
"\n",
" # scaler.unscale_(optimizer)\n",
" # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
"\n",
" scaler.step(optimizer)\n",
" scaler.update()\n",
"\n",
" loss_value = loss.detach().item()\n",
" loss_record.update(loss_value)\n",
"\n",
" tq.set_postfix_str(s=f\"Loss: {loss_value:.4f}\")\n",
"\n",
" mean_loss = loss_record.compute().item()\n",
" \n",
" tq.set_postfix_str(s=f\"Epoch Loss: {mean_loss:.4f}\")\n",
" \n",
" return mean_loss "
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:36:29.926465Z",
"start_time": "2023-02-23T07:36:29.906486Z"
},
"code_folding": [],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:43.842403Z",
"iopub.status.busy": "2023-02-22T16:01:43.841296Z",
"iopub.status.idle": "2023-02-22T16:01:43.857743Z",
"shell.execute_reply": "2023-02-22T16:01:43.856757Z",
"shell.execute_reply.started": "2023-02-22T16:01:43.842351Z"
}
},
"outputs": [],
"source": [
"# Algorithm 2: Sampling\n",
" \n",
"@torch.inference_mode()\n",
"def reverse_diffusion(model, sd, timesteps=1000, img_shape=(3, 64, 64), \n",
" num_images=5, nrow=8, device=\"cpu\", **kwargs):\n",
"\n",
" x = torch.randn((num_images, *img_shape), device=device)\n",
" model.eval()\n",
"\n",
" if kwargs.get(\"generate_video\", False):\n",
" outs = []\n",
"\n",
" for time_step in tqdm(iterable=reversed(range(1, timesteps)), \n",
" total=timesteps-1, dynamic_ncols=False, \n",
" desc=\"Sampling :: \", position=0):\n",
"\n",
" ts = torch.ones(num_images, dtype=torch.long, device=device) * time_step\n",
" z = torch.randn_like(x) if time_step > 1 else torch.zeros_like(x)\n",
"\n",
" predicted_noise = model(x, ts)\n",
"\n",
" beta_t = get(sd.beta, ts)\n",
" one_by_sqrt_alpha_t = get(sd.one_by_sqrt_alpha, ts)\n",
" sqrt_one_minus_alpha_cumulative_t = get(sd.sqrt_one_minus_alpha_cumulative, ts) \n",
"\n",
" x = (\n",
" one_by_sqrt_alpha_t\n",
" * (x - (beta_t / sqrt_one_minus_alpha_cumulative_t) * predicted_noise)\n",
" + torch.sqrt(beta_t) * z\n",
" )\n",
"\n",
" if kwargs.get(\"generate_video\", False):\n",
" x_inv = inverse_transform(x).type(torch.uint8)\n",
" grid = torchvision.utils.make_grid(x_inv, nrow=nrow, pad_value=255.0).to(\"cpu\")\n",
" ndarr = torch.permute(grid, (1, 2, 0)).numpy()[:, :, ::-1]\n",
" outs.append(ndarr)\n",
"\n",
" if kwargs.get(\"generate_video\", False): # Generate and save video of the entire reverse process. \n",
" frames2vid(outs, kwargs['save_path'])\n",
" display(Image.fromarray(outs[-1][:, :, ::-1])) # Display the image at the final timestep of the reverse process.\n",
" return None\n",
"\n",
" else: # Display and save the image at the final timestep of the reverse process. \n",
" x = inverse_transform(x).type(torch.uint8)\n",
" grid = torchvision.utils.make_grid(x, nrow=nrow, pad_value=255.0).to(\"cpu\")\n",
" pil_image = TF.functional.to_pil_image(grid)\n",
" pil_image.save(kwargs['save_path'], format=save_path[-3:].upper())\n",
" display(pil_image)\n",
" return None"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:36:07.313353Z",
"start_time": "2023-02-23T07:36:07.307373Z"
},
"code_folding": [],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:44.742690Z",
"iopub.status.busy": "2023-02-22T16:01:44.742000Z",
"iopub.status.idle": "2023-02-22T16:01:44.748704Z",
"shell.execute_reply": "2023-02-22T16:01:44.747497Z",
"shell.execute_reply.started": "2023-02-22T16:01:44.742652Z"
}
},
"outputs": [],
"source": [
"@dataclass\n",
"class ModelConfig:\n",
" BASE_CH = 64 # 64, 128, 256, 512\n",
" BASE_CH_MULT = (1, 2, 4, 8) # 32, 16, 8, 4 \n",
" APPLY_ATTENTION = (False, False, True, False)\n",
" DROPOUT_RATE = 0.1\n",
" TIME_EMB_MULT = 2 # 128"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-13T13:01:00.588388Z",
"start_time": "2023-02-13T13:01:00.344403Z"
},
"code_folding": [
0,
13,
23
],
"hide_input": false
},
"outputs": [],
"source": [
"model = UNet(\n",
" input_channels = TrainingConfig.IMG_SHAPE[0],\n",
" output_channels = TrainingConfig.IMG_SHAPE[0],\n",
" base_channels = ModelConfig.BASE_CH,\n",
" base_channels_multiples = ModelConfig.BASE_CH_MULT,\n",
" apply_attention = ModelConfig.APPLY_ATTENTION,\n",
" dropout_rate = ModelConfig.DROPOUT_RATE,\n",
" time_multiple = ModelConfig.TIME_EMB_MULT,\n",
")\n",
"model.to(BaseConfig.DEVICE)\n",
"\n",
"optimizer = torch.optim.AdamW(model.parameters(), lr=TrainingConfig.LR)\n",
"\n",
"dataloader = get_dataloader(\n",
" dataset_name = BaseConfig.DATASET,\n",
" batch_size = TrainingConfig.BATCH_SIZE,\n",
" device = BaseConfig.DEVICE,\n",
" pin_memory = True,\n",
" num_workers = TrainingConfig.NUM_WORKERS,\n",
")\n",
"\n",
"loss_fn = nn.MSELoss()\n",
"\n",
"sd = SimpleDiffusion(\n",
" num_diffusion_timesteps = TrainingConfig.TIMESTEPS,\n",
" img_shape = TrainingConfig.IMG_SHAPE,\n",
" device = BaseConfig.DEVICE,\n",
")\n",
"\n",
"scaler = amp.GradScaler()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-13T13:01:00.603387Z",
"start_time": "2023-02-13T13:01:00.590387Z"
}
},
"outputs": [],
"source": [
"total_epochs = TrainingConfig.NUM_EPOCHS + 1\n",
"log_dir, checkpoint_dir = setup_log_directory(config=BaseConfig())\n",
"\n",
"generate_video = False\n",
"ext = \".mp4\" if generate_video else \".png\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-13T12:45:46.057703Z",
"start_time": "2023-02-13T12:45:39.770695Z"
},
"_kg_hide-output": true,
"hide_input": false
},
"outputs": [],
"source": [
"for epoch in range(1, total_epochs):\n",
" torch.cuda.empty_cache()\n",
" gc.collect()\n",
" \n",
" # Algorithm 1: Training\n",
" train_one_epoch(model, sd, dataloader, optimizer, scaler, loss_fn, epoch=epoch)\n",
"\n",
" if epoch % 5 == 0:\n",
" save_path = os.path.join(log_dir, f\"{epoch}{ext}\")\n",
" \n",
" # Algorithm 2: Sampling\n",
" reverse_diffusion(model, sd, timesteps=TrainingConfig.TIMESTEPS, num_images=32, generate_video=generate_video,\n",
" save_path=save_path, img_shape=TrainingConfig.IMG_SHAPE, device=BaseConfig.DEVICE,\n",
" )\n",
"\n",
" # clear_output()\n",
" checkpoint_dict = {\n",
" \"opt\": optimizer.state_dict(),\n",
" \"scaler\": scaler.state_dict(),\n",
" \"model\": model.state_dict()\n",
" }\n",
" torch.save(checkpoint_dict, os.path.join(checkpoint_dir, \"ckpt.tar\"))\n",
" del checkpoint_dict"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Inference"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:36:13.403286Z",
"start_time": "2023-02-23T07:36:10.026409Z"
},
"execution": {
"iopub.execute_input": "2023-02-22T16:03:08.012329Z",
"iopub.status.busy": "2023-02-22T16:03:08.011929Z",
"iopub.status.idle": "2023-02-22T16:03:10.213564Z",
"shell.execute_reply": "2023-02-22T16:03:10.212573Z",
"shell.execute_reply.started": "2023-02-22T16:03:08.012295Z"
},
"hide_input": false
},
"outputs": [],
"source": [
"model = UNet(\n",
" input_channels = TrainingConfig.IMG_SHAPE[0],\n",
" output_channels = TrainingConfig.IMG_SHAPE[0],\n",
" base_channels = ModelConfig.BASE_CH,\n",
" base_channels_multiples = ModelConfig.BASE_CH_MULT,\n",
" apply_attention = ModelConfig.APPLY_ATTENTION,\n",
" dropout_rate = ModelConfig.DROPOUT_RATE,\n",
" time_multiple = ModelConfig.TIME_EMB_MULT,\n",
")\n",
"# checkpoint_dir = \"/kaggle/working/Logs_Checkpoints/checkpoints/version_0\"\n",
"\n",
"\n",
"model.load_state_dict(torch.load(os.path.join(checkpoint_dir, \"ckpt.tar\"), map_location='cpu')['model'])\n",
"\n",
"model.to(BaseConfig.DEVICE)\n",
"\n",
"sd = SimpleDiffusion(\n",
" num_diffusion_timesteps = TrainingConfig.TIMESTEPS,\n",
" img_shape = TrainingConfig.IMG_SHAPE,\n",
" device = BaseConfig.DEVICE,\n",
")\n",
"\n",
"log_dir = \"inference_results\"\n",
"os.makedirs(log_dir, exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:42:49.677019Z",
"start_time": "2023-02-23T07:41:04.890036Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling :: 100%|████████████████████████████████████████████| 999/999 [01:43<00:00, 9.63it/s]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARIAAAESCAIAAAC+Vc10AADrFElEQVR4nOz9ZWBUSdMwDJ/xuHsCIZAEDQQNTnB3WdwXWJxFF3ZZJLC4X8CyuLu7BQgWHEJCHBLi7jOZyaTfH7Xp7ek+52Tguu/3+d7v2foBk3P6dHVXV1dVV3dXSRBC3L/wL/wL3wLS/9MN+Bf+hf8PAkLof1zhSCQSXiz4uUwmI9/Cn1KpVCo1mMZQHv8rlUrxb/hEIpHgJxKJhLcvbGNEmi2tAOohbqcQFhYdiRfXQDVGpG3fMS5sbRSdqcawWCqlFTkKQkAiBVwkFl764N8koaRSKRBcpGFkeZZixg89+QnbQfLPv7H8l6xm5Ifs8EDjyG7jP/Fbtjb2IU+XmFffRz7yB9lUFotQU78VKVkhFjQUexk/8YwBXlYjsfAyNPUEj5pQAYyFHGJK5FH1ULNFpNcg5kSmjRCJ3Nzcjh07NmPGDFNTU/I55j1KiBv05bunjTESFP8mmQA+JCcGy3nUpKeEECcgR8nhqbT9vMzBGQpjXtqJSDUjmVic9eFPCgvVKo6PaN+Kl8VClmS1Lvmbt83seFFYhIQLlqGUuKRGmZfOMpmMnDaSCkuk0ua5ubk9fPjw69evbm5uHCMFhACw8JcwRtmVl5fzFiaNMfYVflJeXg6/pVIpqgCOkLVQAGt5zEzl5eUSiQRj522eMeYNOZPJ8mTNgM6YqozHyxYT+ZOXkt+KDlclbo9hruUqOk71ndUeeGrBn+Tg8qJgLQ5cHsYacwV8S40yfoX5QSqV6vV6/BWg0Ov1vHip5mk0mrCwMA8Pj5o1a5K4gMGEyPg3KcRfYyKKF/smoPhSZGrhwvBELpeTxeAHzDq2zUJAWa5C4lboW5E/vwN4m8r7kDW4KVaoFCqdnNQ8odQLFljkW0qs8LaQBVKdksZbWVkZ9aGQsCClCWXtiwBbxsLCguM4FxcXc3NzqrPiVcnFX1Pz+/ukKQVkbVKpVC6X63S66tWrOzs7m5ubl5SUaLVaS0tLmUyGEMrLy4uJicnPz0cIlZWVkSYsCKfy8nL4IaTiKNRCssTS0tLGxsbOzs7U1NTa2hoh9ODBA61WK9IvShiL08cY6mFxy7YQmk2OgvHDUSngesg6sfjn+FiWd9KSioJtGxZwLFNhzV+pIsVlqCYZYxRQZdzd3du3b88xvihjqFrJtCHZ9H9kkLCVietECNnZ2S1ZsmTAgAHW1tZarTY9Pb1KlSpQPiEhYefOnVeuXImPjy8tLcVtkMlkWBd/k8TFqJVKpZmZmbW1tV6v9/b2rlGjRtOmTX19fatVq+bl5cVxXJUqVZKTkzH3VMqmRr4VYQ4hu5cjtKI4Y0kkEtDJ1tbWJiYmMO2lUqmFhYVEIklJSSkuLuZtHtCTnCocw2fGdFPEPMFzgzOcQqTykUqlZmZmKpXKwcGhtLRULpdrtVq5XF5WVpaXl1dSUgIaSXz2GgPm5uYBAQFVq1ZNT09PT08vKSn5tu+/yUrBIJVKFQqFSqWSyWRyubxSpySpIsiZLZVKzc3N58yZk5ubixDSaDQ6nQ4hVFpaqtPpYCARQh8/fhw+fDjH528Qt7iEDB53d/eRI0eeOHEiKSkJEQDqq6Sk5MGDB2ZmZtSHvIt1iUSiUqkUCgXQAf+QyWQqlcrc3NzMzEwul6tUKrlcbrzzwJhxgUqUSiU0wNnZuXnz5r169QoMDNy3b19kZGRISEhoaGh4eLhWq0UIAQ15sfB6/yUSiVKphPrNzc0tLS0tLCxwL2QymUKhEHJLkKND9oUsj51D+Cs/P79p06Zt3bpVq9V+/vxZr9eHh4enpKS8efNm9uzZfn5+4PVih16IYizDADg5OQUFBen1+h07djg5OYnTmcVSibbhhbZt244dO7Zhw4be3t4WFhZJSUn79+/fvn17VlaW+IfYRsIizdzcvHXr1tOnT7eysuI4LiYm5ujRo9nZ2V26dImMjOzbt2/9+vXLy8vr1q27devWHj16HD58+Pbt26zgF1IFvMUGDBiwdOlSPz8/bLDB/MROmLi4uLlz54IEElIO+PnmzZunTZuWkJDw7Nmz8+fPN2nSJD4+XiaTBQQEdOrUycPDo7i4WC6XX7t27fnz5ydOnEhJSYGqSIVpJJCN8fLycnV19fX1bd++ffPmzZVKpampqZWVFd56kslkNWvWBDoDa/r7+wcHB6emplJVkUYOaZj5+vrOmTPHxcUFIVSrVq1atWqlpqZu3779wYMHbdu2LSkpycrKysnJ+fr1a3FxcWJiIkV5djphw4w0NTHqBQsWzJw509nZubS0NCoqSq/X5+fn165dWyKRuLq6NmjQIDMzc8+ePUFBQVqtVmg+kNoSDxlpNAHY29u7u7vrdLrr16/n5eXxUphquYEpSEkCoQEDqFat2tatW9PS0kpKSq5fvz5t2rSdO3empqYihKKioubPn8/rF+YMFTQAlLSwsJg5cybIwmPHjrVs2dLExMTU1NTR0dHa2rp27drr1q0rKiqCz0tKSi5fvswZSixyR4yUnSI96tWr16tXr6CwRqPBqqa4uPjmzZuTJ09u0KAB6X7g7Qv+c/369biGvLy8wsLCjIyMjIwM3GwAvV5fUlJy6tSpFi1akCNBiV5yhcZynkKhsLa29vHxOXnyZGxsbGpqKowFMg6SkpJ69uzJOy7sclkqla5fvz4rK0ur1b579+7hw4cxMTFArqSkJI1GU1RUlJ2dnZSU9Pz584MHD9arV4/8XEjb4LekfjM3N9+yZUtRUZFer3/8+PGsWbOqVavm4eHh4eHRr1+/SZMmnT17Fir5+PEjXr6T1JPL5SwWFqB8YGBgcHBwYWHh+fPnbW1tcXvEvyUpJqceiXxQs2bNefPmBQYG3rhx4+zZs3FxcQkJCTY2Nm/fvp0yZUqjRo1Gjx596tQpSupQjUaGAlutVl++fDkmJkYul0dERHz+/BkmtFqt5jguPz9//fr1YWFhc+fObdCggampaYMGDRwdHTMzMzG9eFc47CKBFELh4eGvXr1q3LgxQgimx7t377Zv3x4REZGfnw+LKGMoCN3Zu3dvUlJS48aNq1ev7uLiUq1aNfDPcBwXGxt769atkJAQBweHyZMn+/n5DRgwICEhISoqKicnB5ynVJuFRB08r1u3blBQkFQqbd26taWlJVvyxIkT6enpJiYmZmZm5ubmxcXFX7586dy5c4sWLRBC7u7uQgYJ9qxgdCqVqnfv3vb29q9fv96+ffuzZ8+sra29vLxq1Khhampat25dpVL55csXS0vLzp07N2zY0MfH58SJE7t378aeG6HucIZrkmHDhg0aNKh3795Pnz7du3dvWFhYQkIC1gCZmZkKheLevXsfP378/fffnZ2dSYcbXinhhyTR2AbAn1WrVk1JSZk0adLz589zc3M5gY1yxGdcABhrpPn6+v7yyy/du3c/c+bM1q1bY2Ji4Llarb59+3abNm0aNWpkamoKIkRoAY0q9mG4CnbX6/Vfvnz58uULL1KpVJqZmXny5Ek7O7vZs2dXq1bNzMysbdu2ly5dgs+xO4538Ur6asgmlZaWYlmo1Wr37t17+fLlkJAQY2YLqRwAaVRUVGJiYrVq1WxsbKytre3t7ZVKpbm5eVlZ2devX8PCwhITE5VKZVZW1ubNm11dXXv06BEeHn7s2DF2UwLrH0oWwL9NmzadOXMmVhfl5eUpKSnv37//8OFDenq6TqfjOO7hw4f5+fmggZVKZXFxsaWl5ZgxY6Dyp0+f4oFjAZMISOrk5FStWjWO48LDw0NCQuLj4zmOe/nypaWlpUKhcHd3l0qlubm5pqamDx48GDhwYK9evapWrRoeHv7o0aNKjU9ApFQqBwwYsHz5cktLy8uXL+/evfv+/fvUnk9paWlpaWlRUVFYWBjHcV++fMErokpdF5T/CWDw4MHNmze/cePGiRMnyMaQFfKyLv03q0Cpb0xNTTdt2pSSknL06NHatWuTZWxtbadNmxYdHa3VatevX29jY8NLpr9VG7MmJvWsubl5u3btAgMDHR0d8SugUc+ePZ8/f44Qys3NHTl
"text/plain": [
"<PIL.Image.Image image mode=RGB size=274x274>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"inference_results\\20230223-131104.mp4\n"
]
}
],
"source": [
"generate_video = True\n",
"\n",
"ext = \".mp4\" if generate_video else \".png\"\n",
"filename = f\"{datetime.now().strftime('%Y%m%d-%H%M%S')}{ext}\"\n",
"\n",
"save_path = os.path.join(log_dir, filename)\n",
"\n",
"reverse_diffusion(\n",
" model,\n",
" sd,\n",
" num_images=64,\n",
" generate_video=generate_video,\n",
" save_path=save_path,\n",
" timesteps=1000,\n",
" img_shape=TrainingConfig.IMG_SHAPE,\n",
" device=BaseConfig.DEVICE,\n",
" nrow=8,\n",
")\n",
"print(save_path)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-13T13:56:16.260979Z",
"start_time": "2023-02-13T13:50:06.139878Z"
},
"execution": {
"iopub.execute_input": "2023-02-22T16:03:53.237665Z",
"iopub.status.busy": "2023-02-22T16:03:53.237279Z",
"iopub.status.idle": "2023-02-22T16:08:06.306724Z",
"shell.execute_reply": "2023-02-22T16:08:06.304477Z",
"shell.execute_reply.started": "2023-02-22T16:03:53.237633Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling :: 100%|██████████| 999/999 [04:06<00:00, 4.05it/s]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABEIAAAESCAIAAAB2Fud+AAEAAElEQVR4nOz9ZXRVyfI4gPbxuHsCMUKwhOBOCO4Mg7vrDC4zBGYYZnCZwd1dg7sHDwRJQgIkIe7uOZLs96F+6dvp3nvnMPf+31tvLepD1k6f7q6u6uqqaquWcByHvsN3+A7f4Tt8h+/wHb7Dd/gO3+H/f0D6/+sGfIfv8B2+w3f4Dt/hO3yH7/AdvsM3Asdx7IaMRCIRSaF+xf9KpYKTIoxFPLNEIqEyyGQy/C9ZhLcNJBbIjCuk8uPayA82g0QikclkVKswFlyEahtvC8maeb+pIrz9gqqzCGoQQkSV4m0k1S+Yb2QKbifgAp5Av8BfSMHsQgxDAAv+laSdxMjyQUjSMBNwN5H9wks+EpVhEaByYlpQTcJMii5ZFSZciDTE9L4+o+9fgAgWCvTEwiuNJBah8fjv0AlhEWobFldECDNixhSq3rMkkL1fI5DSzgtCA7lGWv418GoYSiyF1LiefYcB6iFpIRnCO4J45Vykd4RoQVVKiaIIM5myDrxkkkoPGk9qfn2YxqZTw5yXkxgLmB6h7iDtIy+vyAxCml9SBUK2lbedcrkcEUODLUVhwSqalxXfCkK9z5owcSoA2FYJMVxcWwqRg1ErFAoRB0lPLDWaez1/JbGIuFVUusioYWugxgvVcrZHatSWQgBYFAoFFuYaCQGQy+Vsj1ADCoDXu/hvLK8QABahMasnUlIpiWsYqk7SMgph5K2QNKOUjJHdIZfLJQTgItjHBsLlcjmkoOoihynC3uZ/aKFUs/hQp1QkEuC1COOEWCwigkKMY/1yrDTJOjFGTDxOFPEvsQ2QVHe+5XI5a/71H4H6q3Jeh4nUEexIwx/k1EtoWEIi1S9sR7M9QsqAUBtQdWXHOhkUsOlCKk9EFYo4f+Kq7ZuAlDHK0RFvJ/sv7hq2YZTrL0IUL7Cjknec6uOU87oppL6rkav/75xysnm8HONtHrZepErFXck7ycQ9JWTMSPWF03lHkz5CqA/HJML+Fgs19j7ZQpFqhRqP+QMOLhLoF0RIDiLkB1eL/WMyM2t0WIpYSWa1E6szUXW1SRVHjFbk7ReyeUJ+A2naWKmjSGZpIZkmVDlrT0n2wr+knRJ3ZYSA6guy70RkTCSDUP3fBCwWsmH/Wu1TBYVGJSs25E8AIh6CPlh4Zxr6D38WaqTlm5jGyiqJRVLlp4qUQtXHCJVT3L+qUVuyegZVH4lkTrY97HKMCBX/JbCjUmTqSLkfqLr9EmoYOyXjzSniqepDMsaCFzskhJEltR/VO6TiwvMZ3BHkr9U4Ji4EvILFdn+NgLGwC3LsEBURLNZlJBMpw0xOYJBo17KA2ympPvmTMK4/Ehj8+mhSymdiVQAmhDWQJAeoRKFs7K+8ZobdRpBWX04jpZDFhZh+pAwz1UFsU3lHJi9gosi1n/858Ayb6qwm54qkyhbSv0JyzsrY/1OgsPB68KxI66nFhLAI1Yy/zczMTp06denSpfnz59vb25NFROyZ0AQDStnb2x85ciQ0NHT48OEkLYgQZmpug/hGkwiWGhP1zyDuLv87YAcRi4X0dGtccBEalZTGZrWlUE7en6iOEEJKahhqLipkZcTJkVR3PeFf3gkGa5VIpLwWQWS1FdMCMzqyODlFFGEmae9wZvJXcbdMpGGIbwDyGiZWW1LwTcJcY2Z9tKWQ4hLSQvqM/f856Kn5STHj7ax/xzHxeZFUKm3evPm1a9dmzpxpZ2enD1IRWoSsDJZVcLq+CYuQX4FFlB2PIr9SNfwPe5+1+ywtiG9eJ2J/Raw2m05hESrIi536SUSXUhzDmcmJDSwpkoAIb5Oct8hkMlIl4r//h4VE9q/39Sj4b1QAL1/wKSYqD8VoXhcTEfJKVYvXZfW3muICLTIMeNOpguRf/TnGTtIcHR2nTZu2ePHiFi1aiDeS1y2jjn6RLCINNiXEpALC+UlacBFqOsRLAuKz1rw+Ac6jJ8f+SwtKYiGneWR+qpS1tfXPP/88evToOnXq8A6x/95k/guiWJMprvv+HWo9DTMGKyur5cuXQ5GCgoL27dvrWZA1AHg+qVQqf/zxx6KiIo7jVq9ezWv7hTwYLMxQJ6+GEdLmQsCrXkRWyv8fQY2OLEXXN9GIi/BqmBqrpcZ+jajJfmHVCJtfqVQ2btx4woQJ8+fPb9q0qUKh4MVF6jpezc+u7yC+DXC2VRSlpEXD2pLMJiGA6h1e9wu3HA5ykJWQHBNqFatdyXRp9QVRqn78kz5jn8XOakg9e1/cyPJq5hqB7ZdvKv4vALCQqwn/us1UitDSEtuPLCdlMlnjxo0fPHjAcdzOnTudnJzIIkItFHGX/x05vJUIWTF2BIlj4R0m+N//Ye+LNEx8VAKIGy99MPJqGN78+s8IKAZKq1++ID1DTAU5nyF1GimHWHSlUilMY8iZD+4XOdmUyspKPRuNm8t2LW+iSB5JladL2Qmcp2XLlgEBAcXFxffv3//06RNkBtqgwaxjwRqbiooK6DmO4+RyuU6nk0gkZCIekJCZlyEymQz/hJuHP6D9mBwqz8iRIw0MDJ4/fx4bG6vT6TBessHkX5ZX8IGpxoCbBNlcXV2nTJkyc+ZMiUSiVCpfv35NVSXiu0DNFRUVmApIkUgk8AF/gQ94eOOCZHFckKKR6i8oDkUopmEycU6SaWSTREQOdy5VOZuH7NwaeQWAUZMDD4QK5zEyMvrpp5+mTp0qkUj279+/b9++5ORkkijexuj5778AhUKh1Wp5UbMs4m0kJYHi7dGntSTGgICAmTNnwreZmZmRkRGVR4hvuG0IocrKSugIGJKVlZXNmzdXKBQfPnwICQkhB7iQjSRFGiRNhPP69whVCclGUvb0BIVCYWJiYmFhYWNjY2pq2qRJE5DhFy9evHnzRqPRsBhrbBLZMKGC/v7+9evXNzAwKCkpCQsLe/XqlVA9vKoMJwKTebGQnKH6iBzOrDIEBa7T6VD1gw2ouqJwcnL6448/WrdubWRk1KVLl507d966dQvrNKxbSIXMNgPqxHoDJ5LajGIFbjBZA6l/WCWJ20DRTiYKcZ7jOFKueJlGIWKbIWSnyE4kM1PcEwecB7NRSAmLD3wk4L2QDNSzSbzNEykIHGCbJ5fLHR0di4uLDQ0NLSwsatWq5eLiYm1tLZFIwAEICgqKj4+n2knRwhJF6WfWOvAKg4gSY2lkkcpkMh8fn4CAAIRQSkpKSUkJWUT/jsaDiFICNRZkdYWQZuM19OS/1tbWHTt2dHBwUCgUFRUVt2/fjo2NxULLgshPGGqkgm0h22ChRGrY8vYjSy+qSXrJLiAzYBGiDCWlCsTpwtzAo4NiI9hoVF3YqMopbxO0OvwlyUQIVZvGsE10dnauU6eOSqWqqKgA77+goECtVqekpOTm5opzRwQwm8hRR/UWTh8+fPi8efPi4uISEhI+ffqEm07ST4oRab2wXcEoyL4h2UEiNTAw8PLycnFx4TguISEhOTm5qKgICXAcdzDldpN5zM3Nt2zZYmNjs3Hjxq1btyYlJeHGQCmq/WRVpLSR8y7qV/wxePDg6dOnW1pafvnyJSUlRVq1wce2igJsXxExikjNiCUSGIjTsT2mmCOTySh3hJyxYBTU+MdsoSwQ1XhedrH/kuyi2kPVSUkFyyWhMUzOf6j65XJ5jx495s+fb25ujhDq0KHDvXv3kpOT8fQPMYtJvDquRhDKJpFIzMzMGjdubGBgEBoampOTA3MYIRBxlVCVbACNHh4eVlZWxcXFWVlZVlZWcrk8MTERzBt0ATUt1Kfxbdu2HTlyJJxYKC8vv3HjxtevX6k84gwhPSosCSqVysfHR6VSPXr06PXr11gCSZ8SVfd4qPEl5JWyP7H6Woi3VlZWvr6+paWlnz9/LioqYqf9vECKt6WlZefOnbt06WJvby+RSDp06ABi9vz5861bt96+fTs/P59sJy/3WHtJtRmnw99WrVpt2LABb/Pev39/1qxZUVFRZPOEbB6lRijlII6drAH3HcVnUPjYzrE6BLfK2tr6hx9+gGy9e/c2NTUtLCx
"text/plain": [
"<PIL.Image.Image image mode=RGB size=1090x274>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"inference_results/20230222-160353.mp4\n"
]
}
],
"source": [
"generate_video = True\n",
"\n",
"ext = \".mp4\" if generate_video else \".png\"\n",
"filename = f\"{datetime.now().strftime('%Y%m%d-%H%M%S')}{ext}\"\n",
"\n",
"save_path = os.path.join(log_dir, filename)\n",
"\n",
"reverse_diffusion(\n",
" model,\n",
" sd,\n",
" num_images=256,\n",
" generate_video=generate_video,\n",
" save_path=save_path,\n",
" timesteps=1000,\n",
" img_shape=TrainingConfig.IMG_SHAPE,\n",
" device=BaseConfig.DEVICE,\n",
" nrow=32,\n",
")\n",
"print(save_path)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-13T14:00:18.151514Z",
"start_time": "2023-02-13T13:57:07.698420Z"
},
"execution": {
"iopub.execute_input": "2023-02-22T16:16:59.757373Z",
"iopub.status.busy": "2023-02-22T16:16:59.756510Z",
"iopub.status.idle": "2023-02-22T16:21:11.518330Z",
"shell.execute_reply": "2023-02-22T16:21:11.517233Z",
"shell.execute_reply.started": "2023-02-22T16:16:59.757332Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling :: 100%|██████████| 999/999 [04:04<00:00, 4.08it/s]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiIAAAIiCAIAAAB61jR9AAEAAElEQVR4nOz9d1xVR/M4AO8tXHqvAgIiRRFr7L13jT2xa+wxajT2RGNsMUYTS2yJ3dhii713RUTFgooUAUFAeu9w9/1jHvaZu3vOAfP8vu/7+byfzB96OHfPzs7s7Mzs7OyuilJK/oV/4V/4F/6Ff+H/BtT/v27Av/Av/Av/wr/w/9dAKa3+hEatVhNCVCpVlSWhDCsJWLRarVwB8Vs57PDAyqgqgWHRaDTcV/AGv2efiP8qNECj0ajVasBSzZZz9TMSqvwWsODyuM24MHuufuXspVzvcyiIwHM5RGJLGMcUyrMKWU9BDYC0OqxmtOBG4nqw6GJauDZL0oifqzNelMWJIVJgI8MCVCij41Bj5uMRp4xFroxkv4uDV0H2RI4pdKjcCMLPkrhY70ODWRl45j4RhUfyT4aOfS7Z+9XXh9UBhVGJy3A9yxEoJzNykoxVq2R5zE9JWtRqtVarFZvxHywfZWb+MSh0j+SYlBuZ4r+YMIwFmx9cCadNsAwpq1GojVOanEBzuIh8t0myAj8rGADJPxXUVnXU2T8DhZqxRawOlirlWE61sTfMMDPHgqlIeMlVy4Yi60TO2kkCwyIpotUH5c5S5hiW5Gr+irmqoDQVxkKVBvif0fL/CrhRyd5XaWCqNG9Y1fx/k5bqDC7ujXJ57pnRIintHApRobEhIKKGCg16X5JxogQrzGOqY6gBC0ePnMaUdOHFT7hnps64WQtmh6jCMFJWQHRsMUM5H7BK8iXtENcAlUql0WhwAQUfUHJIKJgZrhh+o+DPKtRmZmY2a9YsHx8fOSzEkPPVkbGP8tkl28kGJ3CSuXvY2LD+xdxWVc4YOM8dF2M1cLQodIpIY5XAmCCpaLjuq3LyKjaJs47/u9KsjkKoUmmS/6H3RSzKA0H0NfGMWY4cVp6j5aNGjTKI/SKHRdm6NGzYcM2aNSNGjCCKricRel/kjCRbuAGiMnT0xaYCFj6KBYj1ej3XJngjKZTVl9SKigrWer1ezxQcPLB6MHaucvyn3DP3Oe4w5odKEgj8EhvG6udGC5g0IIqVV6vVXOXsT/hckiJKaUVFhZyUc/wRGc7eiNglK8HNlqyTlRcbrNPpfvnll/bt24eEhERHR4u4GCJJRcn1MnsD7WncuLGvr29iYmJeXp6lpeWDBw8YOWJLxDfwEjgJ/7LGAGfgX8l6mPxj04gdf0mxgaqgmNgYhlHkkiQhcp2iEmaEUK0kBwjqAlyYGA4NTndIii6REhXJFlb5UhnkUFSHe5INwJzBmo3pN0YXk0MiMJwgzuMK1Wo1fCs5EuU6pco2K/zE+MBVzn21aNGiLl261K5dOzs7W65OjlKuKhAqrnJMFH6JtaLyYP+vmZEcsXLC98+A9ZnYnVwxySFdZeXwwEY+1zGMHbTS6+EKsD/hJ25M0srIL36JhweuRJlpmGpjY2O1Wl1SUiJp3hTUvSRUx8YQ+VFdJVJTU9PPPvtsxIgRe/bsiY+PVyaNGAoV9xNWnc7OzrVq1apfv37Pnj0DAwOjo6OLi4t1Ol1QUNCJEyciIyOJIJ/Mz+IGNvaHsDfHaXyYl0BJ0Q/AfhzW5uIEQhRpOW5g7MrKQvxc0oAps7dJkyZt27bNyMg4c+ZMXl6eWLnYAEnhwdaaEGJhYeHo6Ojp6enk5GRlZQUMSUlJOX/+fEVFhbIfIHJDUqrr1avXrl07MzOzI0eOJCUlsffKBk+SOvhEpFRUL1jgRf+SGLJLrhmiN1B9ELU/fmban2lzUX2ZmZlNnjzZy8vr0qVLN27cwE3iWqjQbCz/mEvKRDEpZX3EvDooUMUKIZHygDBKLC5VqkJRgYrl5ey2ZFVyChFXIokdm2IiBEMUZI6bu8g1Xs7GYKQmJiaNGzd2dXW9fft2SUkJ5onoUDBwcnJq1KgRISQvLy81NTU1NbWgoECv12u1WmNjY3d3dxcXl+zs7NjY2NzcXMlKqgwxS/KBEGJhYTF06NDVq1f//fffS5YsYfXjMgq9xr0HoXRzc2vWrNknn3zSqlWrZs2amZubE0L8/PygTN++fc3MzDZu3Jienq5QG37WarWgWeSGJTG0MZLVchxglcjNZqpDssK0jAHnnVRpuvAb/NLW1varr74aP358Tk7O8+fPX758iT+szhqDqGqdnZ3r1avXtGnTgICAwMBAT09PMzMzrVar0+nevXtXXl5+6dIl5SmyyA3myEJ7rKysWrZsOXLkyEGDBpmamj548ACbmerYGE4IsY0U5zTwslWrVvb29q9evYqJiVHuzWpiJ4T4+/v7+PiUl5drNJqSkpIbN25UWaec9mfP3BQK9w6UdHNzg9jjmTNnQkJC8K/iV5LAGXKusIeHR0BAQGlpqYmJSVFRUVhYWHp6OosAQRl4gHrYKPtPXZgYlRCGk1xD+9hYqoJYS3qI7NnS0tLW1tbCwsLIyEg50i0ONiAHBxBNTExq1qzp7u5ua2trZmYm2QbJUCN+4DiGH1SG00ZJiszMzNzc3Pz8/AYPHnzo0KFNmzY5OztzxeQUgZ2d3YIFC2glnD17dsqUKR06dGjQoEHbtm0nTJhw7NgxvV5/+/bt5s2bi9hFBiqrG04e2rRpExsbm5WVVadOHZEzkgxRxmJlZbV27VpqCMXFxenp6dnZ2fDnixcvPv30U0lcHDkMCxdlhvUVBtwKuZWVlampqYuLi4eHh7e3d+3atWvVquXu7m5mZmZkZIQ/IdV2V+XEgAhyIley+v0i9qlGo5k2bVpkZCSlNCIiokmTJlViUWgwVFivXr2ffvopNTWVUpqRkZGVlRUTE3P+/PmTJ08+efIkODh41apVcjoBsMihwO0fPXp0YmIipbSkpOTkyZPM28CEyzUSsEiu/mJJYBTBT02aNElLS6OU/vLLL0ZGRpaWlk5OTra2tjY2NhYWFg4ODjY2NkZGRqyq6vQ+IeTAgQNYpFu1amVlZVXlVxwtH7Wwp1arR48enZmZeeHChSZNmhBDXknyDfcLYxTOi2ElNRpNnTp1tm/fjokaNWoUKykqAT7NRFTN1WlilTRLkqRQoSgfKpXK2tp60aJFmzdvnjJlSp06dYhhdpBYCVY0jHhcc/PmzcPCwpKSkvbu3Tt8+HBLS0sFdcmlDOIZhihqIq+5r+DB399/3rx5ISEhlNLU1NTevXsDFsZ2ZSxeXl4HDx6E8Fp5eTnr8pycnLi4OPZnQkJC//79CSFGRkbGxsZiq1iTmKhV2eMajWb8+PExMTGwuqgA4udyg9PZ2fnOnTtllVBeXl5aWvrgwYNZs2bt378fZmmU0qCgoD59+igjZVgYPzFL2TPLgQGwtLRcvHjxnDlzHjx4gIfQ48ePx48f37JlS/BF2CBUngdUf6RoNBqdTqfVamvVquXn52dpacl9Xk11hvEyOZ8yZQoIQ1ZW1oYNGySbh5Wm2GxOaPv27Xv79m1KaWlp6cuXL4cOHTp+/Hg3NzcoY2Ji4uLiotPpJBumEibokqDRaL744otXr16BMJ87d87e3h5+Ul5hZs9YNYtJNzj3gT3b2tpeunSprKyMUnrgwIGRI0euW7fuxIkTa9euXbx48Zdffnnw4MFff/21ZcuWHMcUCCGEmJubnzx5khrC1q1bPTw8lD9kbWZYqmObGdKIiAhK6YoVK5jbqoCCYeF4JQqJvb395MmTw8PDKaUQJwCYOXMmLsa1iu0lkDYz3Jf/AJQNAFdGTsGZmJiMGjXq6dOnRUVFpaWlS5cu9fPza9myZa9evRT0pmT34Ow6b2/v0NBQSqler8/Kytq0aZO7u7tkVebm5s2bNwdHBoNaft+MSipDn4Gvr++qVauSk5Mppa9fv/7++++bNWtGZNglp87UarWnp+evv/66fft25u+DycnIyMB/Dhs2zNraetKkSdOnTxfbKXKMIGUq2X4vL6/Zs2f/+OOPCrXJgdzgtLW13bZtW2FhIRTIysras2dPu3btNBqNs7Pz999/n5eXBz/dunULRnuVWEC+xVxBjkAjI6OhQ4cmJiYWFBSUl5eDrsFQVFSUn5/
"text/plain": [
"<PIL.Image.Image image mode=RGB size=546x546>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"inference_results/20230222-161659.mp4\n"
]
}
],
"source": [
"generate_video = True\n",
"\n",
"ext = \".mp4\" if generate_video else \".png\"\n",
"filename = f\"{datetime.now().strftime('%Y%m%d-%H%M%S')}{ext}\"\n",
"\n",
"save_path = os.path.join(log_dir, filename)\n",
"\n",
"reverse_diffusion(\n",
" model,\n",
" sd,\n",
" num_images=256,\n",
" generate_video=generate_video,\n",
" save_path=save_path,\n",
" timesteps=1000,\n",
" img_shape=TrainingConfig.IMG_SHAPE,\n",
" device=BaseConfig.DEVICE,\n",
" nrow=16,\n",
")\n",
"print(save_path)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"execution": {
"iopub.execute_input": "2023-02-22T16:30:05.578985Z",
"iopub.status.busy": "2023-02-22T16:30:05.578388Z",
"iopub.status.idle": "2023-02-22T16:34:16.926424Z",
"shell.execute_reply": "2023-02-22T16:34:16.925434Z",
"shell.execute_reply.started": "2023-02-22T16:30:05.578945Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling :: 100%|██████████| 999/999 [04:04<00:00, 4.08it/s]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiIAAAIiCAIAAAB61jR9AAEAAElEQVR4nOy9Z3gUx9Iw2rPaVc45IgmEJEAEAQKRc85ggk0GBwwOmGAMmGNwABuMAXPIOZucc85ZAURSQDmgnKXVanfujzrqU+qeGS1+3++7z3Ov64c0O9Pd1VVdXV1dXd0tiKJI/oF/4B/4B/6Bf+D/DKj+367AP/AP/AP/wD/w/2kQRVEURUEQTExMCCHwlxAiCAJNQ59VKhX9if/iZJBGEgstXKPRMAloBQRBgBIwLh6dIAgYNUMLkwWnVKlUKpVKMjtPryRewCKXt05QwIuBx2JkRgV0TAkqlYrh2PsWSGRahD5DmxrDMUmxeS9gsEgSxb9Uq9UK5FOJpdmNbP06WcrQy6SnWHAFQHSJfJuamJjgYiV7qEqlgjJp60tWTBAEjUbDdxPcd+CB6a203ekzljETExNgOFSVdnmsOgAFfQBdAZ+gWEojzk5pYbq/AjDskmwRoQbgJaaFYSwmQbKPUHYxNcTkGN9f/h5gGiWx8EzjiVVITGpkjOGYtDrj8ytrAdz2CsOMQr3lyFMmRhKLkXJGs9MRTq4ymGtynfN/C9jmMTo9/okrLJmMvmGw8D2NZ6ZkR1Kom4Jq5jWgXFFUyzDZcWWUu43CcMv0dh41Tl9nuyiXpgx86yvwWUFFktojihwuHgsdRdRqNWQXpIAOeNRoo2VimYFPgIW2FO53hBs/4CcMRYwiYrBgs9jExEQOC88xBVUjxyhaE8ZcNqY0fjQSkEmNZZjW+f+ohqFAsfCmCdOsuIaMnsdywqP4r4wx4zOPks8pV2/8ieEdZhyVIYYwyaIYUVPGq2Cdkdp9if/JF4j7Fa7n/2UhYGrF1JDPqGwm05e8alYo828DxsJw8m/g4kuQsJukKkAQdXIDgFytGLbLtb6Ro6+Rww/TLkZ2OvxS0i3BpGFooYiwNsFqERQKHnsI6iYENQdGSgcAUEl8NSQVMamt7PDMg2pnrPje1yyT5Cqj/eTmsozK4sVSrVYTec7j6ZpkAp4WI8WmzsmAJBZe8iW7qiA1OmAZwK3PaBhBrJmbGwwGExMTvV7PFAQiYjAY6EtcOfpMB0ZISb9iaQYdzfCFlsyXz5Qpl4AWBZUHLDiZIAi0Vrj3UtLUanV1dTW8h2emcMwZZTOTyfi3gdIOlX+vYvnEcszELKJcatCgweeff25mZmZubl5ZWanRaPR6PbA3LS3txo0bz549Ky8vp8XWWTcFjvEtYsxXXmwI4hiThSmEppRjlwI5WJAU6FWg0UjA7SKXV65fkNrdkHKA/qXl43ahEs50YUw1LZxmx1YX32VEUTQYDFiScRtRzVNdXQ36h9RoDMKpHTmu0qKUe6VcduU3PIfhJ+UAI2DKgqQs6gQ1gZG01Nn7JDHSB36MkSxNoRV4bS9LC0WGVa0kgCjI9WF+1ME15hmH0zNTBDnC3kudMSxmmMsUSyuM8To4OPTu3Vuj0dy9ezcxMVESCwblXtG8eXN3d3cvL6969epZWlpGRUVduHAhPz/fGFp42nEDM9XGdDGCxTOQp0WlUk2aNGnbtm1yFbty5crLly+3bNkSExPDU40Hbx4LoxqU669AkeQn4wcApnqMYPMAfQaqCp2HkeT/FcMCw3spTckKSJqGTBbJwYwfjZg3Go1Gp9NhLIIggBUCKSk/GSww0oDJwgwMUI6khDs7O4eEhHh5eZmbmzdu3Pjdu3fnzp17+fIlLQEmNFAl3pCV48/78hOAMWVoA2HdhVWZHPNxFtwFMB8YLP+HZAwai2lHUjOZMxgM1MZVqVRBQUG9evWKjo5+8uRJWVkZkdctAprKEELU9AOMMYz+oqqZ9jSmlpjRUEsq3MYMCQCMlvkfKh2GcqghZRYzYyOIF/DX0tLSz8+vXr16bdu2nTdvXlxcXGpqamJiYp3CKlltExMTPz+/1q1bjx49OjQ01M/PD95nZ2dbWlpu3brVSDKZtqTkYJmWFEd+sJdDAX/NzMxKS0sTEhJ0Ol1cXFxxcbGvr6+zs3N2dnZ1dXVgYGDPnj179uxpa2u7cePGJ0+egNbAjS5IzR0JEi2oPE0vNwTyoJBAOS9TOK0tmFZUHkxMTBwcHMzMzMrLy0tLS6urq3luvy9qSVC2SIxPjGtlZ2dnYWHh6upaUlKSkpICRDk7O5ubm6enpyuzV6xxaWJlR1FQ3QfPOp3OwcEhJCTE19c3MzMzMjIyPz+fjjGEEGYgwVVl5IT2O6xYKc8dHR1btmzZsWPHgQMHOjs7P3v2bNCgQZA9JSWluLiY6ihlnSOKorW1tb+/f/369QVBqK6uNhgMer0+JSXl1atXkiwlnAEkyUBMBSgWLNu4KILmZ0w5lHbmE9MvjOkgCsao3IhFuyR8xVaUXq/HitHW1nbq1KmzZ88+evTot99++/btWzndAr41GL3+SwylEycSpBbrTExMXF1dGzZs2LBhQ3d3d3BB8qSSGl+qgCYWIhehwfz8e4B94oCFX4wBiWFqiN9oNBpLS0sXF5fAwMDJkyffuXOnoKBAFMXMzMyZM2e6u7tjXHz/4Z/pQ8eOHc+cOSOKYnV1dV5e3oMHD7Zs2XLt2jVRFCMiIhSI4scPZVBOw3xl2kVVs1YMf62srDp06NC7d29I3LZt2759+zo7OxNCxo0bd+bMmezsbKh/3759MWcYEigu3Pp11laj0bi6unp7e9vZ2RlPODMYy6GQfAmjS1BQ0OjRo5cvX75p06ZFixaNGTPGzMyMTyzZLhqNxtzcXKPR2NnZOTg4uLq6enl5+fr6enp6Ojs7W1lZGelbl8RSZ+Xd3Nxatmz5+eef//LLL/fu3du3b1+3bt0IISEhIUuWLPnhhx+YNQZaGb5d6gRBECwsLObMmZOYmCiK4uPHj9u1a0dqphQ4WJR2OooFDGShBphlf7r8A8/Ozs6zZs1KTU2F7Nu3b2/YsOHly5dFUfzll19ol6SFyBkBKpXK1dV1zJgxN27cqKioMBgM7969S0pKqq6uvnr1avPmzUGJGbOqgSVZTrrqVDUAarVao9GYmJiYmpqamZnR5WpQm8ZIskajsbKysra29vT09PX1dXd3Vwj3kASGY1SZS+oKQoiHh8e1a9d0Ol10dHSrVq1w9VQofFEai2TzyDE9MDBw165dkOX69eudO3eukxhqcSgPxcBiU1NTaAAzMzPoutAMpqam8Em5S1AhoIuHciyjQ6BarQ4ODu7Tp8/MmTNPnjxZVFQEhRQUFBw/frxXr16ga6Co91psJIT06NHj0aNHoigmJCRs27atZ8+eEJ05fPjwxMTEe/fu1UmLJKNAQEEiNRoNXlmFegLrJOWGGXoVaJGUAUEQbG1t169fn5ubK4rinTt3evbsKZkdM5/Hgr/i6qlUqg4dOly/fr2ysnLFihUODg5MlVTyQXQMFqb+vOQAA1UqlZub27fffvvs2TMw34qKivR6fUFBQceOHXm6eFqcnJw6duw4YsSInj17/vnnnwcOHLh582Z6enp+fv7z58+3bdv25ZdfNm3a1NLSUpJRPOB24W0mnEylUvXu3RusFjBl4CEqKmrx4sW7d+/W6/VLliwxNTVlClGhGDAjlSY8WFlZLVmyJCUl5erVq0lJSS9fvhw5ciQtkyZT1azPk9oTVlwmdHk+SMnU1LRZs2a3bt0qKSmBOcfChQttbW1tbGwOHTpUWVm5YMECFxcXSEx1Ah4yaYFqtbpZs2b79+9PT08/duxYnz59AgMDnZyc/P399+7dW1lZee/evWbNmiloFaZAapZhtlBi8V+mBIJUnJOTU7t27bp169a6detu3boNGzYsJCQEDxKS/QX/tLCwGDhw4LRp09asWZOVlSWK4uXLl8PCwviNInyDKmBRBmdn56NHj4KMDRgwQLJ8eMDk1xpmVLUj1nnCCCF+fn5//PGHWAM6na68vHzWrFkKWQDUarUySUFBQR988MH06dM3bNiwadOmZcuWbdu27c6dO8nJyREREVFRUVevXn316tWjR4+YgY3BKClqPLi7uwcEBFhbW6tUqs6dO8MMGthXVFR09uzZCRM
"text/plain": [
"<PIL.Image.Image image mode=RGB size=546x546>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"inference_results/20230222-163005.mp4\n"
]
}
],
"source": [
"generate_video = True\n",
"\n",
"ext = \".mp4\" if generate_video else \".png\"\n",
"filename = f\"{datetime.now().strftime('%Y%m%d-%H%M%S')}{ext}\"\n",
"\n",
"save_path = os.path.join(log_dir, filename)\n",
"\n",
"reverse_diffusion(\n",
" model,\n",
" sd,\n",
" num_images=256,\n",
" generate_video=generate_video,\n",
" save_path=save_path,\n",
" timesteps=1000,\n",
" img_shape=TrainingConfig.IMG_SHAPE,\n",
" device=BaseConfig.DEVICE,\n",
" nrow=16,\n",
")\n",
"print(save_path)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"execution": {
"iopub.execute_input": "2023-02-22T16:34:16.929146Z",
"iopub.status.busy": "2023-02-22T16:34:16.928274Z",
"iopub.status.idle": "2023-02-22T16:42:15.960070Z",
"shell.execute_reply": "2023-02-22T16:42:15.959235Z",
"shell.execute_reply.started": "2023-02-22T16:34:16.929115Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling :: 100%|██████████| 999/999 [07:45<00:00, 2.15it/s]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABEIAAAIiCAIAAABDhPpOAAEAAElEQVR4nOz9d1hUSdM4gJ4ZZoacc5JoAMyIghjAAOaw5rjGNeecdlddw5rz6ppzTqsY1iyKAcUICghIzjkPw/T9oz76a7rPOYy+7/e79z6P9QfPcE6frtDVVdWpWoIQ4n7AD/gBP+AH/IAf8AN+wA/4AT/g/39A+v9tAn7AD/gBP+AH/IAf8AN+wA/4AT/gGwEhJLIgI5Xyj3Pwc4lEQj6U1AC8wsUwFvyW+va/AiwvGAVQAv9SxJN/tbS04K8IkeISY2sWAhAXS6rmWFiMGpKBnwAWspVJQUFj4WblGH3Ab8VZw7zwqg37pE658X5bp8TEkYojoiQm8pWGQHYElhJWYhrSqclzlheyz4ooktAT8R4tIrHvtgBCvVIul3O11Z7VT6zPWNVZtcTWABsEKANYNNdVtk72LfVcIpGQWEj6JRIJmCbefg2k8lbISxjZ96VSqUwmowqTHV8EpFIproRFR+oYWVgTClnADQdAtQsmg/IypBeAT9iuLUIDRvodFoZ8RUoJN6uQJkNzaIKF18dxNUrLS5JMJqPahRcLL4WUbHkpwR0H8wKS5/2QwiKpAY5RZqoj4N+ktRTXJfIrtmZeW4F/k63P6xrEgWomIfIAi1BfpkAT78ALvJGSEC4sNArIluJqO3cAbMc0NJjsK3HjycsL1bPqRIdFrQkWVvlxL8PqDb8xGWBtJDWREu7+uAw2aCwvmCMhqtiSpA4LhVusFxMBzY0zBVQM8931cKKRzP9gIQUnZLk0wcHrOfCTbxLc9wFpzvAT8jfv6IXlgvwX/8YayWkcyP6HzIrwoiFqTbQfu38h58SaMCGTRFl21gF8n5GiUOBqWZeDJUb6S7ZyDY0CJUNsbjRsfV7UvM95y1BGs05/xgtC7gfjxVhIO8vSSdYjFP9RLo189U0S+z6g2oUMzXH8JKSr5ENMNhQgPwTd43X/QlT95w6Aq3F45FtKvTm+flSnwlDWkjKJ+HMS9fexo4nESGqpaRQhGoT6PtvKmjhRIStHGQEhTaY6qbhKiDcNbhfK1pGI2N+cBmaNkkOdvVLcerBk8z4ksUCF7DQZJTfev0LjE/hWaIAhRKpQVeJf8UZKvJ9rHklzDKeUr2TlQwEZfbG9mEXEG8OIi1fkibi0eXslRQ/vw28yNSKarInkRZ6TBcignLcbUmIhhytYaHj+BT+hZitw3xchW3MGheoRbxfxytmHeC5J3CZLakfgvMop0qeEnvwPFhElEHIAlHGnKKNIlHxvKKO5CcBABX9sGQ0bhnWluKSEGSz9H8F/gkVcG9jOSfp7XIZsUJEJVCH9I1uBl5dvMlW8QHl3sttQ5LG6Kl4zLsB2Ac3bhe2rGgKpY2QlQsOMOmvjBCwa5TJFQEiAktrTovgHOeX8/7K/kEsKIDreeU2JRCKXy83MzObPn29iYsLVMCgezn4fL9/haf7rEuM1ZWwoQ5XnmAV2tlrqORtwkzaZ10FgmYuoN1UtNXEmYQJZoWlICou0Zg2KKkw1OscX/JFWUShsZeUjIdYAOQHgDWRZIllBwV+QJ+ZLCNG36hjbFhTwhnfiFkbEzJKGhVcz8b8SYuDHFhAaA7MkUfrJTtIL8UIOzyjdI9Wemhhl+SWxsIwIAaW04sXYXqnJqI/8Qc2h8DIC3Yq1MHWqgSZkUIDXr9jyrNh5C2iCCE9gUW1KlQEgl2JwZ5TL5XiuUGTWWKRX8hpPFnr37h0SEnL16tUuXboIMciryeL/CgFvPyKx/LeGZFQZSpP5l0o16RIU0WQBae3lC+7bO+d3gEQg+COp4iWA4lqoYTAIqZqIHoi3Ja9MNAn9NTRhLIjML/KSiruuiEcRAdacsVi+aZFEBIuItaXqEZqCEkJEuZn/U2CxYHpIskV8NuWGNdQxsgvgryhJCs0N48LSulYv/xO/JQRUiMkx/R3vJoInJiYmGzZseP/+va2tLf6KVwlJSf5XWr9O1niDDCHrpMlbCkAOZH/RsF9r/opyMyIuQ7wb1okO4gYhy0/xReL6poUmTVqfV4ZC8YpQDWyQwYZNEiLEZ6sVil2ouTkhibGb2YSshybA+n1WUVkGNQFyDCnuKyk7QP6oc0gpxIsm/YVtF6FOStbGyws1vVInwQqFon///tu3b1+0aJGhoaEIL1SFIlrHopYKrOGL8/JfB3H/ItQKLIhLlQzKpbVXyPEAhqu9o5XqpLytT20SrlNiIp3dxsZm2LBhjx8/RgilpqZ26tSJLUPZse8Ow+oE1o7x0i8O7du337Rp0+TJk0U+BCy1bJZarcbvgA6KLPyWrAWXJ/9FCGlpaVVXV/OyR1X+XwGWKgrUarVUKsU8ApBPMPHwb8eOHVNSUpKSkqqqqrjaEqiTAOpfCqnIh7xYyIe8Bb5VmBQ9ZKNAX1Wr1SxSqo+xZEOL4w9FeGF7i5CIyAaimo+iEFMllUopxeOtnH0opOS8vIjrw39LwzHLmClSthQKKACvyMJktCpkwlgzQXbqZs2aeXt7v3379s2bNyTS9u3bOzg4XLt2raSkBNcDyiNisDSRzDdJD8deWHXBv6rValKxsVbLZDJfX9/x48dHRUWRLkTcbfOChg2tof5QGki2NW4Lqn9RISnuv5h4FouQDaR8D9m5oFkRQnK5vFevXg4ODtHR0S9fviwoKMDUckw3xIgo78AJCFxLS8vZ2dnS0tLLy0uhUOjo6Mjl8qqqqqqqqvz8/PDw8NjYWMpKQ7PyMkKaI0BEMk6aCFZE0FPIV2x/QTWzs1i72Aopg8k2IllbnTVgXCTxVGMhhAwMDBo0aGBlZdW4ceOEhIS7d+8WFhaS7AiZcS0tLZVKRYlCSL29vLyCg4PT09MvXryoVCo5ovexFgY/AR5JqWLt1dbWbtCggUKhaNiwYUVFRWVlpa2tbVVV1aNHj75+/UpKAA8/SBtIxRtYzSiMJOBXQJiQf6Eai/WDbFxB1UBJwNnZuV+/fmlpaffv38/JyREPD7gaXWWx8HbwRo0a9enTZ+DAgT4+Pjdv3jx8+HBxcTEvSdS/vPza2dl17NjRzs7u1q1bnz9/xpRgIeAeR3YxEU/NC1Aet6CTk1Pbtm2tra3lcnlWVtb169dzc3O52haPqkFI1NRbUvEoYWrucUCpYPaEYpP8F/oCOG7sviUSiZOTk7e3d2xsbGxsbHl5OVcz4KmzXah/SePp5OQ0bNiwkSNHenl5ZWZm7tix4+XLlyxfFI+s4tWpirxS5Y0VMUaRsIrlDsDY2Lh169azZs3q2bPnx48f9+3bV0frkC6QI9RRaGVGT0+vXbt2QUFB5ubm1LIvWQNVCdXzxYHCaG5u7uXl5eXl5eHh0bRpU5HxoiZYMJGYVLJCfX39Zs2adenSZdiwYU+ePDl58mTDhg2/AwsGmUxmYWFRr1691q1bBwYGBgYGduvWLSgoqHnz5np6eiIffhMWERCPw6jWZ79lnxsZGQUFBXXr1q1du3ZOTk4KhaJOjICFt+H09PSCg4Pbt28fEBDQoUOHtm3bdu7cuW3btnZ2doaGhrz7s9mzyBQv/yFI+GaY8FyLEBYROdcZCrPlSSx4pudb66kTxPs+/ldXV3fbtm3FxcW//PILJonjOLlcvnPnzoKCggYNGlD04x9CEjMzM+vYsaO7u7s4hVKptF69ejY2NprzghdP8CQZ7uaYMAcHh02bNhUWFo4cORJ3Q2qZizVuVLtQzPKC5k3GOyMrNJFcJ0CqA945Y8wXhUUTFI0bN544cWJiYiJC6MiRI46OjkKMYOUhJUa6Bvarhg0bTpw48fDhw/fu3UtPT8/JyUEE5Ofn79+/v0WLFrw2QbzvY4x1rvwIiZp3kwwJZmZm7u7ufn5+nTp16tSpU0BAQOfOnYUsJC9G/BuwiMiKLE+6MwMDAw8Pj3bt2k2bNu3atWsfP35ECKWkpCxZssTe3p7ChbGI1E/hogz44sWLEUJZWVnGxsYsFyQWrnYTsOwbGxs3bdp0xowZly5dCgsLy8rKKioqSktLg8/37dvXsmVLEWnwSkzkCBn3LSrB1e7
"text/plain": [
"<PIL.Image.Image image mode=RGB size=1090x546>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"inference_results/20230222-163416.mp4\n"
]
}
],
"source": [
"generate_video = True\n",
"\n",
"ext = \".mp4\" if generate_video else \".png\"\n",
"filename = f\"{datetime.now().strftime('%Y%m%d-%H%M%S')}{ext}\"\n",
"\n",
"save_path = os.path.join(log_dir, filename)\n",
"\n",
"reverse_diffusion(\n",
" model,\n",
" sd,\n",
" num_images=512,\n",
" generate_video=generate_video,\n",
" save_path=save_path,\n",
" timesteps=1000,\n",
" img_shape=TrainingConfig.IMG_SHAPE,\n",
" device=BaseConfig.DEVICE,\n",
" nrow=32,\n",
")\n",
"print(save_path)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"execution": {
"iopub.execute_input": "2023-02-22T16:16:33.646107Z",
"iopub.status.busy": "2023-02-22T16:16:33.645729Z",
"iopub.status.idle": "2023-02-22T16:16:59.754676Z",
"shell.execute_reply": "2023-02-22T16:16:59.753710Z",
"shell.execute_reply.started": "2023-02-22T16:16:33.646076Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling :: 100%|██████████| 999/999 [00:25<00:00, 39.22it/s]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAIoAAACKCAIAAAD65UUzAAA6L0lEQVR4nO19eVxV1fb4PnfiMiMgIJNDgKJoJKihmWamKUU5gdrzmaglT8UG0wafQz3jZaVmaT4rTcsUNRvUHCg1nFNRFCdQmS6IMskVLnC59+7fH0u2++69z7lX3/t+fv3R+uN+zt1n773WXmvttdYej4QxRn/BnxVU/78J+AsUAWOMMVap7smJfmZSJEmSJAlSIJGkkHTyTBdksKjVamWqSA0MDXQ6jZfGIszGV84X54sI6aSxSJIEhKlUKroq8lfVCkzlDJPJXygoSdJdLAQZg0CtVpO/hA4eGZ1ByA6oFrAIeSHkywO8RRTjkCNByslGQWYMFpITGki4x+irMBHZazzNT2C7QDySJGk0Gr5GmnRaBkzvIUrEUAYpNOMcgkMxyIEcFkZ1CEOFbx32bEbVGLVGrT1GqK9QuZxa0A+seMhrWjx0p0OUVNRqNW9tkLz2OWSck+k8Frphyr3nwUCuLYwMFIoQe4XsZQmJRGB0F1SpVHbiAXbTNlej0Qh7j0qlYvSLlpOw1D1kjloul67AX/qVnHgeDORqIL2HZq5yQV5XeG9EAOoELBqC0mazIUoHrVYrahUJvIJskiSRV8AL+FWr1TabzWKxMKXo4nw76XqYdGGKRqMJCQkJCwvTarWXLl2qqKiQqwRSGM0IDQ319fX19vbWarV6vV6n05nN5lu3blmtVoxxbW1tRUVFc3OzkAwGBbTLarUyvUHBTtDsonmCMfb09PTz89PpdHfu3KmoqCDCuyseklutVlssFpvNplarhVwDCZFfkg6CAXaQRJVKZbPZlBVZ2BhXV1edTldXV4cQcnd3d3V11Wg03t7ekZGRSUlJqamparU6PT197dq1DQ0NTCWAFFoEv5CCEBo9evSzzz7bp08fd3d3kr+0tBTyHzhw4Lvvvjt69GhjY6NDOqVWsNlsJPxB9jpHlIP0NldXV29vb1dXV9DmmpoajUbj6en56KOPDhkypH379rt3716yZAnBoiHIaC4jhKCLQC0kD6BXq9VWqxXSdTodajWGZrMZVI8ICT2QkQkICBgyZEhYWNiyZcvCw8OHDx/et2/f/v37BwUFQQabzYYxjoyMbNeu3dWrV5nioBBCAoKDgx9++GF3d3er1Xr58uXs7OxHHnmkffv2NputTZs2qampffr0Wbhw4bZt25wkFXDxdgLkQdJVKpWvr6+bm1uvXr1effXVvn37YoyrqqqysrJKSkrGjBnz0EMPIYRMJlNOTo4dzUxn5AMwOdBoNN27d//4448vX75cXV1dW1t78eJFvV7PMwXZjxXodB5FcHDwkiVLMMbNzc1lZWV1dXUmkwmMD1jgysrKbdu2JSYmenl5MWWJ0+Yr12g08+bNu3XrFoj2yJEjw4cP12g0fn5+np6e3bp1i46O/uKLL0pKSt58802FVvMck+wBcZFbUFDQyy+/fPjw4YqKCovFAo6jrq4OY2yxWFpaWuBh7969o0ePZkaKGoZTROxqtXrMmDFjx4795ZdfysrKEEJt27b19/d3cXHx8fEJCgrq06ePp6enp6enq6srVNrS0uLq6trU1ATdC7QGeqFCUwl2jLGXl9fUqVNnz56NENLpdMHBwSRnfn7+hQsX7ty5s2jRooaGhurqatLRGaBVmE739vZ2c3OTJGn58uUrVqwoLy+3WCzV1dUIoUuXLiGEFi5cuHLlyqqqKpok3qvRb8G6MP6PyM/X1/f5558fMWJEnz592rZte+fOnYMHD27btu3SpUseHh4TJkxISUkpLy/PyMgoLCy8evVqcXExcP4e00jkJkmSVqsFRuv1+sGDB1+5csVsNl+5cqWgoOD69esGg6GioqKiogJYQ6yqxWKxWCwNDQ3btm1zdXVFrWE3UR/nI7f4+PidO3dC5paWltOnTx87duzdd98dPXp03759u3TpEhERISPoe20R1qzVanft2oUxXrp0KS11ITh0liA2euTOF2zXrt1HH31UU1MDSjB37txhw4Z16NBBo9GMHDny3LlzVVVVRUVFM2bM8PPzo4Nh0k8wCaxJ0A3v4uPjCwoKwJi0tLQYjcaampr6+nroidXV1bm5ucXFxWBz4Leurm7NmjUuLi6IspAEq1w8w0BKSkpeXh7G2GAwpKWlDR48uH///gEBAQ4LKmCBFnXs2LGoqAhjvGvXrtTU1MGDBw8dOnTIkCFt27Z1snLCHNqEygly2LBhGzZsqKmpuXHjxoIFC5KTkyEYUalUL774Yn5+Psb4999/T05O9vHxIaWAXfRA/q5xIyEWWIagoKCIiAjI8f333+fm5tbX16tUKr1e7+XlVVVVVVpa6uvr27lz5z59+sTExHh7e4PYcGvUSIc0TNQv5CDGODQ0dNCgQVFRUbdv316/fv3nn3+uzCzGAZC/JE4DANvg5uYGqjN8+PC4uLiysjKNRmO1Wk+dOnX06FGj0fjHH38YDAYFjJiLY0kkzeAdPnz4m2++2b9//8LCwiVLlnz11VctLS0IoUcffXTYsGEvvfSSXq//7LPPMjMzDx8+TDeBhBJ2aDDGWq2WfpGYmAhWC2P8j3/8w9/fX47oF1988dKlSxjj2trad955h9RD+g34SWd6T1paWklJCcZ4586d4eHhypkl0cwT77TJs5+f35YtW4qKim7evFlfXw+GgUQcVVVVH3zwQbdu3eRmdGjPT7CQ+THSXpVK1bVr12vXrmGMjUbjokWLoHhoaOjIkSNPnTpVVVV1/Pjx9PR0Pz8/ZB9WMBjteg+4WeizZHQJwoyLi9uxYwdQQ/yNRqOBPLdv34ZRgiRJ7dq1I1WDcyP9VJnXCCFPT08Ipi9durRly5aSkhKGO4Q2gkK5ZqZ71dTUvPvuu88884xGo9Hr9S4uLlar1cPDo3v37hEREX5+fnPmzImMjFy9enVOTo7JZDKZTAyzeBTEgUutQ3WVSjVu3DjgQ3Z29u7duxFC3bt3nzRp0tSpU11cXDZu3LhkyRKIRJD9IJLp9PcQYy6wHjx4MOiX2WyeP39+YGAgeUs7fITQW2+91djYiDEuLy+fNWsW6T10iIkodyrHzfHjx4NvyMzMDAoKklpnj+iVC74UM3dOewUnwc3NLSEh4dtvv62srIR45NChQxMnTlSohOYYHxosXLgQjHxOTs7s2bOnTZt2/vx5jPGNGzc+/PDD9u3bE8qZnkqYJggNEMV6nU43ffp0oNVoNKakpLi5uSF7S0Lm1l555ZWbN29ijCsqKkjYLkkSyUDPIMkx18/PLysrC0xNcXHxd999t3Hjxtzc3KtXr3755ZcjRozghzjCengsNPuEhouoUb9+/fbv3w813LhxIzU1lZYQ/Qx5aNWhqwoKCjp79izkaWlpAR6uW7cuNDSUj/QYzaMz2ImHjhkiIyO//vpr8D02my0hIYGui0gRVh8mTpwIMZ7BYBg6dCgkQh56upBmHE+lm5vb1atXwSRarVaz2QwhIsa4qamptrb24MGDzz33nFAkNGsUfI8ckDySJPXr12/NmjWgbZWVlePGjRNWRfoo6Cu/AJaWlnb48GHwcBjj+vr6p59+mq5KrVYrBOXi3kNeDxw48Ny5cxhjk8mUlJTk4eGBkN1sAqJWFqZMmQL+/Pz58z4+PjxKEt4oeyBQCACj0VheXg7DPQK7d+9+4okneGbRQGPhp+4VsAO4uLiEhoYOHTr022+/xRhfv369d+/ecljkWooQ8vb2HjJkSHZ2ts1my8/Pz8zMTEtLi4yMpImnl++YlTMai4b+T0b7CKHq6upPP/30559/hrf0RBadLTg42NfXF/Ddvn2bsAaiasbXCaNqgJUrVx45csTX19dkMlVXVzc0NLRp08bDw6Nv375Dhgzx8/Pr1q1bRETEgQMHkHOxBhNbO8yPEGpubjYYDAaDwWg0du7cOT4+Pj4+Pj8/n7RLCMBWMveKMa6rq/P19Q0ODs7Pz1+6dOnp06ejo6MXLVq0du3a/fv3Q04iY0k0l0oqvzclSppUVla2Zs0anU63fv16Plgi8zTw18vLy93d3WKxkEE
"text/plain": [
"<PIL.Image.Image image mode=RGB size=138x138>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"inference_results/20230222-161633.mp4\n"
]
}
],
"source": [
"generate_video = True\n",
"\n",
"ext = \".mp4\" if generate_video else \".png\"\n",
"filename = f\"{datetime.now().strftime('%Y%m%d-%H%M%S')}{ext}\"\n",
"\n",
"save_path = os.path.join(log_dir, filename)\n",
"\n",
"reverse_diffusion(\n",
" model,\n",
" sd,\n",
" num_images=16,\n",
" generate_video=generate_video,\n",
" save_path=save_path,\n",
" timesteps=1000,\n",
" img_shape=TrainingConfig.IMG_SHAPE,\n",
" device=BaseConfig.DEVICE,\n",
" nrow=4,\n",
")\n",
"print(save_path)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-13T14:06:41.407250Z",
"start_time": "2023-02-13T14:03:27.834241Z"
},
"execution": {
"iopub.execute_input": "2023-02-22T16:21:11.522425Z",
"iopub.status.busy": "2023-02-22T16:21:11.519897Z",
"iopub.status.idle": "2023-02-22T16:23:35.317409Z",
"shell.execute_reply": "2023-02-22T16:23:35.316345Z",
"shell.execute_reply.started": "2023-02-22T16:21:11.522385Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"inference_results/20230222-162111.mp4\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling :: 100%|██████████| 999/999 [02:20<00:00, 7.11it/s]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARIAAAIiCAIAAACLx9BEAAEAAElEQVR4nOz9ZXhWxxMwDu9tcXchRkhwDwR392LFi7u3QIHSYqUtBYpDgUKhxd2KS/AACSGEQARC3F1v3f+Habab3XNObujved7rvZ7Oh1wn596zszM7Ozs7uzsjwxij/+A/+A8+BuT/v27Af/Af/P8hYIwxxjKZrNqSCoUCHvjCgm/ol4AFISSXVxmopE4a5HI5fEv/JdUqFAqmcvJMsDDtMYY6vkL+K4VCIZfLAQtDCN94iQbAtwQLjQ4+VygUpF8IwFfySiAfMuwiZYC38AwP0H6lUqlQKGhaBHtBjARCODyQb2mG0PQyHKMpQhzbP7anaAAsKpXKyA/F6gGKBAWM0PKPqBkPpBbSl/BAepHHygh0tQTQfUMkgBRmBI6A8bR82kD6WCz/piU0FkbU6JFJ9Aiq7AUyTuAnerDxnzMCjYSGh6A0S7zhlZoYx3hREUPBaxYxjtHYGeUrpm2rJY18LjpsBNtEc5BHIKab6WL0sBHT0wSUSqVgi5nG8CDYPYLdKcYawWdpLDIOBD+X+IlvJ42FUX70PMP8Sp4ZXEqlUrA9/wiBJB8EO13sX0TJA+kpCUUjyBMjBZr5l8ZCdIfYV/RLf3//vXv3hoaGfv/99z4+PkwxRkOxwwbeCjZXEHjxlVBUtCn4UaOcrqTa8fa/nQf+PRa+58Ro5EvSZi2vNfixQWtiGWUFyCjrDnEyLWii0/8KzjnVKh0aaINT7CtpkfgodSaraviImY4A5ubmAwYMOHHihFarxRiHhYV16NBBum3VG2k8+ySUKDLOumUMYgLE0hDEYiRIdM//ED55cEo0jP+JVzT0yCHDRoyftNwISjw9bD62ndWqV4l5wJgRQjQ402zBf5k5TXoc0p/L5fLAwMBLly7hSoiKiho2bJj0V1DyH3NILpcbDAa+KBAMP/EsxpWWMc0XaD15Q6ol7NPr9Uw9BoOBfMVjsbW1VSqVFhYWvr6+RUVFr1+/1ul0NCU0m/4vTDgMEGJpqnl+Sv9EaOE/of8ihAwGA2gi8hL4aWtr27RpU5VKlZyc/O7dO61WC1gAHeGtQqEwGAx8XxOA8nRLlEpljRo1ZDJZQkICqtp9DC1iPcgzgQbCN/IXUJDCIGa0IJG/hDTmJzJ+yFeMZJqZmXXr1q1du3akcHFxcWFhIU8UD/8MG8EhgRBSKBRWVlYIISsrKysrK3hZWFiYn5+v0WhQ5aiA9hF80lgFRYR/Y29vb2Ji4unp2bJlS1tbW29v75kzZ75+/XrZsmUJCQnv3r0rKysTM9A/DegOMB6YPgMgHLC2toZhn5KSAqOdcIzRNeRbet5WKBREhmDq0Ov1NKuVSqWNjY2rq2uPHj1++eUXmUx29+7dzZs337t3r7i4WKfTMVj0ej2vvJnZgNECjRo1mjZtmlqtnjt3Ll0MOt3MzMzDw0Mmk6Wnp5eXlzM9wkwRguxlXvKaF4lLlMT0QppBo4a/KpWqQYMG48aNs7S0hF8TExNPnDhx9+5dxA1FgUbSVTNGJPxbr169ZcuW/fzzzxcuXIDCGo3m9u3bHTp0MDU1lW43zQWaAAkg5kft2rV37doVGRmZl5en0WhwVcjMzOzYsaMYmz4NJEx8QSy855cHf3//tWvXZmdn63S6Nm3aWFlZmZiYkA8ZDyH9LIiFQadSqczNzTt16nT06NH4+HiMsU6ng3Gi1Wr379/v7+/PN4/2TAKWat0tixcvfv/+/aFDhwT506xZs6dPn2KMu3fvbmJiwlcFWGRVF8+MpUf7LRhWQG0mJiY8n3mvIM8lwd5p1KjR48ePQYno9XqM8fLly2lflKBZ+0+/CIoajXjHjh0FBQWkPwAMBoNGo4mOjh4yZIggFxggJAn+Si+c4Llx48YJCQmwUMvMzLxz586hQ4euXLlSUVEBVd25c6dhw4Z0g/+H0w6jQQRpMWYx1qdPnxcvXsCYv3r1alxc3IcPHyIjI7/55hsXFxfpBtC0EK+AjFrvIoS2bt1aWFio0+nApsIYFxYWlpSUAMbk5OTFixcjqjdhu4ZgoTlGM59oLijWrVu3J0+eXLp0qVatWjR/oLC3t/fKlSsLCgr0ev3atWstLS15HjL9wg/OXr16LVy4sHnz5oLcqFu37qZNmzZs2ICqOuhodqGqEwDjR0FV+8jV1XXlypUajQbmbeDe2rVreTnkGwxYqrh66d/IX3d3d6VS+ddffx09ejQ6OhpjXLdu3eDg4JkzZ9atW3fcuHEJCQnh4eFgSEgbOfRPsqqLH2KKGAyGxo0bb9y40dfXFyH05s2b77777urVqwEBATNmzID5LT09fe3atW/evEGUhPF4PTw8AgICmjZt2rp162fPniUkJPTt2/fdu3cIoVatWlVUVNy9e9ff39/R0TEtLW3Hjh0pKSmMkS0BNEZaI+BK66JNmzYLFy6sV6/eoUOH/vjjj9GjR3fr1g2ktqCg4MmTJzdv3hSsmV8tgM1A86pr164LFy5s27Yt2M8IoYSEhH379t2+ffvDhw9Hjhzp3Lmzq6trzZo1EaVTwDwj6yJ+TSK4gnVycnJ0dLS1ta1Tp058fDyxoAwGg5eX15w5c+bOnWtiYpKZmbl///6ysjIkLgaAmje3Pnz4MHny5LFjx96+ffvBgwfx8fHW1tYGg8Hd3b1Xr17t27evVavW1atXkYgI0fznX9JLOxAzBweHJk2awMYo6a/58+c7OTn9+OOPHz58kFErIsStrKp0P70jRkPXrl0HDBjQoEEDMzMzeGNhYeHq6jpz5sy0tLTy8nJQabQtwezHIXHziVceDg4O8+bNg/IhISF9+/aF9yNHjkxKSsIYV1RUfPXVV9bW1oJdQnOwb9++169fz8nJKSsrS0tLi4yMLCwszMzMTE9PLysrKy0tff/+fVZWVnl5eUFBwU8//eTl5cW3kG4qzTHekKChbt26Fy9eLCoq2r59u4ODg1wuj4mJwRhv2LDh/Pnz+/fvd3NzYxqPquo5Zk6jHcEuLi537txRq9VQJiMjY+3atc2bN3dwcEAI1alT59mzZxjj7OzstWvXIkrx05VUMTkooM+CQPmtW7dijK9fv96gQQOGTF9f30OHDkElmZmZfKfQHEPUvMewTqVSderU6Ztvvtm2bdupU6cSEhLi4+OTkpJycnKIfX7u3DnB7mCwCM78zMvZs2eXlJSAHsGUZatWq8+fP9+qVSvBD2lJZl0CuNINQHwpt2/fZjhbVlZWVlZWXFwMplp5eTmiVBczUgUJIEyERsBQhr9+fn7du3dHCBUWFn777bchISGOjo7Dhg2bNm2al5eXVqtdt27d77//XlxcDBUSXcK4WRBCarXa3t7e0dERIWRubu7u7o4QsrGxIY3x8/ODBzMzs3Hjxnl5eR04cODGjRuCLTdyFgXo3LlzUFDQ7du3v//++5KSkn79+mVkZBw7diwxMVGv1z9//jwjI0OsBpp7MAOTqVgmk6lUKn9//06dOiGE7t27FxERERYWdu3atczMTISQj4/P6tWr69SpgxB6/PjxyZMnUaWyxFV30GHW4mkhDgPomnbt2rVv376oqOjWrVtRUVGkJQAlJSWpqalarValUmk0GmM21pgaoG1arfbu3buvX792cnIKDAx8//69ra1tw4YNMzIymjRp4ufnp9frS0pKyFe8p4tWavRLQjv5t2fPnqNGjbK0tNRqtWTuVSgUOp3OxMSkV69e9vb2Fy5c2L59O6wIBPuI1QQS20PkJ4VCYWFhcf78eYzxmTNnmjVrJs0sJDLb0L1InocNG5aZmYkxjo+Pr1GjRs+ePTdu3JiQkIAxLi4uPnr0KOhUvpG87rSzs5szZ050dHRubi78VFRUlJSUlJGRAUsCfSWQZduDBw9GjRplZ2f3sbTQJDg4OFy4cCE6Oho2Aezt7fv169evX7/mzZsfO3Zs9erVMNVILJ/onS5moWViYlK/fv3169f//PPPnTp1srW1hfc1atQICgr64YcfYBZ6+vTpwIEDkZBHgR4tTO8z5Jibm588edJgMDx//rxnz558O93
"text/plain": [
"<PIL.Image.Image image mode=RGB size=274x546>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"inference_results/20230222-162111.mp4\n"
]
}
],
"source": [
"generate_video = True\n",
"\n",
"ext = \".mp4\" if generate_video else \".png\"\n",
"filename = f\"{datetime.now().strftime('%Y%m%d-%H%M%S')}{ext}\"\n",
"\n",
"save_path = os.path.join(log_dir, filename)\n",
"print(save_path)\n",
"\n",
"reverse_diffusion(\n",
" model,\n",
" sd,\n",
" num_images=128,\n",
" generate_video=generate_video,\n",
" save_path=save_path,\n",
" timesteps=1000,\n",
" img_shape=TrainingConfig.IMG_SHAPE,\n",
" device=BaseConfig.DEVICE,\n",
" nrow=8,\n",
")\n",
"print(save_path)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-13T14:03:27.628269Z",
"start_time": "2023-02-13T14:00:18.376509Z"
},
"execution": {
"iopub.execute_input": "2023-02-22T16:23:35.323346Z",
"iopub.status.busy": "2023-02-22T16:23:35.322602Z",
"iopub.status.idle": "2023-02-22T16:27:47.271778Z",
"shell.execute_reply": "2023-02-22T16:27:47.270801Z",
"shell.execute_reply.started": "2023-02-22T16:23:35.323305Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"inference_results/20230222-162335.mp4\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling :: 100%|██████████| 999/999 [04:04<00:00, 4.08it/s]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiIAAAIiCAIAAAB61jR9AAEAAElEQVR4nOz9d1xVR/M4AJ9bgEvvHelSLdhiR7Bg773EWGI3xhZr7C1qTKyxd2OvJDZsKCJWBAGp0osgvZfL3fePedhn2T3ncM3z/b3v5/N+Mn/Auefs7uzMzs7M7s7uShBC3L/wL/wL/8K/8C/8vwHp/68r8C/8C//Cv/Av/P81IIR4BzQSiaTJNwDSBiCT4Z/wDFjgE285ZC6JRILTSCQSsnAWNSSGBIAFJ4YHXCZOz0uOOL0ymQw/AxZ4Q5aGq43/4q8UczBduOYsYFqE2E7xQSaTCRXFSxE8q9P66lRABJ1EIiGxkIxSp56UIIlkxFioZDKZDHObEgMhRPCJbHRcc5Zj4rSINIpcLucILpGlCWHh7VziHZMtnGoXSm4pYPsLS5E4B8h2oXoN2Si8fKCwUH1KIpEADzmGY1gpkQqE6q1kr2T7LC/5gEVDQ0Mmk8nlcplMBhSRigjXUy6X4/bCNaHaEcsYWR9Sw+BPZHYASAAdn2IjpheEH2pCyjOvJP+fA9b8HGkASGAl72vVDZkek8QrRiJdkbcXCcmi+owT6ldUTXgrxouFFFxWF/BWnmPEmpdjTRbCW1URxUG+/1qOCWEXBwpLk2ZGzWpQb0ilyXY8Clh2kZJG9mqqtk1yjJdRZDlNNgpHqDO2quwb0p3iLZYXr4QwM0LVFqreVwGFBfcRca8IoyO9BHjDKmtSaZIvcXohQljVzL7nGBkDza6hoYHtBNnxSfMgaTBylCCR5IOtIj+xDhOv/42ZQxUoIYAjhBknozj2/xSaMDMsUKynDACvj8OrznjFS0gmKDeEKpnKJW7MhIB1dsSB8jWaBOxwqYmFFAKWUl49y/YQoZ8Ueyks4iDe3CLAmhm2ZPEKqEMgpc4ojrEC1qxZs+7du+OXuE8KeeuSxuMMlhvUmyblipcPXGNFI2QhyGeKe+ICwEuLOAhZHcq8ibcLPKup8UXqTAJZAZJjvJ1FiNv4WcJnAkl/hWts/jFqGNZQLiZVFNbvUqkUj3JwyaS14AT6ixD5uDIsUE1D/f1n2vJr4T+9EpvNJgUOaib0EjOIN5k6WMiiKC7zFiuVSlUqFZkMfpJ8pNKIUIErSaHgOA6XwNs2QrRIRP0F3oqRVEBGmUwGD0IMAUQqlQoKpHiI6SK/kn/VpIWXOVR92GcMLBbeZPi9rq5u//79AwIC5HL5+fPnw8LCKioqWI7xYsFKB3OGJcHV1XX48OHdunUzMjJKSkr69ddfY2NjyQpQzUcWhduFlAqyAkJ8a926tYmJSceOHe3t7SMjIwMDA3NycijCgUapVFpfX88RDgdJHdWOnHAHYVlNdQrxXilUoDpfSaDsJUkIyyVSjeASMEMAZDIZ/CTz8mKhKgzlNCnPuBUo9cJr+8kORdErzh+oj0qlItsOWlypVHKNJRkhBFT/dxqqAQWIolQqVSgU7dq16927t6ura2Ji4vnz5+Pj40mukoRLJBKSh/8YHBwcWrZs2bx5c4TQu3fv3rx5U11dzXJMTr4S5w7ve/ySffhaMWVLIH+y1RPpWhyhW0Xek61LyjGuIZVMCBfbz0UMDABluoRykZaDJIH6S/5kWYTRkQ2vvv3j7bHw0tzcvKamprS0lLeor5Ul/N7BwWHFihXt2rUrKip6/PgxNA3VFkKIEDGgwcqC1ALOzs4zZ86cPXu2vr4+x3G+vr6XL1+Oi4ujVDavb4GBbBeRhsafNDQ0Fi9e3KxZMy8vL0tLy+LiYjc3t8OHDyclJUFRuFiOkA3cmqQlYC0NbzXI6lFVFep3GDDTeDkslEscKDZiEoTUDvmT+oS7KtlMvFgwgLuGM2KkvCUA+aSloQwJ+QZ7JDKZTKlUAiKVSqWjo+Pk5OTl5eXg4AAYVSpVVVVVVVXVly9f4uLiPn36RAotWAuVSoWnSUiMYGMwdVgeQHHJZLI+ffr06NGjXbt2ffr0gSxGRkarV68uLy+ndLuQYvwqsLCwcHNz69ChQ9euXVu3bu3g4KBSqV69evXrr78GBQXV1NRwjVvtP2YGeIpNnAhIJBJ3d/fCwsLy8vL27dsbGxuXl5fn5eWlp6dXVFSAHcZsovJS1HIN85gGBgZeXl51dXW1tbXa2toKhaKurq6kpERXV1dLSys3N/fz58/FxcUcwXpWmKhewRHdlXxPZiSFnk1JlszbMLj9QMJ4+7mBgYGdnV1RUVFubi5lM/T19R0cHGJiYljmUIWIG2zWtOD683Y5cbeX9eB4MXIcZ2JiMnny5IqKiqNHj/KKzddqIgBNTU1XV9d27dpxHPf06dMXL15UV1eL63TyJe6BpMOOs+vr60+dOnXmzJn6+vrAt+rq6oKCAope1HjMx/EN6XA/5+2xpPw0b968e/furq6u5eXlV65c0dLS8vf3nz9/fkJCQmpqam1tLSs5+BlkBiu7Jj0AihDKX+Z15FkSeHU3K5lCWp6XFtJk4lzUuIp80NHRMTc3z83Nra6uxun19fU9PT319PTkcnlUVBQ5HBRhCKXWyK5BigeWHFKJ4ZdCHROzV6lUyuVysDQcx7Vq1WrJkiXdunWztLQkK1NfX//x48enT5/u2LEjJydHqVRSk368HZDyPLBrAmbJ19d3+fLlfn5+HMdVV1eXlZWZm5t/++23r169On/+PC7qn5kWSjwsLS3bt2/v7+/v6+vr5eWlqamZk5MTHx/frFkzX1/fmJiYyMjI9PR0uhSS1+pgbd68+f379xcuXGhpafn06VOEELic8+bNGzlypJOTk6mpqba2tvpY5HJ5x44d09LSiouLP378mJ2djRAqLi5+9epVdnZ2eXl5RETE/PnzraysID0ZUkWutnECWpVMQD6QQS//YKWBcm3wAzVFGxAQ8PDhwxUrVlBV5TjOxcXlzJkzCoWCly0kLeKrWeQDxos9ejILXv3D6dl2obIoFAp3d3cHBwcdHR2qAsOGDcvMzMzKyhoyZAhvXl5auMZheywbdXV1v/32W8gSFRUF9kYdINuF+ovbZeTIkbGxsWBN6+vrP3/+fOzYMTMzM1wBMjEvT9SRMRIcHBzOnDkTExMTEBCAXw4ePLi2tnbLli2ampq8JZBYMBWsGMjlck1NTR0dHZlMplAodHV12TqI9Gu29cUlTQioRRohLGwuNotMJjM0NHRycvr2229PnDgxbtw4FxcXW1tbW1tbLy+vWbNmZWZmQoHfffedEC24a+MeIQ6sEAotqmEseKEFUFDL+AqFYu3atZC4qqpKqVRWVVVVVlZWVVWBV5Gfnz9//nwtLS2Se7goVsNQyopknbe394sXL0Cks7Ozr1y5cuHCBeDS48eP7e3tKdIwW1gZY59JMDIyat269fbt28vKyhBCX758CQwMXLx4sb+/f8eOHffs2ZOQkLBixQqsqBtxTKTbsGBqarpq1SqE0Llz54yMjGxsbPr27btkyZJHjx4hhOrr6z98+HD06NGJEyfa2tryNg8vJXp6eq9evaqrqwPXQKlUwgMJR44cIa0XrxqlsOCWYxce27Rp07FjR5yM+ioumrzNw5tFT09v2bJlNTU1R48epdS0XC53c3ObOXMmBAiQdaBMZpOxBmz9eTswBZBMSNQwTJs27eXLlzdu3OjZsydVw4MHD1ZXVyOE7t+/z1aDrF6TMkbm9fHxCQ4Orq+vRwjt2LGDcgZZKvAza2ZYOHbsGNQZuvqsWbNwOZSNwYqGkp8mOUa+d3R03L59e0xMTLdu3chCRo8eXVtbu2HDBqH1WyFaKCa3bdt25syZU6dO7dy585AhQ2bNmtWlSxd3d3dra2sdHR2WKBEs4oRQgJ0VTU1NkXYnseACheRZIpEYGRm1b99+69atcXFxWBUUFRXl5eXl5ORAw8HMUmZm5siRI0lE6ugx6BcKhUJbW1tLS0uhUGhqavI2AdWv8V/MMSiKDfQC6kaMGPHmzRuEUHV1dUZGRlFR0atXrx4+fHjt2rUnT54kJyffuXMnICAAR6mRwBHtQvmLZAQ8fggICID0NTU1EyZM4Dhu+vTpOTk5CKFPnz61bNk
"text/plain": [
"<PIL.Image.Image image mode=RGB size=546x546>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"inference_results/20230222-162335.mp4\n"
]
}
],
"source": [
"generate_video = True\n",
"\n",
"ext = \".mp4\" if generate_video else \".png\"\n",
"filename = f\"{datetime.now().strftime('%Y%m%d-%H%M%S')}{ext}\"\n",
"\n",
"save_path = os.path.join(log_dir, filename)\n",
"print(save_path)\n",
"\n",
"reverse_diffusion(\n",
" model,\n",
" sd,\n",
" num_images=256,\n",
" generate_video=generate_video,\n",
" save_path=save_path,\n",
" timesteps=1000,\n",
" img_shape=TrainingConfig.IMG_SHAPE,\n",
" device=BaseConfig.DEVICE,\n",
" nrow=16,\n",
")\n",
"print(save_path)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"ExecuteTime": {
"start_time": "2023-02-13T14:02:13.807Z"
},
"execution": {
"iopub.execute_input": "2023-02-22T16:27:47.273986Z",
"iopub.status.busy": "2023-02-22T16:27:47.273382Z",
"iopub.status.idle": "2023-02-22T16:30:05.576098Z",
"shell.execute_reply": "2023-02-22T16:30:05.575242Z",
"shell.execute_reply.started": "2023-02-22T16:27:47.273949Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling :: 100%|██████████| 999/999 [02:17<00:00, 7.24it/s]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiIAAAESCAIAAABPRClNAAEAAElEQVR4nOy9d1hVx/M4fG7jXnqVohRBuoC99941KpoYe09MbKn2FhNLEmOvUWPvvXdUsAKKBaUJSO8dLnC55/1jvuxn7u45B5LP9/d7n+d9M3/wHM7ds1N2dmZ2d3ZXxvM89y/8C//Cv/Av/Av/Z0D+/zYB/8K/8C/8C//C/6eB53kY0MjltMuRyWT/oEJZLQhiqSewxIi9FMQil8vrLFzP+lkhABZ4Tz5hWSY/KRQKMUnir+RyOflXJpNhLP8Y6vycbRf2E/KG/UmpVAoiBQAJEF6kq4XyAJRsWdQsJQqFAmP5B60vARgdxQtpXDFREyGwTEkAwYKlRLSaYKTkLJPJFAoFVYCrVS0sWHiQ6JUS7AA6TAn1Camf4kUCgGxcD/7L/koQ/S0s/wxwEwAWghceiMyxJqtUKrlcrlQqcRtxHKdUKokAob0IdwQR2/oYcHNTz0QHqBaHr7Ahwr2yTiuBeyUrGUEdoOwY3TwSBrE+BLGY/i5LbG31R0qwsC0HQFpCos467SyWGFsYOx78QAmQpZBqrf8T3YallpWYYDECrHGpD4g5AMpkCLZLnS1FTKegaa4nhfXXhzods6AVZssIRnWsoSHlsQnDEQy2WZT540S6D3lZTx2jmoxynKwnoD6k2oWYQolmkuhWYiDIC9sZxeqXBrZdiBBwc5A4ifgSEAsmHscK2DFgYRJnRn6iYlBK87F6UHaGQg0/gW+jAllKVn/XSArCf3qlmKrVv69KMEYjEyeUFcc/AAmXib2xrDboUygUAQEBO3fuHDt2rIWFhRhtHArbKQdQT+FQIGas8RsJEyAmn38gNyo6o8j4W06FcpyCQQZFbT2dihgBVE8QxEKoEuxvgpVLc00ZGuqr+pDN0kZFGBzTLpShIS+x75GWEh7lQGDLCTkzaYJZOiXwEvhfDJgoCWDfVn8s9dFnMT3E7YL/crXjWuIwsOMnrkKlUnGGnkYQHeFFpVLJZDL4iwGKkaESJoaKQlhxEW0R7Pv1lAyrimLlAYsSF8XtBBTo9XrSqfR6PYVJr9eT9+RbUowTNzG4PPsVfv4vgTCF6+R5vqamhuM4a2vrX375ZfDgwcbGxmFhYcXFxZQQyL86nY5jRMRyQaRBiQUCnGHDhnXs2PHixYv37t2jvv1b3ktMPv9YbiTYhAfyL9WymDa2TUGkHMOLWBADykMVqI+q1PlekCpWwTCzGAS5FsSO9QH3DtJ3rKyshg4dGhAQoFKpqqur79279+DBg7KyMihWU1MjrVGkTqqBqMKkRXDTAAHwF3BhlnnDmR+KZUDEopPL5VCPYLfCwH5LASHsb3kgqjBp3L8F9cEoaARY4atUKmwZQM4kqmYtZ3V1NVdrWrlaYw02lmDBvken05HP4T08KxSKmpoawIVblvQpLB+ZTKZWq7VaLegbYCQemrQC+YmTbD6xXsMxlgEGTPD8HzfDKhbbM7EFEdMzXIkETRKgVqs1Go2rq2tBQUFqaqoYG2JIMUfwIKjTNjY2c+bMGTJkCMdxL168KC4uZj+UMAG4b/PMIgrrRLt06bJgwYKWLVtaW1u/evUqLy+PKs/KrU6Q0AbBn6TLU8XEPq+P/f3HZJNfzc3Ne/fu7ejoaGZmxvN8VVVVXFzc/fv3y8vL61MbvITeKE1MfUDQ4IJGsQWIj3Fzc/v0009nzJjRpEkT+Kl3796HDh06efJkWloapTNiMsGIcB9k3xPUYC9w14O/pqambdq0KSsre/78OYeaidTDdluKKiLMOlVLzC8SEPPljRs37t+/f1JS0r179yorKzFhf8snSRf28PDo0KGDo6OjTqdTKBTl5eVVVVWpqakJCQn5+fkFBQW4sITBBbcBfoUzlBs4AMF4i/gGHJoQTqnyUBLCFMIXOBi9Xk8eiIiwQXZ2du7bt2+bNm0KCwsvXrwYFRVVXl5OxXY4yOANhxaUGO3s7Fq0aBEUFFRQUBAeHv727VvcuVjjjH8VnQsSFC55aWZm5uzs7OTkZGRkxPO8VqtNSEiorKzMzc2ltLP+9s7Hx8fOzs7Ozq5JkybGxsaBgYGZmZmnT59+/Pgx9tiCIGgFyL/Yx5CHtm3bLly4kOO49+/fX7x4MT8/n3wo0eFZjFi3CCKqBlNT0zlz5jRv3pzjOGdnZ19f37CwMFYI/2VwV+dP0hGKGAGUzbKwsFAqlYWFhcTqubu7u7m5wRCNQ3ZWjDYcC5MhP7aVSqVy4sSJixYtcnJyIl+9e/duyZIlZ8+e5ThOoVCQZhLjlITwdYKTk5OlpaW9vb2ZmVlqauqrV6/E5EOA9EyquYER+HXAgAHLli0zMTFJTEyMi4vz8vJq0aJFo0aNnJycjhw58ubNGzz+gwCTIlipVEI8CyjAeFEBE6EBPhccTLu7u3t5edna2rZp0yYqKur58+diyoyBUoaGDRva29s7ODjo9fqqqio2MmPHQNLKLKhsQUFBO3bsePHixeTJk1+9esW6LupDsU7E+kgAlUrVqVOn8ePHjx492szMDP8UGxsbFhb27NmzAwcOVFRUSFDOGdp9MiLB/9bU1BgbGzdq1MjHx6eqqkqj0eTk5Hz8+DEjI4MMMjhDFZKh6S/MDmcYYcB7vV6vUqmgxUl3IF1PrVY3b9585MiRn376Kc/zarW6QYMGGzduBK0T1B/4SyGCvxqNplmzZoMGDerVq1e7du0UCsWVK1c2b9589+5dGMzVbbVYmygImHlLS8vg4OAbN27wtVBUVLR3797169cHBgYKzoZLYzEyMmrVqtXx48cLCgp4Q3jy5MnIkSPNzc3Zr2TMzBKLhS0DVDk5OW3YsIHn+eLi4nnz5lEKR9HPjlTEeCFRCf5Eo9FMmjSJcLRt27ZGjRoJfo5nbOvZLv8lsPrNPhMwNTVt0qRJ9+7dW7RooVar4aWXl9fmzZvv3LlDlcfLkoQXWS2QYgQdftmyZcuSkhJKE2pqakJDQ2GRjMpwI/E7/zdz80xNTd3d3detW3f79u3CwkKe58PDw+3t7SUqwVgkijk5Of3xxx88zyclJQ0YMECpVE6ZMiUyMhI+v3//vq+vLyUElhciGUpKeBRFPmQXn2AxoHHjxgcPHiwuLr558+bIkSNNTU1JmTolJpfLraysHBwcOnTosG7dOrD7PM9XVlZOmDDByMgIF2broXjhGBvKQr9+/aD8t99+a2xsLF0YQz37S9euXR8+fAjq9OHDh2fPnkVHR8fFxcXHxxcVFfE8X1JSMmzYMIo1CgssinCoRWAVRIayMNRqdfv27U+dOkUUODk5ecWKFb6+vlgIxFqSGojE8MQaZ7jKSAiQoWVmrla81tbWYJ+fP38+f/58a2vrH3744caNG4MHD+aQblC9UkxiGo1mxIgRYWFhpaWl27ZtmzlzZnh4eE1NzaVLl9q3b8+WpxaE/geLYPMI6jGAqanpvHnzSktLeZ7X6/U6na6qqoqIMiwsrGfPnsQGkc8pLLhaU1PT2bNn5+XlEQ1OTU198uRJUlLSu3fvdDqdTqf75ZdfLC0tBaXAsiQhMvLThAkT8vLyampq7t69C+tyFFA+pk6XSS2+YYxt2rSBARnP8zqd7rPPPuNql+9Ywgg6FsvfMqD1/IqVmJgBNTc3HzVq1OnTp3v37k2svJmZ2bp1627evGlqaopdCKW4FBYsJWxD4U379u3BHOv1+srKyurqahKpZWVlDR48mKRXUTTX09AQXC4uLqtXr37//j18CGqs1WqPHz+Oa3ZxcXFzcwOWBduFY/wEx3E2NjaLFi2KiYk5fPgwvFcqlUuWLMnKyuJ5vrCwcOXKlWI1EF7IWj3rnjHg5X34F94YGRm1bNly69ater3+9u3bzZo1oyTA8gJ208jICPJxvb29ly1b9v79+6qqKgicdTodNMebN28GDRokSA+2tiwWaYXs1q1bcXGxXq+/du2ao6OjWP0SISbrjAlYWVmdPn2a5/nq6uoPHz60aNGC4zhvb+/AwMA2bdp8//33b968qa6urqioCAgIEPSIgAXnbsAzqAfWzE6dOl26dAn
"text/plain": [
"<PIL.Image.Image image mode=RGB size=546x274>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"inference_results/20230222-162747.png\n"
]
}
],
"source": [
"generate_video = False\n",
"\n",
"ext = \".mp4\" if generate_video else \".png\"\n",
"filename = f\"{datetime.now().strftime('%Y%m%d-%H%M%S')}{ext}\"\n",
"\n",
"save_path = os.path.join(log_dir, filename)\n",
"\n",
"\n",
"reverse_diffusion(\n",
" model,\n",
" sd,\n",
" num_images=128,\n",
" generate_video=generate_video,\n",
" save_path=save_path,\n",
" timesteps=1000,\n",
" img_shape=TrainingConfig.IMG_SHAPE,\n",
" device=BaseConfig.DEVICE,\n",
" nrow=16,\n",
")\n",
"print(save_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}