| 
									
										
										
										
											2019-11-15 17:15:07 +11:00
										 |  |  | ################################################## | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | 
					
						
							|  |  |  | ################################################## | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | import torch, copy, random | 
					
						
							|  |  |  | import torch.utils.data as data | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SearchDataset(data.Dataset): | 
					
						
							| 
									
										
										
										
											2021-04-22 19:12:21 +08:00
										 |  |  |     def __init__(self, name, data, train_split, valid_split, check=True): | 
					
						
							|  |  |  |         self.datasetname = name | 
					
						
							|  |  |  |         if isinstance(data, (list, tuple)):  # new type of SearchDataset | 
					
						
							|  |  |  |             assert len(data) == 2, "invalid length: {:}".format(len(data)) | 
					
						
							|  |  |  |             self.train_data = data[0] | 
					
						
							|  |  |  |             self.valid_data = data[1] | 
					
						
							|  |  |  |             self.train_split = train_split.copy() | 
					
						
							|  |  |  |             self.valid_split = valid_split.copy() | 
					
						
							|  |  |  |             self.mode_str = "V2"  # new mode | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             self.mode_str = "V1"  # old mode | 
					
						
							|  |  |  |             self.data = data | 
					
						
							|  |  |  |             self.train_split = train_split.copy() | 
					
						
							|  |  |  |             self.valid_split = valid_split.copy() | 
					
						
							|  |  |  |             if check: | 
					
						
							|  |  |  |                 intersection = set(train_split).intersection(set(valid_split)) | 
					
						
							|  |  |  |                 assert ( | 
					
						
							|  |  |  |                     len(intersection) == 0 | 
					
						
							|  |  |  |                 ), "the splitted train and validation sets should have no intersection" | 
					
						
							|  |  |  |         self.length = len(self.train_split) | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-22 19:12:21 +08:00
										 |  |  |     def __repr__(self): | 
					
						
							|  |  |  |         return "{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})".format( | 
					
						
							|  |  |  |             name=self.__class__.__name__, | 
					
						
							|  |  |  |             datasetname=self.datasetname, | 
					
						
							|  |  |  |             tr_L=len(self.train_split), | 
					
						
							|  |  |  |             val_L=len(self.valid_split), | 
					
						
							|  |  |  |             ver=self.mode_str, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-22 19:12:21 +08:00
										 |  |  |     def __len__(self): | 
					
						
							|  |  |  |         return self.length | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-22 19:12:21 +08:00
										 |  |  |     def __getitem__(self, index): | 
					
						
							|  |  |  |         assert index >= 0 and index < self.length, "invalid index = {:}".format(index) | 
					
						
							|  |  |  |         train_index = self.train_split[index] | 
					
						
							|  |  |  |         valid_index = random.choice(self.valid_split) | 
					
						
							|  |  |  |         if self.mode_str == "V1": | 
					
						
							|  |  |  |             train_image, train_label = self.data[train_index] | 
					
						
							|  |  |  |             valid_image, valid_label = self.data[valid_index] | 
					
						
							|  |  |  |         elif self.mode_str == "V2": | 
					
						
							|  |  |  |             train_image, train_label = self.train_data[train_index] | 
					
						
							|  |  |  |             valid_image, valid_label = self.valid_data[valid_index] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             raise ValueError("invalid mode : {:}".format(self.mode_str)) | 
					
						
							|  |  |  |         return train_image, train_label, valid_image, valid_label |