init diffusion
This commit is contained in:
		
							
								
								
									
										171
									
								
								main.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										171
									
								
								main.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,171 @@ | |||||||
|  | { | ||||||
|  |  "cells": [ | ||||||
|  |   { | ||||||
|  |    "cell_type": "markdown", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "source": [ | ||||||
|  |     "## Network Helper\n", | ||||||
|  |     "\n" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 1, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "import torch.nn as nn\n", | ||||||
|  |     "import inspect\n", | ||||||
|  |     "import torch\n", | ||||||
|  |     "import math" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 2, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "def exists(x):\n", | ||||||
|  |     "    return x is not None\n", | ||||||
|  |     "\n", | ||||||
|  |     "def default(val, d):\n", | ||||||
|  |     "    if exists(val):\n", | ||||||
|  |     "        return val\n", | ||||||
|  |     "    return d() if inspect.isfunction(d) else d\n", | ||||||
|  |     "\n", | ||||||
|  |     "class Residual(nn.Module):\n", | ||||||
|  |     "    def __init__(self, fn):\n", | ||||||
|  |     "        super().__init__()\n", | ||||||
|  |     "        self.fn = fn\n", | ||||||
|  |     "\n", | ||||||
|  |     "    def forward(self, x, *args, **kwargs):\n", | ||||||
|  |     "        return self.fn(x, *args, **kwargs) + x\n", | ||||||
|  |     "\n", | ||||||
|  |     "# 上采样(反卷积)\n", | ||||||
|  |     "def Upsample(dim):\n", | ||||||
|  |     "    return nn.ConvTranspose2d(dim, dim, 4, 2, 1)\n", | ||||||
|  |     "\n", | ||||||
|  |     "def Downsample(dim):\n", | ||||||
|  |     "    return nn.Conv2d(dim, dim, 4, 2 ,1)" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "markdown", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "source": [ | ||||||
|  |     "## Positional embedding\n", | ||||||
|  |     "\n", | ||||||
|  |     "目的是让网络知道\n", | ||||||
|  |     "当前是哪一个step. \n", | ||||||
|  |     "ddpm采用正弦位置编码\n", | ||||||
|  |     "\n", | ||||||
|  |     "输入是shape为(batch_size, 1)的tensor, batch中每一个sample所处的t,并且将这个tensor转换为shape为(batch_size, dim)的tensor.\n", | ||||||
|  |     "这个tensor会被加到每一个残差模块中.\n", | ||||||
|  |     "\n", | ||||||
|  |     "总之就是将$t$编码为embedding,和原本的输入一起进入网络,让网络“知道”当前的输入属于哪个step" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "class SinusolidalPositionEmbedding(nn.Module):\n", | ||||||
|  |     "    def __init__(self, dim):\n", | ||||||
|  |     "        super().__init__()\n", | ||||||
|  |     "        self.dim = dim\n", | ||||||
|  |     "\n", | ||||||
|  |     "    def forward(self, time):\n", | ||||||
|  |     "        device = time.device\n", | ||||||
|  |     "        half_dim = self.dim // 2\n", | ||||||
|  |     "        embeddings = math.log(10000) / (half_dim - 1)\n", | ||||||
|  |     "        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)\n", | ||||||
|  |     "        embeddings = time[:, :, None] * embeddings[None, None, :]\n", | ||||||
|  |     "        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)\n", | ||||||
|  |     "        return embeddings\n", | ||||||
|  |     "        " | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "markdown", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "source": [ | ||||||
|  |     "## ResNet/ConvNeXT block" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "class Block(nn.Module):\n", | ||||||
|  |     "    def __init__(self, dim, dim_out, groups = 8):\n", | ||||||
|  |     "        super().__init__()\n", | ||||||
|  |     "        self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)\n", | ||||||
|  |     "        self.norm = nn.GroupNorm(groups, dim_out)\n", | ||||||
|  |     "        self.act = nn.SiLU()\n", | ||||||
|  |     "    \n", | ||||||
|  |     "    def forward(self, x, scale_shift = None):\n", | ||||||
|  |     "        x = self.proj(x)\n", | ||||||
|  |     "        x = self.norm(x)\n", | ||||||
|  |     "\n", | ||||||
|  |     "        if exists(scale_shift):\n", | ||||||
|  |     "            scale, shift = scale_shift\n", | ||||||
|  |     "            x = x * (scale + 1) + shift\n", | ||||||
|  |     "\n", | ||||||
|  |     "        x = self.act(x)\n", | ||||||
|  |     "        return x\n", | ||||||
|  |     "\n", | ||||||
|  |     "class ResnetBlock(nn.Module):\n", | ||||||
|  |     "    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):\n", | ||||||
|  |     "        super().__init__()\n", | ||||||
|  |     "        self.mlp = (\n", | ||||||
|  |     "            nn.Sequential(\n", | ||||||
|  |     "                nn.SiLU(), \n", | ||||||
|  |     "                nn.Linear(time_emb_dim, dim_out)\n", | ||||||
|  |     "            )\n", | ||||||
|  |     "            if exists(time_emb_dim) else None\n", | ||||||
|  |     "        )\n", | ||||||
|  |     "        self.block1 = Block(dim, dim_out, groups=groups)\n", | ||||||
|  |     "        self.block2 = Block(dim_out, dim_out=dim_out, groups=groups)\n", | ||||||
|  |     "        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n", | ||||||
|  |     "    \n", | ||||||
|  |     "    def forward(self, x, time_emb = None):\n", | ||||||
|  |     "        h = self.block1(x)\n", | ||||||
|  |     "\n", | ||||||
|  |     "        if exists(self.mlp) and exists(time_emb):\n", | ||||||
|  |     "            time_emb = self.mlp(time_emb)\n", | ||||||
|  |     "            h = rearrange(time_emb, 'b n -> b () n') + h\n", | ||||||
|  |     "\n", | ||||||
|  |     "        h = self.block2(h)\n", | ||||||
|  |     "        return h + self.res_conv(x)\n", | ||||||
|  |     "    \n", | ||||||
|  |     "    " | ||||||
|  |    ] | ||||||
|  |   } | ||||||
|  |  ], | ||||||
|  |  "metadata": { | ||||||
|  |   "kernelspec": { | ||||||
|  |    "display_name": "arch2vec39", | ||||||
|  |    "language": "python", | ||||||
|  |    "name": "python3" | ||||||
|  |   }, | ||||||
|  |   "language_info": { | ||||||
|  |    "codemirror_mode": { | ||||||
|  |     "name": "ipython", | ||||||
|  |     "version": 3 | ||||||
|  |    }, | ||||||
|  |    "file_extension": ".py", | ||||||
|  |    "mimetype": "text/x-python", | ||||||
|  |    "name": "python", | ||||||
|  |    "nbconvert_exporter": "python", | ||||||
|  |    "pygments_lexer": "ipython3", | ||||||
|  |    "version": "3.6.13" | ||||||
|  |   } | ||||||
|  |  }, | ||||||
|  |  "nbformat": 4, | ||||||
|  |  "nbformat_minor": 2 | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user