| 
									
										
										
										
											2024-03-31 14:51:57 +02:00
										 |  |  |  | { | 
					
						
							|  |  |  |  |  "cells": [ | 
					
						
							|  |  |  |  |   { | 
					
						
							|  |  |  |  |    "cell_type": "code", | 
					
						
							| 
									
										
										
										
											2024-04-01 00:16:59 +02:00
										 |  |  |  |    "execution_count": 17, | 
					
						
							| 
									
										
										
										
											2024-03-31 14:51:57 +02:00
										 |  |  |  |    "metadata": {}, | 
					
						
							|  |  |  |  |    "outputs": [], | 
					
						
							|  |  |  |  |    "source": [ | 
					
						
							|  |  |  |  |     "import gc\n", | 
					
						
							|  |  |  |  |     "import os\n", | 
					
						
							|  |  |  |  |     "import cv2\n", | 
					
						
							|  |  |  |  |     "import math\n", | 
					
						
							|  |  |  |  |     "import base64\n", | 
					
						
							|  |  |  |  |     "import random\n", | 
					
						
							|  |  |  |  |     "import numpy as np\n", | 
					
						
							|  |  |  |  |     "from PIL import Image \n", | 
					
						
							|  |  |  |  |     "from tqdm import tqdm\n", | 
					
						
							|  |  |  |  |     "from datetime import datetime\n", | 
					
						
							|  |  |  |  |     "import matplotlib.pyplot as plt\n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "import torch\n", | 
					
						
							|  |  |  |  |     "import torch.nn as nn\n", | 
					
						
							|  |  |  |  |     "from torch.cuda import amp\n", | 
					
						
							|  |  |  |  |     "import torch.nn.functional as F\n", | 
					
						
							|  |  |  |  |     "from torch.optim import Adam, AdamW\n", | 
					
						
							|  |  |  |  |     "from torch.utils.data import Dataset, DataLoader\n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "import torchvision\n", | 
					
						
							|  |  |  |  |     "import torchvision.transforms as TF\n", | 
					
						
							|  |  |  |  |     "import torchvision.datasets as datasets\n", | 
					
						
							|  |  |  |  |     "from torchvision.utils import make_grid\n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "from torchmetrics import MeanMetric\n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "from IPython.display import display, HTML, clear_output\n" | 
					
						
							|  |  |  |  |    ] | 
					
						
							|  |  |  |  |   }, | 
					
						
							|  |  |  |  |   { | 
					
						
							|  |  |  |  |    "cell_type": "markdown", | 
					
						
							|  |  |  |  |    "metadata": {}, | 
					
						
							|  |  |  |  |    "source": [ | 
					
						
							|  |  |  |  |     "## Helper functions" | 
					
						
							|  |  |  |  |    ] | 
					
						
							|  |  |  |  |   }, | 
					
						
							|  |  |  |  |   { | 
					
						
							|  |  |  |  |    "cell_type": "code", | 
					
						
							| 
									
										
										
										
											2024-04-01 00:16:59 +02:00
										 |  |  |  |    "execution_count": 18, | 
					
						
							| 
									
										
										
										
											2024-03-31 14:51:57 +02:00
										 |  |  |  |    "metadata": {}, | 
					
						
							|  |  |  |  |    "outputs": [], | 
					
						
							|  |  |  |  |    "source": [ | 
					
						
							|  |  |  |  |     "def to_device(data, device):\n", | 
					
						
							|  |  |  |  |     "    \"\"\"将张量移动到选择的设备\"\"\"\n", | 
					
						
							|  |  |  |  |     "    \"\"\"Move tensor(s) to chosen device\"\"\"\n", | 
					
						
							|  |  |  |  |     "    if isinstance(data, (list, tuple)):\n", | 
					
						
							|  |  |  |  |     "        return [to_device(x, device) for x in data]\n", | 
					
						
							|  |  |  |  |     "    return data.to(device, non_blocking=true)" | 
					
						
							|  |  |  |  |    ] | 
					
						
							|  |  |  |  |   }, | 
					
						
							|  |  |  |  |   { | 
					
						
							|  |  |  |  |    "cell_type": "code", | 
					
						
							| 
									
										
										
										
											2024-04-01 00:16:59 +02:00
										 |  |  |  |    "execution_count": 19, | 
					
						
							|  |  |  |  |    "metadata": {}, | 
					
						
							|  |  |  |  |    "outputs": [], | 
					
						
							|  |  |  |  |    "source": [ | 
					
						
							|  |  |  |  |     "class DeviceDataLoader:\n", | 
					
						
							|  |  |  |  |     "    \"\"\"包装一个数据加载器,来把数据移动到另一个设备上\"\"\"\n", | 
					
						
							|  |  |  |  |     "    \"\"\"Wrap a dataloader to move data to a device\"\"\"\n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "    def __init__(self, dl, device):\n", | 
					
						
							|  |  |  |  |     "        self.dl = dl\n", | 
					
						
							|  |  |  |  |     "        self.device = device\n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "    def __iter__(self):\n", | 
					
						
							|  |  |  |  |     "        \"\"\"在移动到设备后生成一个批次的数据\"\"\"\n", | 
					
						
							|  |  |  |  |     "        \"\"\"Yield a batch of data after moving it to device\"\"\"\n", | 
					
						
							|  |  |  |  |     "        for b in self.dl:\n", | 
					
						
							|  |  |  |  |     "            yield to_device(b, self.device)\n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "    def __len__(self):\n", | 
					
						
							|  |  |  |  |     "        \"\"\"批次的数量\"\"\"\n", | 
					
						
							|  |  |  |  |     "        \"\"\"Number of batches\"\"\"\n", | 
					
						
							|  |  |  |  |     "        return len(self.dl)" | 
					
						
							|  |  |  |  |    ] | 
					
						
							|  |  |  |  |   }, | 
					
						
							|  |  |  |  |   { | 
					
						
							|  |  |  |  |    "cell_type": "code", | 
					
						
							|  |  |  |  |    "execution_count": 20, | 
					
						
							|  |  |  |  |    "metadata": {}, | 
					
						
							|  |  |  |  |    "outputs": [], | 
					
						
							|  |  |  |  |    "source": [ | 
					
						
							|  |  |  |  |     "def get_default_device():\n", | 
					
						
							|  |  |  |  |     "    return torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" | 
					
						
							|  |  |  |  |    ] | 
					
						
							|  |  |  |  |   }, | 
					
						
							|  |  |  |  |   { | 
					
						
							|  |  |  |  |    "cell_type": "code", | 
					
						
							|  |  |  |  |    "execution_count": 21, | 
					
						
							|  |  |  |  |    "metadata": {}, | 
					
						
							|  |  |  |  |    "outputs": [], | 
					
						
							|  |  |  |  |    "source": [ | 
					
						
							|  |  |  |  |     "def save_images(images, path, **kwargs):\n", | 
					
						
							|  |  |  |  |     "    grid = make_grid(images, **kwargs)\n", | 
					
						
							|  |  |  |  |     "    ndarr = grid.permute(1,2,0).to(\"cpu\").numpy()\n", | 
					
						
							|  |  |  |  |     "    im = Image.fromarray(ndarr)\n", | 
					
						
							|  |  |  |  |     "    im.save(path)" | 
					
						
							|  |  |  |  |    ] | 
					
						
							|  |  |  |  |   }, | 
					
						
							|  |  |  |  |   { | 
					
						
							|  |  |  |  |    "cell_type": "code", | 
					
						
							|  |  |  |  |    "execution_count": 22, | 
					
						
							|  |  |  |  |    "metadata": {}, | 
					
						
							|  |  |  |  |    "outputs": [], | 
					
						
							|  |  |  |  |    "source": [ | 
					
						
							|  |  |  |  |     "def get(element: torch.Tensor, t: torch.Tensor):\n", | 
					
						
							|  |  |  |  |     "    \"\"\"\n", | 
					
						
							|  |  |  |  |     "    Get value at index position \"t\" in \"element\" and \n", | 
					
						
							|  |  |  |  |     "        reshape it to have the same dimension as a batch of images\n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "    获得在\"element\"中位置\"t\"并且reshape,以和一组照片有相同的维度\n", | 
					
						
							|  |  |  |  |     "    \"\"\"\n", | 
					
						
							|  |  |  |  |     "    ele = element.gather(-1, t)\n", | 
					
						
							|  |  |  |  |     "    return ele.reshape(-1, 1, 1, 1)" | 
					
						
							|  |  |  |  |    ] | 
					
						
							|  |  |  |  |   }, | 
					
						
							|  |  |  |  |   { | 
					
						
							|  |  |  |  |    "cell_type": "code", | 
					
						
							|  |  |  |  |    "execution_count": 23, | 
					
						
							|  |  |  |  |    "metadata": {}, | 
					
						
							|  |  |  |  |    "outputs": [], | 
					
						
							|  |  |  |  |    "source": [ | 
					
						
							|  |  |  |  |     "element = torch.tensor([[1,2,3,4,5],\n", | 
					
						
							|  |  |  |  |     "                        [2,3,4,5,6],\n", | 
					
						
							|  |  |  |  |     "                        [3,4,5,6,7]])\n", | 
					
						
							|  |  |  |  |     "t = torch.tensor([1,2,0]).unsqueeze(1)" | 
					
						
							|  |  |  |  |    ] | 
					
						
							|  |  |  |  |   }, | 
					
						
							|  |  |  |  |   { | 
					
						
							|  |  |  |  |    "cell_type": "code", | 
					
						
							|  |  |  |  |    "execution_count": 24, | 
					
						
							|  |  |  |  |    "metadata": {}, | 
					
						
							|  |  |  |  |    "outputs": [ | 
					
						
							|  |  |  |  |     { | 
					
						
							|  |  |  |  |      "name": "stdout", | 
					
						
							|  |  |  |  |      "output_type": "stream", | 
					
						
							|  |  |  |  |      "text": [ | 
					
						
							|  |  |  |  |       "tensor([[1, 2, 3, 4, 5],\n", | 
					
						
							|  |  |  |  |       "        [2, 3, 4, 5, 6],\n", | 
					
						
							|  |  |  |  |       "        [3, 4, 5, 6, 7]])\n", | 
					
						
							|  |  |  |  |       "tensor([[1],\n", | 
					
						
							|  |  |  |  |       "        [2],\n", | 
					
						
							|  |  |  |  |       "        [0]])\n" | 
					
						
							|  |  |  |  |      ] | 
					
						
							|  |  |  |  |     } | 
					
						
							|  |  |  |  |    ], | 
					
						
							|  |  |  |  |    "source": [ | 
					
						
							|  |  |  |  |     "print(element)\n", | 
					
						
							|  |  |  |  |     "print(t)" | 
					
						
							|  |  |  |  |    ] | 
					
						
							|  |  |  |  |   }, | 
					
						
							|  |  |  |  |   { | 
					
						
							|  |  |  |  |    "cell_type": "code", | 
					
						
							|  |  |  |  |    "execution_count": 25, | 
					
						
							| 
									
										
										
										
											2024-03-31 14:51:57 +02:00
										 |  |  |  |    "metadata": {}, | 
					
						
							|  |  |  |  |    "outputs": [ | 
					
						
							|  |  |  |  |     { | 
					
						
							| 
									
										
										
										
											2024-04-01 00:16:59 +02:00
										 |  |  |  |      "name": "stdout", | 
					
						
							|  |  |  |  |      "output_type": "stream", | 
					
						
							|  |  |  |  |      "text": [ | 
					
						
							|  |  |  |  |       "tensor([[[[2]]],\n", | 
					
						
							|  |  |  |  |       "\n", | 
					
						
							|  |  |  |  |       "\n", | 
					
						
							|  |  |  |  |       "        [[[4]]],\n", | 
					
						
							|  |  |  |  |       "\n", | 
					
						
							|  |  |  |  |       "\n", | 
					
						
							|  |  |  |  |       "        [[[3]]]])\n" | 
					
						
							| 
									
										
										
										
											2024-03-31 14:51:57 +02:00
										 |  |  |  |      ] | 
					
						
							|  |  |  |  |     } | 
					
						
							|  |  |  |  |    ], | 
					
						
							| 
									
										
										
										
											2024-04-01 00:16:59 +02:00
										 |  |  |  |    "source": [ | 
					
						
							|  |  |  |  |     "extracted_scores = get(element, t)\n", | 
					
						
							|  |  |  |  |     "print(extracted_scores)" | 
					
						
							|  |  |  |  |    ] | 
					
						
							|  |  |  |  |   }, | 
					
						
							|  |  |  |  |   { | 
					
						
							|  |  |  |  |    "cell_type": "code", | 
					
						
							|  |  |  |  |    "execution_count": 26, | 
					
						
							|  |  |  |  |    "metadata": {}, | 
					
						
							|  |  |  |  |    "outputs": [], | 
					
						
							|  |  |  |  |    "source": [ | 
					
						
							|  |  |  |  |     "def setup_log_directory(config):\n", | 
					
						
							|  |  |  |  |     "    \"\"\"\n", | 
					
						
							|  |  |  |  |     "    Log and Model checkpoint directory Setup\n", | 
					
						
							|  |  |  |  |     "    记录并且建模目录准备\n", | 
					
						
							|  |  |  |  |     "    \"\"\"\n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "    if os.path.isdir(config.root_log_dir):\n", | 
					
						
							|  |  |  |  |     "        # Get all folders numbers in the root_log_dir\n", | 
					
						
							|  |  |  |  |     "        # 在root_log_dir下获得所有文件夹数目\n", | 
					
						
							|  |  |  |  |     "        folder_numbers = [int(folder.replace(\"version_\", \"\")) for folder in os.listdir(config.root_log_dir)]\n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "        # Find the latest version number present in the log_dir\n", | 
					
						
							|  |  |  |  |     "        # 找到在log_dir下的最新版本数字\n", | 
					
						
							|  |  |  |  |     "        last_version_number = max(folder_numbers)\n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "        # New version name\n", | 
					
						
							|  |  |  |  |     "        version_name = f\"version{last_version_number + 1}\"\n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "    else:\n", | 
					
						
							|  |  |  |  |     "        version_name = config.log_dir\n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "    # Update the training config default directory\n", | 
					
						
							|  |  |  |  |     "    # 更新训练config默认目录\n", | 
					
						
							|  |  |  |  |     "    log_dir         = os.path.join(config.root_log_dir,         version_name)\n", | 
					
						
							|  |  |  |  |     "    checkpoint_dir  = os.path.join(config.root_checkpoint_dir,  version_name) \n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "    # Create new directory for saving new experiment version\n", | 
					
						
							|  |  |  |  |     "    # 创建一个新目录来保存新的实验版本\n", | 
					
						
							|  |  |  |  |     "    os.makedirs(log_dir,        exist_ok=True)\n", | 
					
						
							|  |  |  |  |     "    os.makedirs(checkpoint_dir, exist_ok=True)\n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "    print(f\"Logging at: {log_dir}\")\n", | 
					
						
							|  |  |  |  |     "    print(f\"Model Checkpoint at: {checkpoint_dir}\")\n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "    return log_dir, checkpoint_dir\n" | 
					
						
							|  |  |  |  |    ] | 
					
						
							|  |  |  |  |   }, | 
					
						
							|  |  |  |  |   { | 
					
						
							|  |  |  |  |    "cell_type": "code", | 
					
						
							|  |  |  |  |    "execution_count": 27, | 
					
						
							|  |  |  |  |    "metadata": {}, | 
					
						
							|  |  |  |  |    "outputs": [], | 
					
						
							| 
									
										
										
										
											2024-03-31 14:51:57 +02:00
										 |  |  |  |    "source": [ | 
					
						
							|  |  |  |  |     "from dataclasses import dataclass\n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "@dataclass\n", | 
					
						
							|  |  |  |  |     "class BaseConfig:\n", | 
					
						
							|  |  |  |  |     "    DEVICE = get_default_device()\n", | 
					
						
							|  |  |  |  |     "    DATASET = \"Flowers\" #MNIST \"cifar-10\" \"Flowers\"\n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "    # 记录推断日志信息并保存存档点\n", | 
					
						
							|  |  |  |  |     "    root_log_dir = os.path.join(\"Logs_Checkpoints\", \"Inference\")\n", | 
					
						
							|  |  |  |  |     "    root_checkpoint_dir = os.path.join(\"Logs_Checkpoints\",\"checkpoints\")\n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "    #目前的日志和存档点目录\n", | 
					
						
							|  |  |  |  |     "    log_dir = \"version_0\"\n", | 
					
						
							|  |  |  |  |     "    checkpoint_dir = \"version_0\"\n", | 
					
						
							|  |  |  |  |     "\n", | 
					
						
							|  |  |  |  |     "@dataclass\n", | 
					
						
							|  |  |  |  |     "class TrainingConfig:\n", | 
					
						
							|  |  |  |  |     "    TIMESTEPS = 1000\n", | 
					
						
							|  |  |  |  |     "    IMG_SHAPE = (1,32,32) if BaseConfig.DATASET == \"MNIST\" else (3,32,32)\n", | 
					
						
							|  |  |  |  |     "    NUM_EPOCHS = 800\n", | 
					
						
							|  |  |  |  |     "    BATCH_SIZE = 32\n", | 
					
						
							|  |  |  |  |     "    LR = 2e-4\n", | 
					
						
							|  |  |  |  |     "    NUM_WORKERS = 2" | 
					
						
							|  |  |  |  |    ] | 
					
						
							|  |  |  |  |   } | 
					
						
							|  |  |  |  |  ], | 
					
						
							|  |  |  |  |  "metadata": { | 
					
						
							|  |  |  |  |   "kernelspec": { | 
					
						
							|  |  |  |  |    "display_name": "DLML", | 
					
						
							|  |  |  |  |    "language": "python", | 
					
						
							|  |  |  |  |    "name": "python3" | 
					
						
							|  |  |  |  |   }, | 
					
						
							|  |  |  |  |   "language_info": { | 
					
						
							|  |  |  |  |    "codemirror_mode": { | 
					
						
							|  |  |  |  |     "name": "ipython", | 
					
						
							|  |  |  |  |     "version": 3 | 
					
						
							|  |  |  |  |    }, | 
					
						
							|  |  |  |  |    "file_extension": ".py", | 
					
						
							|  |  |  |  |    "mimetype": "text/x-python", | 
					
						
							|  |  |  |  |    "name": "python", | 
					
						
							|  |  |  |  |    "nbconvert_exporter": "python", | 
					
						
							|  |  |  |  |    "pygments_lexer": "ipython3", | 
					
						
							| 
									
										
										
										
											2024-04-01 00:16:59 +02:00
										 |  |  |  |    "version": "3.11.8" | 
					
						
							| 
									
										
										
										
											2024-03-31 14:51:57 +02:00
										 |  |  |  |   } | 
					
						
							|  |  |  |  |  }, | 
					
						
							|  |  |  |  |  "nbformat": 4, | 
					
						
							|  |  |  |  |  "nbformat_minor": 2 | 
					
						
							|  |  |  |  | } |