MasterThesis/diffusion.ipynb

479 lines
1.1 MiB
Plaintext
Raw Normal View History

2024-03-31 14:51:57 +02:00
{
"cells": [
{
"cell_type": "code",
2024-04-01 00:16:59 +02:00
"execution_count": 17,
2024-03-31 14:51:57 +02:00
"metadata": {},
"outputs": [],
"source": [
"import gc\n",
"import os\n",
"import cv2\n",
"import math\n",
"import base64\n",
"import random\n",
"import numpy as np\n",
"from PIL import Image \n",
"from tqdm import tqdm\n",
"from datetime import datetime\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.cuda import amp\n",
"import torch.nn.functional as F\n",
"from torch.optim import Adam, AdamW\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"import torchvision\n",
"import torchvision.transforms as TF\n",
"import torchvision.datasets as datasets\n",
"from torchvision.utils import make_grid\n",
"\n",
"from torchmetrics import MeanMetric\n",
"\n",
"from IPython.display import display, HTML, clear_output\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Helper functions"
]
},
{
"cell_type": "code",
2024-04-01 12:38:50 +02:00
"execution_count": 45,
2024-03-31 14:51:57 +02:00
"metadata": {},
"outputs": [],
"source": [
"def to_device(data, device):\n",
" \"\"\"将张量移动到选择的设备\"\"\"\n",
" \"\"\"Move tensor(s) to chosen device\"\"\"\n",
" if isinstance(data, (list, tuple)):\n",
" return [to_device(x, device) for x in data]\n",
2024-04-01 12:38:50 +02:00
" return data.to(device, non_blocking=True)"
2024-03-31 14:51:57 +02:00
]
},
{
"cell_type": "code",
2024-04-01 00:16:59 +02:00
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"class DeviceDataLoader:\n",
" \"\"\"包装一个数据加载器,来把数据移动到另一个设备上\"\"\"\n",
" \"\"\"Wrap a dataloader to move data to a device\"\"\"\n",
"\n",
" def __init__(self, dl, device):\n",
" self.dl = dl\n",
" self.device = device\n",
"\n",
" def __iter__(self):\n",
" \"\"\"在移动到设备后生成一个批次的数据\"\"\"\n",
" \"\"\"Yield a batch of data after moving it to device\"\"\"\n",
" for b in self.dl:\n",
" yield to_device(b, self.device)\n",
"\n",
" def __len__(self):\n",
" \"\"\"批次的数量\"\"\"\n",
" \"\"\"Number of batches\"\"\"\n",
" return len(self.dl)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def get_default_device():\n",
" return torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"def save_images(images, path, **kwargs):\n",
" grid = make_grid(images, **kwargs)\n",
" ndarr = grid.permute(1,2,0).to(\"cpu\").numpy()\n",
" im = Image.fromarray(ndarr)\n",
" im.save(path)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"def get(element: torch.Tensor, t: torch.Tensor):\n",
" \"\"\"\n",
" Get value at index position \"t\" in \"element\" and \n",
" reshape it to have the same dimension as a batch of images\n",
"\n",
" 获得在\"element\"中位置\"t\"并且reshape以和一组照片有相同的维度\n",
" \"\"\"\n",
" ele = element.gather(-1, t)\n",
" return ele.reshape(-1, 1, 1, 1)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"element = torch.tensor([[1,2,3,4,5],\n",
" [2,3,4,5,6],\n",
" [3,4,5,6,7]])\n",
"t = torch.tensor([1,2,0]).unsqueeze(1)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[1, 2, 3, 4, 5],\n",
" [2, 3, 4, 5, 6],\n",
" [3, 4, 5, 6, 7]])\n",
"tensor([[1],\n",
" [2],\n",
" [0]])\n"
]
}
],
"source": [
"print(element)\n",
"print(t)"
]
},
{
"cell_type": "code",
"execution_count": 25,
2024-03-31 14:51:57 +02:00
"metadata": {},
"outputs": [
{
2024-04-01 00:16:59 +02:00
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[[2]]],\n",
"\n",
"\n",
" [[[4]]],\n",
"\n",
"\n",
" [[[3]]]])\n"
2024-03-31 14:51:57 +02:00
]
}
],
2024-04-01 00:16:59 +02:00
"source": [
"extracted_scores = get(element, t)\n",
"print(extracted_scores)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"def setup_log_directory(config):\n",
" \"\"\"\n",
" Log and Model checkpoint directory Setup\n",
" 记录并且建模目录准备\n",
" \"\"\"\n",
"\n",
" if os.path.isdir(config.root_log_dir):\n",
" # Get all folders numbers in the root_log_dir\n",
" # 在root_log_dir下获得所有文件夹数目\n",
" folder_numbers = [int(folder.replace(\"version_\", \"\")) for folder in os.listdir(config.root_log_dir)]\n",
"\n",
" # Find the latest version number present in the log_dir\n",
" # 找到在log_dir下的最新版本数字\n",
" last_version_number = max(folder_numbers)\n",
"\n",
" # New version name\n",
" version_name = f\"version{last_version_number + 1}\"\n",
"\n",
" else:\n",
" version_name = config.log_dir\n",
"\n",
" # Update the training config default directory\n",
" # 更新训练config默认目录\n",
" log_dir = os.path.join(config.root_log_dir, version_name)\n",
" checkpoint_dir = os.path.join(config.root_checkpoint_dir, version_name) \n",
"\n",
" # Create new directory for saving new experiment version\n",
" # 创建一个新目录来保存新的实验版本\n",
" os.makedirs(log_dir, exist_ok=True)\n",
" os.makedirs(checkpoint_dir, exist_ok=True)\n",
"\n",
" print(f\"Logging at: {log_dir}\")\n",
" print(f\"Model Checkpoint at: {checkpoint_dir}\")\n",
"\n",
" return log_dir, checkpoint_dir\n"
]
},
{
"cell_type": "code",
2024-04-01 12:38:50 +02:00
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"def frames2vid(images, save_path):\n",
"\n",
" WIDTH = images[0].shape[1]\n",
" HEIGHT = images[0].shape[0]\n",
"\n",
" # fourcc = cv2.VideoWriter_fourcc(*'XVID')\n",
" fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n",
" video = cv2.VideoWriter(save_path, fourcc, 25, (WIDTH, HEIGHT))\n",
"\n",
" # Appending the images to the video one by one\n",
" # 一个接一个的将照片追加到视频\n",
" for image in images:\n",
" video.write(image)\n",
" \n",
" # Deallocating memories taken for window creation\n",
" # 释放创建window占用的内存\n",
" \n",
" video.release()\n",
" return\n",
"\n",
"def display_gif(gif_path):\n",
" b64 = base64.b64encode(open(gif_path,'rb').read()).decode('ascii')\n",
" display(HTML(f'<img src=\"data:image/gif;base64,{b64}\" />'))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configurations"
]
},
{
"cell_type": "code",
"execution_count": 29,
2024-04-01 00:16:59 +02:00
"metadata": {},
"outputs": [],
2024-03-31 14:51:57 +02:00
"source": [
"from dataclasses import dataclass\n",
"\n",
"@dataclass\n",
"class BaseConfig:\n",
" DEVICE = get_default_device()\n",
" DATASET = \"Flowers\" #MNIST \"cifar-10\" \"Flowers\"\n",
"\n",
" # 记录推断日志信息并保存存档点\n",
" root_log_dir = os.path.join(\"Logs_Checkpoints\", \"Inference\")\n",
" root_checkpoint_dir = os.path.join(\"Logs_Checkpoints\",\"checkpoints\")\n",
"\n",
" #目前的日志和存档点目录\n",
" log_dir = \"version_0\"\n",
" checkpoint_dir = \"version_0\"\n",
"\n",
"@dataclass\n",
"class TrainingConfig:\n",
" TIMESTEPS = 1000\n",
" IMG_SHAPE = (1,32,32) if BaseConfig.DATASET == \"MNIST\" else (3,32,32)\n",
" NUM_EPOCHS = 800\n",
" BATCH_SIZE = 32\n",
" LR = 2e-4\n",
" NUM_WORKERS = 2"
]
2024-04-01 12:38:50 +02:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Dataset & Build Dataloader"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"def get_dataset(dataset_name='MNIST'):\n",
" \"\"\"\n",
" Returns the dataset class object that will be passed to the Dataloader\n",
" Three preprocessing transforms,\n",
" and one augmentation are applied to every image in the dataset\n",
" 返回数据集的类对象\n",
" 这个类对象将会被传递给DataLoader\n",
" 数据集中的每个图象将会应用三个预处理转换\n",
" 和一个增强\n",
" \"\"\"\n",
" transforms = TF.Compose(\n",
" [\n",
" TF.ToTensor(),\n",
" TF.Resize((32,32),\n",
" interpolation=TF.InterpolationMode.BICUBIC,\n",
" antialias=True),\n",
" TF.RandomHorizontalFlip(),\n",
" TF.Lambda(lambda t: (t * 2) - 1) # scale between [-1, 1]\n",
" ]\n",
" )\n",
"\n",
" if dataset_name.upper() == \"MNIST\":\n",
" dataset = datasets.MNIST(root=\"data\", train=True, download=True, transform=transforms)\n",
" elif dataset_name == \"Cifar-10\":\n",
" dataset = datasets.CIFAR10(root=\"data\", train=True,download=True, transform=transforms)\n",
" elif dataset_name == \"Cifar-100\":\n",
" dataset = datasets.CIFAR10(root=\"data\", train=True,download=True, transform=transforms)\n",
" elif dataset_name == \"Flowers\":\n",
" dataset = datasets.ImageFolder(root=\"data/flowers\", transform=transforms)\n",
" return dataset"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"def get_dataloader(dataset_name='MNIST',\n",
" batch_size=32,\n",
" pin_memory=False,\n",
" shuffle=True,\n",
" num_workers=0,\n",
" device=\"cpu\"\n",
" ):\n",
" dataset = get_dataset(dataset_name=dataset_name)\n",
" dataLoader = DataLoader(dataset, batch_size=batch_size,\n",
" pin_memory=pin_memory,\n",
" num_workers=num_workers,\n",
" shuffle=shuffle\n",
" )\n",
" device_dataloader = DeviceDataLoader(dataLoader, device)\n",
" return device_dataloader"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"def inverse_transform(tensors):\n",
" \"\"\"\n",
" Convert tensors from [-1., 1.] to [0., 255.]\n",
" \"\"\"\n",
" return ((tensors.clamp(-1, 1) + 1.0) / 2.0) * 255.0"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset ImageFolder\n",
" Number of datapoints: 4317\n",
" Root location: data/flowers\n",
" StandardTransform\n",
"Transform: Compose(\n",
" ToTensor()\n",
" Resize(size=(32, 32), interpolation=bicubic, max_size=None, antialias=True)\n",
" RandomHorizontalFlip(p=0.5)\n",
" Lambda()\n",
" )\n"
]
}
],
"source": [
"dataset = get_dataset(dataset_name='Flowers')\n",
"print(dataset)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualize Dataset"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"loader = get_dataloader(\n",
" dataset_name=BaseConfig.DATASET,\n",
" batch_size=128,\n",
" device='cpu'\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA64AAAHiCAYAAADoA5FMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9Waxta5bfCf3G18xuNbs/7W0jbtyIyN5dpl0yKhvbVWWgpILCpMAq8QBP1AvYCIH8YCEheECinigBT5RUUBRNQSEol6tcWFhygpts7MyMjIz2tqfd7epm83U8fHOuvc+NyHREOlLkwx5H66zdrD3XXHN+zfiP8R//ISmlxL3d273d273d273d273d273d273d2x9RU///PoF7u7d7u7d7u7d7u7d7u7d7u7d7u7ffz+6B673d273d273d273d273d273d2739kbZ74Hpv93Zv93Zv93Zv93Zv93Zv93Zvf6TtHrje273d273d273d273d273d273d2x9puweu93Zv93Zv93Zv93Zv93Zv93Zv9/ZH2u6B673d273d273d273d273d273d2739kbZ74Hpv93Zv93Zv93Zv93Zv93Zv93Zvf6TtHrje273d273d273d273d273d273d2x9puweu93Zv93Zv93Zv93Zv93Zv93Zv9/ZH2u6B673d273d273d273d273d273d2739kTbz47x4d9ny8nfO6WjppcXRE1MgxoQPga7vEQERQRAAQowkEiklIJEAhSAiGGNQSmGMRkShRJFSIqbEMPTEGAkhIKLyMUV+8KQkv1N+XxAlKCUopYDxPcf3hkQ+LaEoS7TRxOjyYZTw4Owxbz99nxfnN6y3PXH8DIzv+8b/Al/44u5J3T29P7Dl895/94Uv0+1P0ptf73/7xuvy/0rg3ccnpDDwW7/zawyDI/hAiIkUQVQ+adGAKJRSxJiICVKMkCCGiChQSkjT5/Mp34fxuicgxURKt58j36dxbKQ03YrxvMYYioDRBlFCCJ6UIEoaby6IJJQCYw0pQQgJQiLFRAyRmPL9Pj054+d/9o/T3ryiXb2mX6+RlDh6fILSioTgO0ccPLoqAdjdXEMMEAK2rDFFSQoRRFBlSXCOfrfNr4kRsQalNbYocc7T7Xr6rsc7T1UqrNXUsxnRB4be0fUO7wNVU2KsoSwrxBiUtSiVP19MPl+vCESQmD+vaEHPG8RYxNb5c0YwZkFKwsff+VWGvgOEGBIhRIxSKBEg4UOi7SJaa6zRNLVFa2HwAzEknI9one93iJEUIYS4n9sx5vsp4z0PxP39slqhteDbSPSJ4G9Hq4zz04hgreX07AxjNEYr2q5ncJ6mmWGtoSwL8mVIeO8JIdB3PdoWnL3/JZJS+NspzDTEp2kyfX37nG5/l+5+PX5PeuNrEsT9/JHxb9J+/E7H2K9l4xtN7zl98QPz8vYHt1M4xfG1cTxG5PRwwdGi5jf+6a9yfXOF1qCUQmuNVgYlirbtCCGQ1Hh+IYHK0yOE/OZKBCWCFmGmFKUo/OCQmKh8wmpFaQyH2tAoxefBsUmR58HTJOEIRZESJp/d/rMbBCuKmxTYkbiS/PnqkLAIpQiGvL4PRqiU4WlZEwEPPIsdnVH81C/8KWKCzz/7LF+HOzYt1enO+vZDLt+4rgqi8nqidN4f8jqSN4O8Z+RrmH+kfuA4KeYBpSQfS43rDEnwIREi+JiYFqqAwkfFxglNoTmeG5aNoSo029bT+cTFFg4bw9OTkm984xs8e/Z8f77TfvXmfnHnA057jZL9ngayH7tffPn+WdJ0+NtnQJLs11xR5LV3v2Wl/LPxb/LP07iPTudzZ8xGxvOX/bj84N2fQkTzD/7Br9APAyne3sMfxZS63bd1TEiCIoFKoKfTvHO1pufEfokkQt6HlKCMxjY1pigoqpoYIzEGVqt1XqO74Y33FwFr9f46hxDx/s0xeXJywi/+4i/y+uo1L86f5zVK8l4nIih167+IUuM1vjMe99dsvA/jOQXviCkSUyRN407y2hBTGOfG7djLd08Qye+HqPFPBFF6/77qzvtBIgSPDwHn3bgnp/1cEHXrHz05e4t5PefXfu0f0bZbqsrQlCWzekbf9zjviSpijGE5n9F3nnbbs+vy72yhsVazPGywRmOMwg0R7wLnF2tSgsKa/bo7OJf3mBRJKe9RSgvaCHVVoZUmBpevh7j8GcVgVEGMcH21hgRGK7ROKJ2wRqEUxKQgCRIVZZkoioSThC0bPnz/T/Lq1Wt+/dd//UcbpH9A0zqP4P3eAG/sJX+Y9sEHH/Dhhx/yq7/2q7x+/fp2QZje9/fwS2VcdLL/LKM/d+fvUva13vhMk4/9Qz+T/N7v9UNP5I63+sWFX958vTWGP/1LfxqRwHe+9xtYESpjMIWgdGLoPTEJPpXoskSXNZcXV3RtS1MZmqrg0elhni+S8KP/GGMixEjbDwTvCd5zcrikspa27wkx4kIgSt6l3BAAOGiq0d8CHyEgzJoaozXeu4xtYn6fECO79ZbgHMp5VEpoIMY3fY4YE4uTRzz68s/yO9/8Bh99/DFK5XkfY3zTzxKYN3kt2/Vhf+VSHB9MfrigDGiT55rWwsHcZH8+gQ8J7xNa5z1Rq3zebRtRCpROlEV+/WbjCWHEB+Pnq6sGqw0hJmL0DHFHUarsE+sahabtHCnlcxmcx/nAT3/1Fzg6PP3hg+WLYyf9GDPo+3//U/7D//5/xjP5iOf6U27Sc4a4o+0D6+2WZy9eZOdXC1obSND2HTFGXPD7AW60xmjDwWJOWZbM5jNKW2KMJfjI4Byvzl/T9z1tu8MYi9YGrfXkr+YFHkGPQNXoDKSKUmELQ1mXpOTz4h8diUhKLm9uWnH88DHNfEY/XIMETGH4L/3Lv8x/69/47/J/+U9/ld/6zjMcmoSMG4QA6tZRGoFGdpQUCbWf9G8A3cnBujMF907MfiBNv7uzOOxH3eQk5800Tp5ByhtcGn+Xpu8nZzuGO7+Lo2Oendb/zl/5c/juNf/mX/9lLi8vWa+3dG3EDQlTCsoINBqtDdZWdD7gfCIMnuQCbtejrcLU+jZnvwlIAlPZffDBu5BB0ZCDA0ZpjCi0KGR0UkQllBKssnnzVsJyMaMwhs1ui0+RXkWSFjBCYSPGCAdHc2KE3dYT20DoIl3b4YPHhcBf/HP/Cv/b/9X/gY9/9T/ik9/427z4rd9GRcef/St/DttUeNFsP7uifbli/vZjIpGPf+0fE3Zb2O04efIuy9NHDLselKZ6/JDt9SUvv/dt0q4FN1AcLSmamsMHD7m5XPPZ91/w/LOXrC6veftxxeHBjPc+/ArtuuXy+QWfv7jiZtPy1pcfsThY8ODJWxSLBeXRIabUKCP0bpWBSS/IEFFdoqzAVJrmZ76Enh+iDt4hhIRzgeXyZ/HO8r/8n/7rXLz8FEmWrh3YbXuWZUVpNDEmtrvAR89a5rOGg8WcL727ZDYzXNxc0naO66uBqimpmoJ2N+BcpNs6+sGx2rQMQ8S5hCryPd/iSJJB0+GiYD6z3Hy/Y1gHtqtITEJUCjMCqIUxHB8e8Jf+8l9mOW9YNCWfPHvJ+eU173/5KxwcHPDkyQlGgyGyWa9ot1tefPaS2fEZ/9J/76/ji4p1iOhxioU4LbR57oQEMUJM4GIGoX4E3yEl4hjgcCGNjlLeQMIYIIspMoRATBBEEX0gOD8C+YQfn4PPfxND2B8zB37ywp3ubD4pkU/qzuaeSKTgSTESoyNGT/ADf+mXfoo/8zPv8m/8t/8K//Af/33qOlGUBbPZglnZUJqKjz76jG3bQanxPjLsPLpMKJvYbQIpQKEtldbMreGDouSRsazOb1Cd4+nac1RVPDo84M/UM94rCv73u2t+23f8B+2a96Lhz8SChymwSJEhJQI5iLFQhmNl+VXf8f3k+BWbr8M7beBYFA+0YZkUVoSrecGTesZfffAevQhrgX+v+4xny5J/69/9f9IPgf/dv/Pv4IPfr3wJSGHc3GO+TvFO4C3GW6dFK4VSGm00WivKqshBGVvsHXtd5MBSWU4BUnMLrFCQEsF5lOT
"text/plain": [
"<Figure size 1200x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(12, 6), facecolor='white')\n",
"\n",
"for b_image, _ in loader:\n",
" b_image = inverse_transform(b_image).cpu()\n",
" grid_img = make_grid(b_image / 255.0, \n",
" nrow = 16, \n",
" padding=True,\n",
" pad_value=1,\n",
" normalize=True \n",
" )\n",
" plt.imshow(grid_img.permute(1, 2, 0))\n",
" plt.axis(\"off\")\n",
" break"
]
2024-03-31 14:51:57 +02:00
}
],
"metadata": {
"kernelspec": {
"display_name": "DLML",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-04-01 00:16:59 +02:00
"version": "3.11.8"
2024-03-31 14:51:57 +02:00
}
},
"nbformat": 4,
"nbformat_minor": 2
}