init diffusion
This commit is contained in:
		
							
								
								
									
										171
									
								
								main.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										171
									
								
								main.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,171 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "## Network Helper\n", | ||||
|     "\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 1, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import torch.nn as nn\n", | ||||
|     "import inspect\n", | ||||
|     "import torch\n", | ||||
|     "import math" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def exists(x):\n", | ||||
|     "    return x is not None\n", | ||||
|     "\n", | ||||
|     "def default(val, d):\n", | ||||
|     "    if exists(val):\n", | ||||
|     "        return val\n", | ||||
|     "    return d() if inspect.isfunction(d) else d\n", | ||||
|     "\n", | ||||
|     "class Residual(nn.Module):\n", | ||||
|     "    def __init__(self, fn):\n", | ||||
|     "        super().__init__()\n", | ||||
|     "        self.fn = fn\n", | ||||
|     "\n", | ||||
|     "    def forward(self, x, *args, **kwargs):\n", | ||||
|     "        return self.fn(x, *args, **kwargs) + x\n", | ||||
|     "\n", | ||||
|     "# 上采样(反卷积)\n", | ||||
|     "def Upsample(dim):\n", | ||||
|     "    return nn.ConvTranspose2d(dim, dim, 4, 2, 1)\n", | ||||
|     "\n", | ||||
|     "def Downsample(dim):\n", | ||||
|     "    return nn.Conv2d(dim, dim, 4, 2 ,1)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "## Positional embedding\n", | ||||
|     "\n", | ||||
|     "目的是让网络知道\n", | ||||
|     "当前是哪一个step. \n", | ||||
|     "ddpm采用正弦位置编码\n", | ||||
|     "\n", | ||||
|     "输入是shape为(batch_size, 1)的tensor, batch中每一个sample所处的t,并且将这个tensor转换为shape为(batch_size, dim)的tensor.\n", | ||||
|     "这个tensor会被加到每一个残差模块中.\n", | ||||
|     "\n", | ||||
|     "总之就是将$t$编码为embedding,和原本的输入一起进入网络,让网络“知道”当前的输入属于哪个step" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "class SinusolidalPositionEmbedding(nn.Module):\n", | ||||
|     "    def __init__(self, dim):\n", | ||||
|     "        super().__init__()\n", | ||||
|     "        self.dim = dim\n", | ||||
|     "\n", | ||||
|     "    def forward(self, time):\n", | ||||
|     "        device = time.device\n", | ||||
|     "        half_dim = self.dim // 2\n", | ||||
|     "        embeddings = math.log(10000) / (half_dim - 1)\n", | ||||
|     "        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)\n", | ||||
|     "        embeddings = time[:, :, None] * embeddings[None, None, :]\n", | ||||
|     "        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)\n", | ||||
|     "        return embeddings\n", | ||||
|     "        " | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "## ResNet/ConvNeXT block" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "class Block(nn.Module):\n", | ||||
|     "    def __init__(self, dim, dim_out, groups = 8):\n", | ||||
|     "        super().__init__()\n", | ||||
|     "        self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)\n", | ||||
|     "        self.norm = nn.GroupNorm(groups, dim_out)\n", | ||||
|     "        self.act = nn.SiLU()\n", | ||||
|     "    \n", | ||||
|     "    def forward(self, x, scale_shift = None):\n", | ||||
|     "        x = self.proj(x)\n", | ||||
|     "        x = self.norm(x)\n", | ||||
|     "\n", | ||||
|     "        if exists(scale_shift):\n", | ||||
|     "            scale, shift = scale_shift\n", | ||||
|     "            x = x * (scale + 1) + shift\n", | ||||
|     "\n", | ||||
|     "        x = self.act(x)\n", | ||||
|     "        return x\n", | ||||
|     "\n", | ||||
|     "class ResnetBlock(nn.Module):\n", | ||||
|     "    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):\n", | ||||
|     "        super().__init__()\n", | ||||
|     "        self.mlp = (\n", | ||||
|     "            nn.Sequential(\n", | ||||
|     "                nn.SiLU(), \n", | ||||
|     "                nn.Linear(time_emb_dim, dim_out)\n", | ||||
|     "            )\n", | ||||
|     "            if exists(time_emb_dim) else None\n", | ||||
|     "        )\n", | ||||
|     "        self.block1 = Block(dim, dim_out, groups=groups)\n", | ||||
|     "        self.block2 = Block(dim_out, dim_out=dim_out, groups=groups)\n", | ||||
|     "        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n", | ||||
|     "    \n", | ||||
|     "    def forward(self, x, time_emb = None):\n", | ||||
|     "        h = self.block1(x)\n", | ||||
|     "\n", | ||||
|     "        if exists(self.mlp) and exists(time_emb):\n", | ||||
|     "            time_emb = self.mlp(time_emb)\n", | ||||
|     "            h = rearrange(time_emb, 'b n -> b () n') + h\n", | ||||
|     "\n", | ||||
|     "        h = self.block2(h)\n", | ||||
|     "        return h + self.res_conv(x)\n", | ||||
|     "    \n", | ||||
|     "    " | ||||
|    ] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "arch2vec39", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.6.13" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 2 | ||||
| } | ||||
		Reference in New Issue
	
	Block a user