From 0436bbe2116116ef5609794aa7077b2702b7a87a Mon Sep 17 00:00:00 2001 From: Hanzhang ma Date: Sun, 31 Mar 2024 14:51:57 +0200 Subject: [PATCH] add diffusion.ipynb --- diffusion.ipynb | 125 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 diffusion.ipynb diff --git a/diffusion.ipynb b/diffusion.ipynb new file mode 100644 index 0000000..ff5c8a2 --- /dev/null +++ b/diffusion.ipynb @@ -0,0 +1,125 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import gc\n", + "import os\n", + "import cv2\n", + "import math\n", + "import base64\n", + "import random\n", + "import numpy as np\n", + "from PIL import Image \n", + "from tqdm import tqdm\n", + "from datetime import datetime\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.cuda import amp\n", + "import torch.nn.functional as F\n", + "from torch.optim import Adam, AdamW\n", + "from torch.utils.data import Dataset, DataLoader\n", + "\n", + "import torchvision\n", + "import torchvision.transforms as TF\n", + "import torchvision.datasets as datasets\n", + "from torchvision.utils import make_grid\n", + "\n", + "from torchmetrics import MeanMetric\n", + "\n", + "from IPython.display import display, HTML, clear_output\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Helper functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def to_device(data, device):\n", + " \"\"\"将张量移动到选择的设备\"\"\"\n", + " \"\"\"Move tensor(s) to chosen device\"\"\"\n", + " if isinstance(data, (list, tuple)):\n", + " return [to_device(x, device) for x in data]\n", + " return data.to(device, non_blocking=true)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'get_default_device' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[5], line 4\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdataclasses\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m dataclass\n\u001b[0;32m 3\u001b[0m \u001b[38;5;129;43m@dataclass\u001b[39;49m\n\u001b[1;32m----> 4\u001b[0m \u001b[38;5;28;43;01mclass\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;21;43;01mBaseConfig\u001b[39;49;00m\u001b[43m:\u001b[49m\n\u001b[0;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mDEVICE\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mget_default_device\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mDATASET\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mFlowers\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m#MNIST \"cifar-10\" \"Flowers\"\u001b[39;49;00m\n", + "Cell \u001b[1;32mIn[5], line 5\u001b[0m, in \u001b[0;36mBaseConfig\u001b[1;34m()\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[38;5;129m@dataclass\u001b[39m\n\u001b[0;32m 4\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mBaseConfig\u001b[39;00m:\n\u001b[1;32m----> 5\u001b[0m DEVICE \u001b[38;5;241m=\u001b[39m \u001b[43mget_default_device\u001b[49m()\n\u001b[0;32m 6\u001b[0m DATASET \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFlowers\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;66;03m#MNIST \"cifar-10\" \"Flowers\"\u001b[39;00m\n\u001b[0;32m 8\u001b[0m \u001b[38;5;66;03m# 记录推断日志信息并保存存档点\u001b[39;00m\n", + "\u001b[1;31mNameError\u001b[0m: name 'get_default_device' is not defined" + ] + } + ], + "source": [ + "from dataclasses import dataclass\n", + "\n", + "@dataclass\n", + "class BaseConfig:\n", + " DEVICE = get_default_device()\n", + " DATASET = \"Flowers\" #MNIST \"cifar-10\" \"Flowers\"\n", + "\n", + " # 记录推断日志信息并保存存档点\n", + " root_log_dir = os.path.join(\"Logs_Checkpoints\", \"Inference\")\n", + " root_checkpoint_dir = os.path.join(\"Logs_Checkpoints\",\"checkpoints\")\n", + "\n", + " #目前的日志和存档点目录\n", + " log_dir = \"version_0\"\n", + " checkpoint_dir = \"version_0\"\n", + "\n", + "@dataclass\n", + "class TrainingConfig:\n", + " TIMESTEPS = 1000\n", + " IMG_SHAPE = (1,32,32) if BaseConfig.DATASET == \"MNIST\" else (3,32,32)\n", + " NUM_EPOCHS = 800\n", + " BATCH_SIZE = 32\n", + " LR = 2e-4\n", + " NUM_WORKERS = 2" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "DLML", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}