2024-03-31 14:51:57 +02:00
|
|
|
|
{
|
|
|
|
|
"cells": [
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-04-01 00:16:59 +02:00
|
|
|
|
"execution_count": 17,
|
2024-03-31 14:51:57 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"import gc\n",
|
|
|
|
|
"import os\n",
|
|
|
|
|
"import cv2\n",
|
|
|
|
|
"import math\n",
|
|
|
|
|
"import base64\n",
|
|
|
|
|
"import random\n",
|
|
|
|
|
"import numpy as np\n",
|
|
|
|
|
"from PIL import Image \n",
|
|
|
|
|
"from tqdm import tqdm\n",
|
|
|
|
|
"from datetime import datetime\n",
|
|
|
|
|
"import matplotlib.pyplot as plt\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"import torch\n",
|
|
|
|
|
"import torch.nn as nn\n",
|
|
|
|
|
"from torch.cuda import amp\n",
|
|
|
|
|
"import torch.nn.functional as F\n",
|
|
|
|
|
"from torch.optim import Adam, AdamW\n",
|
|
|
|
|
"from torch.utils.data import Dataset, DataLoader\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"import torchvision\n",
|
|
|
|
|
"import torchvision.transforms as TF\n",
|
|
|
|
|
"import torchvision.datasets as datasets\n",
|
|
|
|
|
"from torchvision.utils import make_grid\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"from torchmetrics import MeanMetric\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"from IPython.display import display, HTML, clear_output\n"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"## Helper functions"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-04-01 00:16:59 +02:00
|
|
|
|
"execution_count": 18,
|
2024-03-31 14:51:57 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def to_device(data, device):\n",
|
|
|
|
|
" \"\"\"将张量移动到选择的设备\"\"\"\n",
|
|
|
|
|
" \"\"\"Move tensor(s) to chosen device\"\"\"\n",
|
|
|
|
|
" if isinstance(data, (list, tuple)):\n",
|
|
|
|
|
" return [to_device(x, device) for x in data]\n",
|
|
|
|
|
" return data.to(device, non_blocking=true)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-04-01 00:16:59 +02:00
|
|
|
|
"execution_count": 19,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"class DeviceDataLoader:\n",
|
|
|
|
|
" \"\"\"包装一个数据加载器,来把数据移动到另一个设备上\"\"\"\n",
|
|
|
|
|
" \"\"\"Wrap a dataloader to move data to a device\"\"\"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" def __init__(self, dl, device):\n",
|
|
|
|
|
" self.dl = dl\n",
|
|
|
|
|
" self.device = device\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" def __iter__(self):\n",
|
|
|
|
|
" \"\"\"在移动到设备后生成一个批次的数据\"\"\"\n",
|
|
|
|
|
" \"\"\"Yield a batch of data after moving it to device\"\"\"\n",
|
|
|
|
|
" for b in self.dl:\n",
|
|
|
|
|
" yield to_device(b, self.device)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" def __len__(self):\n",
|
|
|
|
|
" \"\"\"批次的数量\"\"\"\n",
|
|
|
|
|
" \"\"\"Number of batches\"\"\"\n",
|
|
|
|
|
" return len(self.dl)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 20,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def get_default_device():\n",
|
|
|
|
|
" return torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 21,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def save_images(images, path, **kwargs):\n",
|
|
|
|
|
" grid = make_grid(images, **kwargs)\n",
|
|
|
|
|
" ndarr = grid.permute(1,2,0).to(\"cpu\").numpy()\n",
|
|
|
|
|
" im = Image.fromarray(ndarr)\n",
|
|
|
|
|
" im.save(path)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 22,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def get(element: torch.Tensor, t: torch.Tensor):\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
" Get value at index position \"t\" in \"element\" and \n",
|
|
|
|
|
" reshape it to have the same dimension as a batch of images\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" 获得在\"element\"中位置\"t\"并且reshape,以和一组照片有相同的维度\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
" ele = element.gather(-1, t)\n",
|
|
|
|
|
" return ele.reshape(-1, 1, 1, 1)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 23,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"element = torch.tensor([[1,2,3,4,5],\n",
|
|
|
|
|
" [2,3,4,5,6],\n",
|
|
|
|
|
" [3,4,5,6,7]])\n",
|
|
|
|
|
"t = torch.tensor([1,2,0]).unsqueeze(1)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 24,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"tensor([[1, 2, 3, 4, 5],\n",
|
|
|
|
|
" [2, 3, 4, 5, 6],\n",
|
|
|
|
|
" [3, 4, 5, 6, 7]])\n",
|
|
|
|
|
"tensor([[1],\n",
|
|
|
|
|
" [2],\n",
|
|
|
|
|
" [0]])\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"print(element)\n",
|
|
|
|
|
"print(t)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 25,
|
2024-03-31 14:51:57 +02:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
2024-04-01 00:16:59 +02:00
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"tensor([[[[2]]],\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" [[[4]]],\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" [[[3]]]])\n"
|
2024-03-31 14:51:57 +02:00
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
2024-04-01 00:16:59 +02:00
|
|
|
|
"source": [
|
|
|
|
|
"extracted_scores = get(element, t)\n",
|
|
|
|
|
"print(extracted_scores)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 26,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def setup_log_directory(config):\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
" Log and Model checkpoint directory Setup\n",
|
|
|
|
|
" 记录并且建模目录准备\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" if os.path.isdir(config.root_log_dir):\n",
|
|
|
|
|
" # Get all folders numbers in the root_log_dir\n",
|
|
|
|
|
" # 在root_log_dir下获得所有文件夹数目\n",
|
|
|
|
|
" folder_numbers = [int(folder.replace(\"version_\", \"\")) for folder in os.listdir(config.root_log_dir)]\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" # Find the latest version number present in the log_dir\n",
|
|
|
|
|
" # 找到在log_dir下的最新版本数字\n",
|
|
|
|
|
" last_version_number = max(folder_numbers)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" # New version name\n",
|
|
|
|
|
" version_name = f\"version{last_version_number + 1}\"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" else:\n",
|
|
|
|
|
" version_name = config.log_dir\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" # Update the training config default directory\n",
|
|
|
|
|
" # 更新训练config默认目录\n",
|
|
|
|
|
" log_dir = os.path.join(config.root_log_dir, version_name)\n",
|
|
|
|
|
" checkpoint_dir = os.path.join(config.root_checkpoint_dir, version_name) \n",
|
|
|
|
|
"\n",
|
|
|
|
|
" # Create new directory for saving new experiment version\n",
|
|
|
|
|
" # 创建一个新目录来保存新的实验版本\n",
|
|
|
|
|
" os.makedirs(log_dir, exist_ok=True)\n",
|
|
|
|
|
" os.makedirs(checkpoint_dir, exist_ok=True)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" print(f\"Logging at: {log_dir}\")\n",
|
|
|
|
|
" print(f\"Model Checkpoint at: {checkpoint_dir}\")\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" return log_dir, checkpoint_dir\n"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 27,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
2024-03-31 14:51:57 +02:00
|
|
|
|
"source": [
|
|
|
|
|
"from dataclasses import dataclass\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"@dataclass\n",
|
|
|
|
|
"class BaseConfig:\n",
|
|
|
|
|
" DEVICE = get_default_device()\n",
|
|
|
|
|
" DATASET = \"Flowers\" #MNIST \"cifar-10\" \"Flowers\"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" # 记录推断日志信息并保存存档点\n",
|
|
|
|
|
" root_log_dir = os.path.join(\"Logs_Checkpoints\", \"Inference\")\n",
|
|
|
|
|
" root_checkpoint_dir = os.path.join(\"Logs_Checkpoints\",\"checkpoints\")\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" #目前的日志和存档点目录\n",
|
|
|
|
|
" log_dir = \"version_0\"\n",
|
|
|
|
|
" checkpoint_dir = \"version_0\"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"@dataclass\n",
|
|
|
|
|
"class TrainingConfig:\n",
|
|
|
|
|
" TIMESTEPS = 1000\n",
|
|
|
|
|
" IMG_SHAPE = (1,32,32) if BaseConfig.DATASET == \"MNIST\" else (3,32,32)\n",
|
|
|
|
|
" NUM_EPOCHS = 800\n",
|
|
|
|
|
" BATCH_SIZE = 32\n",
|
|
|
|
|
" LR = 2e-4\n",
|
|
|
|
|
" NUM_WORKERS = 2"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"metadata": {
|
|
|
|
|
"kernelspec": {
|
|
|
|
|
"display_name": "DLML",
|
|
|
|
|
"language": "python",
|
|
|
|
|
"name": "python3"
|
|
|
|
|
},
|
|
|
|
|
"language_info": {
|
|
|
|
|
"codemirror_mode": {
|
|
|
|
|
"name": "ipython",
|
|
|
|
|
"version": 3
|
|
|
|
|
},
|
|
|
|
|
"file_extension": ".py",
|
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
|
"name": "python",
|
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
|
"pygments_lexer": "ipython3",
|
2024-04-01 00:16:59 +02:00
|
|
|
|
"version": "3.11.8"
|
2024-03-31 14:51:57 +02:00
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"nbformat": 4,
|
|
|
|
|
"nbformat_minor": 2
|
|
|
|
|
}
|