add helper function(1)
This commit is contained in:
		
							
								
								
									
										181
									
								
								diffusion.ipynb
									
									
									
									
									
								
							
							
						
						
									
										181
									
								
								diffusion.ipynb
									
									
									
									
									
								
							| @@ -2,7 +2,7 @@ | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 4, | ||||
|    "execution_count": 17, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
| @@ -44,7 +44,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "execution_count": 18, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
| @@ -58,22 +58,177 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 5, | ||||
|    "execution_count": 19, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "class DeviceDataLoader:\n", | ||||
|     "    \"\"\"包装一个数据加载器,来把数据移动到另一个设备上\"\"\"\n", | ||||
|     "    \"\"\"Wrap a dataloader to move data to a device\"\"\"\n", | ||||
|     "\n", | ||||
|     "    def __init__(self, dl, device):\n", | ||||
|     "        self.dl = dl\n", | ||||
|     "        self.device = device\n", | ||||
|     "\n", | ||||
|     "    def __iter__(self):\n", | ||||
|     "        \"\"\"在移动到设备后生成一个批次的数据\"\"\"\n", | ||||
|     "        \"\"\"Yield a batch of data after moving it to device\"\"\"\n", | ||||
|     "        for b in self.dl:\n", | ||||
|     "            yield to_device(b, self.device)\n", | ||||
|     "\n", | ||||
|     "    def __len__(self):\n", | ||||
|     "        \"\"\"批次的数量\"\"\"\n", | ||||
|     "        \"\"\"Number of batches\"\"\"\n", | ||||
|     "        return len(self.dl)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 20, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def get_default_device():\n", | ||||
|     "    return torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 21, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def save_images(images, path, **kwargs):\n", | ||||
|     "    grid = make_grid(images, **kwargs)\n", | ||||
|     "    ndarr = grid.permute(1,2,0).to(\"cpu\").numpy()\n", | ||||
|     "    im = Image.fromarray(ndarr)\n", | ||||
|     "    im.save(path)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 22, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def get(element: torch.Tensor, t: torch.Tensor):\n", | ||||
|     "    \"\"\"\n", | ||||
|     "    Get value at index position \"t\" in \"element\" and \n", | ||||
|     "        reshape it to have the same dimension as a batch of images\n", | ||||
|     "\n", | ||||
|     "    获得在\"element\"中位置\"t\"并且reshape,以和一组照片有相同的维度\n", | ||||
|     "    \"\"\"\n", | ||||
|     "    ele = element.gather(-1, t)\n", | ||||
|     "    return ele.reshape(-1, 1, 1, 1)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 23, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "element = torch.tensor([[1,2,3,4,5],\n", | ||||
|     "                        [2,3,4,5,6],\n", | ||||
|     "                        [3,4,5,6,7]])\n", | ||||
|     "t = torch.tensor([1,2,0]).unsqueeze(1)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 24, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "ename": "NameError", | ||||
|      "evalue": "name 'get_default_device' is not defined", | ||||
|      "output_type": "error", | ||||
|      "traceback": [ | ||||
|       "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", | ||||
|       "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)", | ||||
|       "Cell \u001b[1;32mIn[5], line 4\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdataclasses\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m dataclass\n\u001b[0;32m      3\u001b[0m \u001b[38;5;129;43m@dataclass\u001b[39;49m\n\u001b[1;32m----> 4\u001b[0m \u001b[38;5;28;43;01mclass\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;21;43;01mBaseConfig\u001b[39;49;00m\u001b[43m:\u001b[49m\n\u001b[0;32m      5\u001b[0m \u001b[43m    \u001b[49m\u001b[43mDEVICE\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mget_default_device\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m      6\u001b[0m \u001b[43m    \u001b[49m\u001b[43mDATASET\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mFlowers\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m#MNIST \"cifar-10\" \"Flowers\"\u001b[39;49;00m\n", | ||||
|       "Cell \u001b[1;32mIn[5], line 5\u001b[0m, in \u001b[0;36mBaseConfig\u001b[1;34m()\u001b[0m\n\u001b[0;32m      3\u001b[0m \u001b[38;5;129m@dataclass\u001b[39m\n\u001b[0;32m      4\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mBaseConfig\u001b[39;00m:\n\u001b[1;32m----> 5\u001b[0m     DEVICE \u001b[38;5;241m=\u001b[39m \u001b[43mget_default_device\u001b[49m()\n\u001b[0;32m      6\u001b[0m     DATASET \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFlowers\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;66;03m#MNIST \"cifar-10\" \"Flowers\"\u001b[39;00m\n\u001b[0;32m      8\u001b[0m     \u001b[38;5;66;03m# 记录推断日志信息并保存存档点\u001b[39;00m\n", | ||||
|       "\u001b[1;31mNameError\u001b[0m: name 'get_default_device' is not defined" | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "tensor([[1, 2, 3, 4, 5],\n", | ||||
|       "        [2, 3, 4, 5, 6],\n", | ||||
|       "        [3, 4, 5, 6, 7]])\n", | ||||
|       "tensor([[1],\n", | ||||
|       "        [2],\n", | ||||
|       "        [0]])\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "print(element)\n", | ||||
|     "print(t)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 25, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "tensor([[[[2]]],\n", | ||||
|       "\n", | ||||
|       "\n", | ||||
|       "        [[[4]]],\n", | ||||
|       "\n", | ||||
|       "\n", | ||||
|       "        [[[3]]]])\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "extracted_scores = get(element, t)\n", | ||||
|     "print(extracted_scores)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 26, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def setup_log_directory(config):\n", | ||||
|     "    \"\"\"\n", | ||||
|     "    Log and Model checkpoint directory Setup\n", | ||||
|     "    记录并且建模目录准备\n", | ||||
|     "    \"\"\"\n", | ||||
|     "\n", | ||||
|     "    if os.path.isdir(config.root_log_dir):\n", | ||||
|     "        # Get all folders numbers in the root_log_dir\n", | ||||
|     "        # 在root_log_dir下获得所有文件夹数目\n", | ||||
|     "        folder_numbers = [int(folder.replace(\"version_\", \"\")) for folder in os.listdir(config.root_log_dir)]\n", | ||||
|     "\n", | ||||
|     "        # Find the latest version number present in the log_dir\n", | ||||
|     "        # 找到在log_dir下的最新版本数字\n", | ||||
|     "        last_version_number = max(folder_numbers)\n", | ||||
|     "\n", | ||||
|     "        # New version name\n", | ||||
|     "        version_name = f\"version{last_version_number + 1}\"\n", | ||||
|     "\n", | ||||
|     "    else:\n", | ||||
|     "        version_name = config.log_dir\n", | ||||
|     "\n", | ||||
|     "    # Update the training config default directory\n", | ||||
|     "    # 更新训练config默认目录\n", | ||||
|     "    log_dir         = os.path.join(config.root_log_dir,         version_name)\n", | ||||
|     "    checkpoint_dir  = os.path.join(config.root_checkpoint_dir,  version_name) \n", | ||||
|     "\n", | ||||
|     "    # Create new directory for saving new experiment version\n", | ||||
|     "    # 创建一个新目录来保存新的实验版本\n", | ||||
|     "    os.makedirs(log_dir,        exist_ok=True)\n", | ||||
|     "    os.makedirs(checkpoint_dir, exist_ok=True)\n", | ||||
|     "\n", | ||||
|     "    print(f\"Logging at: {log_dir}\")\n", | ||||
|     "    print(f\"Model Checkpoint at: {checkpoint_dir}\")\n", | ||||
|     "\n", | ||||
|     "    return log_dir, checkpoint_dir\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 27, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from dataclasses import dataclass\n", | ||||
|     "\n", | ||||
| @@ -117,7 +272,7 @@ | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.11.0" | ||||
|    "version": "3.11.8" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user