for notes
This commit is contained in:
		
							
								
								
									
										8
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1,8 +1,10 @@ | |||||||
| ./flowers/* | ./flowers/* | ||||||
| .DS_Store | .DS_Store | ||||||
| ./UNet/train_image/* | UNet/train_image/* | ||||||
| ./UNet/params/* | UNet/params/* | ||||||
| ./UNet/__pycache__/* | UNet/__pycache__/* | ||||||
|  | UNet/test_image | ||||||
| data/ | data/ | ||||||
| archive.zip | archive.zip | ||||||
| flowers/* | flowers/* | ||||||
|  | UNet/result/result.jpg | ||||||
|   | |||||||
| @@ -7,9 +7,9 @@ from net import * | |||||||
| from torchvision.utils import save_image | from torchvision.utils import save_image | ||||||
|  |  | ||||||
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||||||
| weight_path = r'/Users/hanzhangma/Nextcloud/mhz/Study/SS24/MasterThesis/UNet/params/unet.pth' | weight_path = r'D:\\MasterThesis\\UNet\\params\\unet.pth' | ||||||
| data_path = r'/Users/hanzhangma/Document/DataSet/VOC2007' | data_path = r'D:\\MasterThesis\\data\\VOCdevkit\\VOC2007' | ||||||
| save_path = r'/Users/hanzhangma/Nextcloud/mhz/Study/SS24/MasterThesis/Unet/train_image' | save_path = r'D:\\MasterThesis\\UNet\\train_image' | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     data_loader = DataLoader(MyDataset(data_path), batch_size= 4, shuffle=True) |     data_loader = DataLoader(MyDataset(data_path), batch_size= 4, shuffle=True) | ||||||
|   | |||||||
| @@ -665,6 +665,7 @@ | |||||||
|     "\n", |     "\n", | ||||||
|     "        num_resolutions = len(base_channels_multiples)\n", |     "        num_resolutions = len(base_channels_multiples)\n", | ||||||
|     "\n", |     "\n", | ||||||
|  |     "        # encoder blocks = resnetblock * 3 + \n", | ||||||
|     "        self.encoder_blocks = nn.ModuleList()\n", |     "        self.encoder_blocks = nn.ModuleList()\n", | ||||||
|     "        curr_channels = [base_channels]\n", |     "        curr_channels = [base_channels]\n", | ||||||
|     "        in_channels = base_channels\n", |     "        in_channels = base_channels\n", | ||||||
| @@ -799,6 +800,7 @@ | |||||||
|     "        self.sqrt_one_minus_alpha_cumulative = torch.sqrt(1-self.alpha_cumulative)\n", |     "        self.sqrt_one_minus_alpha_cumulative = torch.sqrt(1-self.alpha_cumulative)\n", | ||||||
|     "\n", |     "\n", | ||||||
|     "    def get_betas(self):\n", |     "    def get_betas(self):\n", | ||||||
|  |     "        \"\"\"linear schedule, proposed in original ddpm paper 线性在原ddpm论文中提出\"\"\"\n", | ||||||
|     "        scale = 1000 / self.num_diffusion_timesteps\n", |     "        scale = 1000 / self.num_diffusion_timesteps\n", | ||||||
|     "        beta_start = scale * 1e-4\n", |     "        beta_start = scale * 1e-4\n", | ||||||
|     "        beta_end = scale * 0.02\n", |     "        beta_end = scale * 0.02\n", | ||||||
| @@ -896,66 +898,6 @@ | |||||||
|     "## Training" |     "## Training" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |  | ||||||
|    "cell_type": "code", |  | ||||||
|    "execution_count": 99, |  | ||||||
|    "metadata": {}, |  | ||||||
|    "outputs": [], |  | ||||||
|    "source": [ |  | ||||||
|     "@dataclass\n", |  | ||||||
|     "class ModelConfig:\n", |  | ||||||
|     "    BASE_CH = 64  # 64, 128, 256, 256\n", |  | ||||||
|     "    BASE_CH_MULT = (1, 2, 4, 4) # 32, 16, 8, 8 \n", |  | ||||||
|     "    APPLY_ATTENTION = (False, True, True, False)\n", |  | ||||||
|     "    DROPOUT_RATE = 0.1\n", |  | ||||||
|     "    TIME_EMB_MULT = 4 # 128" |  | ||||||
|    ] |  | ||||||
|   }, |  | ||||||
|   { |  | ||||||
|    "cell_type": "code", |  | ||||||
|    "execution_count": 100, |  | ||||||
|    "metadata": {}, |  | ||||||
|    "outputs": [], |  | ||||||
|    "source": [ |  | ||||||
|     "model = UNet(\n", |  | ||||||
|     "    input_channels          = TrainingConfig.IMG_SHAPE[0],\n", |  | ||||||
|     "    output_channels         = TrainingConfig.IMG_SHAPE[0],\n", |  | ||||||
|     "    base_channels           = ModelConfig.BASE_CH,\n", |  | ||||||
|     "    base_channels_multiples = ModelConfig.BASE_CH_MULT,\n", |  | ||||||
|     "    apply_attention         = ModelConfig.APPLY_ATTENTION,\n", |  | ||||||
|     "    dropout_rate            = ModelConfig.DROPOUT_RATE,\n", |  | ||||||
|     "    time_multiple           = ModelConfig.TIME_EMB_MULT,\n", |  | ||||||
|     ")\n", |  | ||||||
|     "model.to(BaseConfig.DEVICE)\n", |  | ||||||
|     "\n", |  | ||||||
|     "optimizer = torch.optim.AdamW(model.parameters(), lr=TrainingConfig.LR)\n", |  | ||||||
|     "\n", |  | ||||||
|     "dataloader = get_dataloader(\n", |  | ||||||
|     "    dataset_name  = BaseConfig.DATASET,\n", |  | ||||||
|     "    batch_size    = TrainingConfig.BATCH_SIZE,\n", |  | ||||||
|     "    device        = BaseConfig.DEVICE,\n", |  | ||||||
|     "    pin_memory    = True,\n", |  | ||||||
|     "    num_workers   = TrainingConfig.NUM_WORKERS,\n", |  | ||||||
|     ")\n", |  | ||||||
|     "\n", |  | ||||||
|     "loss_fn = nn.MSELoss()\n", |  | ||||||
|     "\n", |  | ||||||
|     "sd = SimpleDiffusion(\n", |  | ||||||
|     "    num_diffusion_timesteps = TrainingConfig.TIMESTEPS,\n", |  | ||||||
|     "    img_shape               = TrainingConfig.IMG_SHAPE,\n", |  | ||||||
|     "    device                  = BaseConfig.DEVICE,\n", |  | ||||||
|     ")\n", |  | ||||||
|     "\n", |  | ||||||
|     "scaler = amp.GradScaler()" |  | ||||||
|    ] |  | ||||||
|   }, |  | ||||||
|   { |  | ||||||
|    "cell_type": "markdown", |  | ||||||
|    "metadata": {}, |  | ||||||
|    "source": [ |  | ||||||
|     "## Training" |  | ||||||
|    ] |  | ||||||
|   }, |  | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 101, |    "execution_count": 101, | ||||||
| @@ -1051,13 +993,16 @@ | |||||||
|     "        for x0s, _ in loader:\n", |     "        for x0s, _ in loader:\n", | ||||||
|     "            tq.update(1)\n", |     "            tq.update(1)\n", | ||||||
|     "            \n", |     "            \n", | ||||||
|  |     "            # 生成噪声\n", | ||||||
|     "            ts = torch.randint(low=1, high=training_config.TIMESTEPS, size=(x0s.shape[0],), device=base_config.DEVICE)\n", |     "            ts = torch.randint(low=1, high=training_config.TIMESTEPS, size=(x0s.shape[0],), device=base_config.DEVICE)\n", | ||||||
|     "            xts, gt_noise = forward_diffusion(sd, x0s, ts)\n", |     "            xts, gt_noise = forward_diffusion(sd, x0s, ts)\n", | ||||||
|     "\n", |     "\n", | ||||||
|  |     "            # forward & get loss\n", | ||||||
|     "            with amp.autocast():\n", |     "            with amp.autocast():\n", | ||||||
|     "                pred_noise = model(xts, ts)\n", |     "                pred_noise = model(xts, ts)\n", | ||||||
|     "                loss = loss_fn(gt_noise, pred_noise)\n", |     "                loss = loss_fn(gt_noise, pred_noise)\n", | ||||||
|     "\n", |     "\n", | ||||||
|  |     "            # 梯度缩放和反向传播\n", | ||||||
|     "            optimizer.zero_grad(set_to_none=True)\n", |     "            optimizer.zero_grad(set_to_none=True)\n", | ||||||
|     "            scaler.scale(loss).backward()\n", |     "            scaler.scale(loss).backward()\n", | ||||||
|     "\n", |     "\n", | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user