| 
									
										
										
										
											2019-02-01 01:27:38 +11:00
										 |  |  | import numpy as np | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-30 12:05:52 +00:00
										 |  |  | class AverageMeter(object): | 
					
						
							|  |  |  |     """Computes and stores the average and current value""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self): | 
					
						
							|  |  |  |         self.reset() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def reset(self): | 
					
						
							|  |  |  |         self.val = 0.0 | 
					
						
							|  |  |  |         self.avg = 0.0 | 
					
						
							|  |  |  |         self.sum = 0.0 | 
					
						
							|  |  |  |         self.count = 0.0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def update(self, val, n=1): | 
					
						
							|  |  |  |         self.val = val | 
					
						
							|  |  |  |         self.sum += val * n | 
					
						
							|  |  |  |         self.count += n | 
					
						
							|  |  |  |         self.avg = self.sum / self.count | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __repr__(self): | 
					
						
							|  |  |  |         return "{name}(val={val}, avg={avg}, count={count})".format( | 
					
						
							|  |  |  |             name=self.__class__.__name__, **self.__dict__ | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2019-02-01 01:27:38 +11:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class RecorderMeter(object): | 
					
						
							| 
									
										
										
										
											2021-03-30 12:05:52 +00:00
										 |  |  |     """Computes and stores the minimum loss value and its epoch index""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, total_epoch): | 
					
						
							|  |  |  |         self.reset(total_epoch) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def reset(self, total_epoch): | 
					
						
							|  |  |  |         assert total_epoch > 0, "total_epoch should be greater than 0 vs {:}".format( | 
					
						
							|  |  |  |             total_epoch | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.total_epoch = total_epoch | 
					
						
							|  |  |  |         self.current_epoch = 0 | 
					
						
							|  |  |  |         self.epoch_losses = np.zeros( | 
					
						
							|  |  |  |             (self.total_epoch, 2), dtype=np.float32 | 
					
						
							|  |  |  |         )  # [epoch, train/val] | 
					
						
							|  |  |  |         self.epoch_losses = self.epoch_losses - 1 | 
					
						
							|  |  |  |         self.epoch_accuracy = np.zeros( | 
					
						
							|  |  |  |             (self.total_epoch, 2), dtype=np.float32 | 
					
						
							|  |  |  |         )  # [epoch, train/val] | 
					
						
							|  |  |  |         self.epoch_accuracy = self.epoch_accuracy | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def update(self, idx, train_loss, train_acc, val_loss, val_acc): | 
					
						
							|  |  |  |         assert ( | 
					
						
							|  |  |  |             idx >= 0 and idx < self.total_epoch | 
					
						
							|  |  |  |         ), "total_epoch : {} , but update with the {} index".format( | 
					
						
							|  |  |  |             self.total_epoch, idx | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.epoch_losses[idx, 0] = train_loss | 
					
						
							|  |  |  |         self.epoch_losses[idx, 1] = val_loss | 
					
						
							|  |  |  |         self.epoch_accuracy[idx, 0] = train_acc | 
					
						
							|  |  |  |         self.epoch_accuracy[idx, 1] = val_acc | 
					
						
							|  |  |  |         self.current_epoch = idx + 1 | 
					
						
							|  |  |  |         return self.max_accuracy(False) == self.epoch_accuracy[idx, 1] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def max_accuracy(self, istrain): | 
					
						
							|  |  |  |         if self.current_epoch <= 0: | 
					
						
							|  |  |  |             return 0 | 
					
						
							|  |  |  |         if istrain: | 
					
						
							|  |  |  |             return self.epoch_accuracy[: self.current_epoch, 0].max() | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             return self.epoch_accuracy[: self.current_epoch, 1].max() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def plot_curve(self, save_path): | 
					
						
							|  |  |  |         import matplotlib | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         matplotlib.use("agg") | 
					
						
							|  |  |  |         import matplotlib.pyplot as plt | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         title = "the accuracy/loss curve of train/val" | 
					
						
							|  |  |  |         dpi = 100 | 
					
						
							|  |  |  |         width, height = 1600, 1000 | 
					
						
							|  |  |  |         legend_fontsize = 10 | 
					
						
							|  |  |  |         figsize = width / float(dpi), height / float(dpi) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         fig = plt.figure(figsize=figsize) | 
					
						
							|  |  |  |         x_axis = np.array([i for i in range(self.total_epoch)])  # epochs | 
					
						
							|  |  |  |         y_axis = np.zeros(self.total_epoch) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         plt.xlim(0, self.total_epoch) | 
					
						
							|  |  |  |         plt.ylim(0, 100) | 
					
						
							|  |  |  |         interval_y = 5 | 
					
						
							|  |  |  |         interval_x = 5 | 
					
						
							|  |  |  |         plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x)) | 
					
						
							|  |  |  |         plt.yticks(np.arange(0, 100 + interval_y, interval_y)) | 
					
						
							|  |  |  |         plt.grid() | 
					
						
							|  |  |  |         plt.title(title, fontsize=20) | 
					
						
							|  |  |  |         plt.xlabel("the training epoch", fontsize=16) | 
					
						
							|  |  |  |         plt.ylabel("accuracy", fontsize=16) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         y_axis[:] = self.epoch_accuracy[:, 0] | 
					
						
							|  |  |  |         plt.plot(x_axis, y_axis, color="g", linestyle="-", label="train-accuracy", lw=2) | 
					
						
							|  |  |  |         plt.legend(loc=4, fontsize=legend_fontsize) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         y_axis[:] = self.epoch_accuracy[:, 1] | 
					
						
							|  |  |  |         plt.plot(x_axis, y_axis, color="y", linestyle="-", label="valid-accuracy", lw=2) | 
					
						
							|  |  |  |         plt.legend(loc=4, fontsize=legend_fontsize) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         y_axis[:] = self.epoch_losses[:, 0] | 
					
						
							|  |  |  |         plt.plot( | 
					
						
							|  |  |  |             x_axis, y_axis * 50, color="g", linestyle=":", label="train-loss-x50", lw=2 | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         plt.legend(loc=4, fontsize=legend_fontsize) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         y_axis[:] = self.epoch_losses[:, 1] | 
					
						
							|  |  |  |         plt.plot( | 
					
						
							|  |  |  |             x_axis, y_axis * 50, color="y", linestyle=":", label="valid-loss-x50", lw=2 | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         plt.legend(loc=4, fontsize=legend_fontsize) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if save_path is not None: | 
					
						
							|  |  |  |             fig.savefig(save_path, dpi=dpi, bbox_inches="tight") | 
					
						
							|  |  |  |             print("---- save figure {} into {}".format(title, save_path)) | 
					
						
							|  |  |  |         plt.close(fig) |