update Unet
This commit is contained in:
		
							
								
								
									
										11
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										11
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1,3 +1,8 @@ | ||||
| flowers/* | ||||
| data/ | ||||
| archive.zip | ||||
| ./flowers/* | ||||
| .DS_Store | ||||
| ./UNet/train_image/* | ||||
| ./UNet/params/* | ||||
| ./UNet/__pycache__/* | ||||
| data/ | ||||
| archive.zip | ||||
| flowers/* | ||||
|   | ||||
							
								
								
									
										31
									
								
								UNet/data.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								UNet/data.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,31 @@ | ||||
| import os | ||||
|  | ||||
| from torch.utils.data import Dataset | ||||
| from utils import * | ||||
| from torchvision import transforms | ||||
| transform = transforms.Compose([ | ||||
|     transforms.ToTensor() | ||||
| ]) | ||||
|  | ||||
|  | ||||
| #use VOC2007 Dataset | ||||
| class MyDataset(Dataset): | ||||
|     def __init__(self, path): | ||||
|         self.path = path | ||||
|         self.name = os.listdir(os.path.join(path, 'SegmentationClass')) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.name) | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         segment_name = self.name[index] #xx.png | ||||
|         segment_path = os.path.join(self.path, 'SegmentationClass',segment_name) | ||||
|         image_path = os.path.join(self.path,'JPEGImages', segment_name.replace('png','jpg')) | ||||
|         segment_image = keep_image_size_open(segment_path)  | ||||
|         image = keep_image_size_open(image_path) | ||||
|         return transform(image), transform(segment_image) | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     data = MyDataset('/Users/hanzhangma/Document/DataSet/VOC2007') | ||||
|     print(data[0][0].shape) # print the size of image(0,0) | ||||
|     print(data[0][1].shape) # print the size of image(0,1) | ||||
							
								
								
									
										87
									
								
								UNet/net.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								UNet/net.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,87 @@ | ||||
| from torch import nn | ||||
| from torch.nn import functional as F | ||||
| from torch import randn | ||||
| import torch | ||||
|  | ||||
| class Conv_Block(nn.Module): | ||||
|     def __init__(self, in_channel, out_channel): | ||||
|         super(Conv_Block, self).__init__() | ||||
|         self.layer = nn.Sequential( | ||||
|             nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1, padding_mode='reflect', bias=False), | ||||
|             nn.BatchNorm2d(out_channel), | ||||
|             nn.Dropout2d(0.3), | ||||
|             nn.LeakyReLU(), | ||||
|             nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3,stride=1,padding=1,padding_mode='reflect', bias=False), | ||||
|             nn.BatchNorm2d(out_channel), | ||||
|             nn.Dropout2d(0.3), | ||||
|             nn.LeakyReLU() | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.layer(x) | ||||
|  | ||||
| class DownSample(nn.Module): | ||||
|     def __init__(self, channel): | ||||
|         super(DownSample, self).__init__() | ||||
|         self.layer = nn.Sequential( | ||||
|             nn.Conv2d(channel, channel, 3, 2, 1, padding_mode='reflect', bias=False), | ||||
|             nn.BatchNorm2d(channel), | ||||
|             nn.LeakyReLU() | ||||
|         ) | ||||
|     def forward(self, x): | ||||
|         return self.layer(x) | ||||
|  | ||||
| class UpSample(nn.Module): | ||||
|     def __init__(self, channel): | ||||
|         super(UpSample, self).__init__() | ||||
|         self.layer = nn.Sequential( | ||||
|             nn.Conv2d(channel, channel//2, 1, 1) | ||||
|         ) | ||||
|     def forward(self, x, feature_map): | ||||
|         up = F.interpolate(x, scale_factor=2, mode='nearest') | ||||
|         out = self.layer(up) | ||||
|         return torch.cat((out, feature_map), dim=1) | ||||
|  | ||||
| class UNet(nn.Module): | ||||
|     def __init__(self): | ||||
|         super(UNet, self).__init__() | ||||
|         self.c1 = Conv_Block(3,64) | ||||
|         self.d1 = DownSample(64) | ||||
|         self.c2 = Conv_Block(64, 128) | ||||
|         self.d2 = DownSample(128) | ||||
|         self.c3 = Conv_Block(128, 256) | ||||
|         self.d3 = DownSample(256) | ||||
|         self.c4 = Conv_Block(256, 512) | ||||
|         self.d4 = DownSample(512) | ||||
|         self.c5 = Conv_Block(512, 1024) | ||||
|  | ||||
|         self.u1 = UpSample(1024) | ||||
|         self.c6 = Conv_Block(1024, 512) | ||||
|         self.u2 = UpSample(512) | ||||
|         self.c7 = Conv_Block(512, 256) | ||||
|         self.u3 = UpSample(256) | ||||
|         self.c8 = Conv_Block(256, 128) | ||||
|         self.u4 = UpSample(128) | ||||
|         self.c9 = Conv_Block(128, 64) | ||||
|  | ||||
|         self.out = nn.Conv2d(64, 3, 3, 1, 1) | ||||
|         self.Th = nn.Sigmoid() | ||||
|  | ||||
|     def forward(self, x): | ||||
|         R1 = self.c1(x) | ||||
|         R2 = self.c2(self.d1(R1)) | ||||
|         R3 = self.c3(self.d2(R2)) | ||||
|         R4 = self.c4(self.d3(R3)) | ||||
|         R5 = self.c5(self.d4(R4)) | ||||
|  | ||||
|         O1 = self.c6(self.u1(R5, R4)) | ||||
|         O2 = self.c7(self.u2(O1, R3)) | ||||
|         O3 = self.c8(self.u3(O2, R2)) | ||||
|         O4 = self.c9(self.u4(O3, R1)) | ||||
|  | ||||
|         return self.Th(self.out(O4)) | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     x = randn(2, 3, 256, 256) | ||||
|     net = UNet() | ||||
|     print(net(x).shape) | ||||
							
								
								
									
										53
									
								
								UNet/train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								UNet/train.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | ||||
| import torch | ||||
| from torch import optim | ||||
| from torch.utils.data import DataLoader | ||||
| from data import * | ||||
| from net import * | ||||
|  | ||||
| from torchvision.utils import save_image | ||||
|  | ||||
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||||
| weight_path = r'/Users/hanzhangma/Nextcloud/mhz/Study/SS24/MasterThesis/UNet/params/unet.pth' | ||||
| data_path = r'/Users/hanzhangma/Document/DataSet/VOC2007' | ||||
| save_path = r'/Users/hanzhangma/Nextcloud/mhz/Study/SS24/MasterThesis/Unet/train_image' | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     data_loader = DataLoader(MyDataset(data_path), batch_size= 4, shuffle=True) | ||||
|  | ||||
|     net = UNet().to(device) | ||||
|     if os.path.exists(weight_path): | ||||
|         net.load_state_dict(torch.load(weight_path)) | ||||
|         print('successful load weight!') | ||||
|     else: | ||||
|         print('Failed on load weight!') | ||||
|  | ||||
|     opt = optim.Adam(net.parameters()) | ||||
|     loss_fun = nn.BCELoss() | ||||
|  | ||||
|     epoch=1 | ||||
|  | ||||
|     while True: | ||||
|         for i,(image,segment_image) in enumerate(data_loader): | ||||
|             image, segment_image = image.to(device), segment_image.to(device) | ||||
|  | ||||
|             out_image = net(image) | ||||
|             train_loss = loss_fun(out_image, segment_image) | ||||
|  | ||||
|             opt.zero_grad() | ||||
|             train_loss.backward() | ||||
|             opt.step() # 更新梯度 | ||||
|  | ||||
|             if i%5 ==0 : | ||||
|                 print(f'{epoch} -- {i} -- train loss ===>> {train_loss.item()}') | ||||
|  | ||||
|             if i % 50 == 0: | ||||
|                 torch.save(net.state_dict(), weight_path) | ||||
|  | ||||
|             _image = image[0] | ||||
|             _segment_image = segment_image[0] | ||||
|             _out_image = out_image[0] | ||||
|  | ||||
|             img = torch.stack([_image, _segment_image, _out_image], dim=0) | ||||
|             save_image(img, f'{save_path}/{i}.png') | ||||
|  | ||||
|         epoch += 1 | ||||
							
								
								
									
										10
									
								
								UNet/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								UNet/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | ||||
| from PIL import Image | ||||
|  | ||||
| def keep_image_size_open(path,size=(256,256)): | ||||
|     img = Image.open(path) | ||||
|     tmp = max(img.size) | ||||
|     mask = Image.new('RGB', (tmp, tmp),(0,0,0)) | ||||
|     mask.paste(img,(0,0)) | ||||
|     mask = mask.resize(size) | ||||
|     return mask | ||||
|  | ||||
							
								
								
									
										25432
									
								
								paper/Neural Architecture Search without Training.pdf
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25432
									
								
								paper/Neural Architecture Search without Training.pdf
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										6306
									
								
								paper/Neural Architecture Survey.pdf
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6306
									
								
								paper/Neural Architecture Survey.pdf
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
		Reference in New Issue
	
	Block a user