add diffusion.ipynb
This commit is contained in:
		
							
								
								
									
										125
									
								
								diffusion.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										125
									
								
								diffusion.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,125 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 4, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import gc\n", | ||||
|     "import os\n", | ||||
|     "import cv2\n", | ||||
|     "import math\n", | ||||
|     "import base64\n", | ||||
|     "import random\n", | ||||
|     "import numpy as np\n", | ||||
|     "from PIL import Image \n", | ||||
|     "from tqdm import tqdm\n", | ||||
|     "from datetime import datetime\n", | ||||
|     "import matplotlib.pyplot as plt\n", | ||||
|     "\n", | ||||
|     "import torch\n", | ||||
|     "import torch.nn as nn\n", | ||||
|     "from torch.cuda import amp\n", | ||||
|     "import torch.nn.functional as F\n", | ||||
|     "from torch.optim import Adam, AdamW\n", | ||||
|     "from torch.utils.data import Dataset, DataLoader\n", | ||||
|     "\n", | ||||
|     "import torchvision\n", | ||||
|     "import torchvision.transforms as TF\n", | ||||
|     "import torchvision.datasets as datasets\n", | ||||
|     "from torchvision.utils import make_grid\n", | ||||
|     "\n", | ||||
|     "from torchmetrics import MeanMetric\n", | ||||
|     "\n", | ||||
|     "from IPython.display import display, HTML, clear_output\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "## Helper functions" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def to_device(data, device):\n", | ||||
|     "    \"\"\"将张量移动到选择的设备\"\"\"\n", | ||||
|     "    \"\"\"Move tensor(s) to chosen device\"\"\"\n", | ||||
|     "    if isinstance(data, (list, tuple)):\n", | ||||
|     "        return [to_device(x, device) for x in data]\n", | ||||
|     "    return data.to(device, non_blocking=true)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 5, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "ename": "NameError", | ||||
|      "evalue": "name 'get_default_device' is not defined", | ||||
|      "output_type": "error", | ||||
|      "traceback": [ | ||||
|       "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", | ||||
|       "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)", | ||||
|       "Cell \u001b[1;32mIn[5], line 4\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdataclasses\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m dataclass\n\u001b[0;32m      3\u001b[0m \u001b[38;5;129;43m@dataclass\u001b[39;49m\n\u001b[1;32m----> 4\u001b[0m \u001b[38;5;28;43;01mclass\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;21;43;01mBaseConfig\u001b[39;49;00m\u001b[43m:\u001b[49m\n\u001b[0;32m      5\u001b[0m \u001b[43m    \u001b[49m\u001b[43mDEVICE\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mget_default_device\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m      6\u001b[0m \u001b[43m    \u001b[49m\u001b[43mDATASET\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mFlowers\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m#MNIST \"cifar-10\" \"Flowers\"\u001b[39;49;00m\n", | ||||
|       "Cell \u001b[1;32mIn[5], line 5\u001b[0m, in \u001b[0;36mBaseConfig\u001b[1;34m()\u001b[0m\n\u001b[0;32m      3\u001b[0m \u001b[38;5;129m@dataclass\u001b[39m\n\u001b[0;32m      4\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mBaseConfig\u001b[39;00m:\n\u001b[1;32m----> 5\u001b[0m     DEVICE \u001b[38;5;241m=\u001b[39m \u001b[43mget_default_device\u001b[49m()\n\u001b[0;32m      6\u001b[0m     DATASET \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFlowers\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;66;03m#MNIST \"cifar-10\" \"Flowers\"\u001b[39;00m\n\u001b[0;32m      8\u001b[0m     \u001b[38;5;66;03m# 记录推断日志信息并保存存档点\u001b[39;00m\n", | ||||
|       "\u001b[1;31mNameError\u001b[0m: name 'get_default_device' is not defined" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "from dataclasses import dataclass\n", | ||||
|     "\n", | ||||
|     "@dataclass\n", | ||||
|     "class BaseConfig:\n", | ||||
|     "    DEVICE = get_default_device()\n", | ||||
|     "    DATASET = \"Flowers\" #MNIST \"cifar-10\" \"Flowers\"\n", | ||||
|     "\n", | ||||
|     "    # 记录推断日志信息并保存存档点\n", | ||||
|     "    root_log_dir = os.path.join(\"Logs_Checkpoints\", \"Inference\")\n", | ||||
|     "    root_checkpoint_dir = os.path.join(\"Logs_Checkpoints\",\"checkpoints\")\n", | ||||
|     "\n", | ||||
|     "    #目前的日志和存档点目录\n", | ||||
|     "    log_dir = \"version_0\"\n", | ||||
|     "    checkpoint_dir = \"version_0\"\n", | ||||
|     "\n", | ||||
|     "@dataclass\n", | ||||
|     "class TrainingConfig:\n", | ||||
|     "    TIMESTEPS = 1000\n", | ||||
|     "    IMG_SHAPE = (1,32,32) if BaseConfig.DATASET == \"MNIST\" else (3,32,32)\n", | ||||
|     "    NUM_EPOCHS = 800\n", | ||||
|     "    BATCH_SIZE = 32\n", | ||||
|     "    LR = 2e-4\n", | ||||
|     "    NUM_WORKERS = 2" | ||||
|    ] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "DLML", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.11.0" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 2 | ||||
| } | ||||
		Reference in New Issue
	
	Block a user