2019-09-28 10:24:47 +02:00
|
|
|
import time, sys
|
2019-01-31 15:27:38 +01:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
2019-09-28 10:24:47 +02:00
|
|
|
class AverageMeter(object):
|
|
|
|
"""Computes and stores the average and current value"""
|
|
|
|
def __init__(self):
|
|
|
|
self.reset()
|
|
|
|
|
2019-01-31 15:27:38 +01:00
|
|
|
def reset(self):
|
2019-09-28 10:24:47 +02:00
|
|
|
self.val = 0.0
|
|
|
|
self.avg = 0.0
|
|
|
|
self.sum = 0.0
|
|
|
|
self.count = 0.0
|
|
|
|
|
|
|
|
def update(self, val, n=1):
|
|
|
|
self.val = val
|
|
|
|
self.sum += val * n
|
2019-01-31 15:27:38 +01:00
|
|
|
self.count += n
|
2019-09-28 10:24:47 +02:00
|
|
|
self.avg = self.sum / self.count
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return ('{name}(val={val}, avg={avg}, count={count})'.format(name=self.__class__.__name__, **self.__dict__))
|
2019-01-31 15:27:38 +01:00
|
|
|
|
|
|
|
|
|
|
|
class RecorderMeter(object):
|
|
|
|
"""Computes and stores the minimum loss value and its epoch index"""
|
|
|
|
def __init__(self, total_epoch):
|
|
|
|
self.reset(total_epoch)
|
|
|
|
|
|
|
|
def reset(self, total_epoch):
|
2019-09-28 10:24:47 +02:00
|
|
|
assert total_epoch > 0, 'total_epoch should be greater than 0 vs {:}'.format(total_epoch)
|
2019-01-31 15:27:38 +01:00
|
|
|
self.total_epoch = total_epoch
|
|
|
|
self.current_epoch = 0
|
|
|
|
self.epoch_losses = np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
|
|
|
|
self.epoch_losses = self.epoch_losses - 1
|
|
|
|
self.epoch_accuracy= np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
|
|
|
|
self.epoch_accuracy= self.epoch_accuracy
|
|
|
|
|
|
|
|
def update(self, idx, train_loss, train_acc, val_loss, val_acc):
|
|
|
|
assert idx >= 0 and idx < self.total_epoch, 'total_epoch : {} , but update with the {} index'.format(self.total_epoch, idx)
|
|
|
|
self.epoch_losses [idx, 0] = train_loss
|
|
|
|
self.epoch_losses [idx, 1] = val_loss
|
|
|
|
self.epoch_accuracy[idx, 0] = train_acc
|
|
|
|
self.epoch_accuracy[idx, 1] = val_acc
|
|
|
|
self.current_epoch = idx + 1
|
|
|
|
return self.max_accuracy(False) == self.epoch_accuracy[idx, 1]
|
|
|
|
|
|
|
|
def max_accuracy(self, istrain):
|
|
|
|
if self.current_epoch <= 0: return 0
|
|
|
|
if istrain: return self.epoch_accuracy[:self.current_epoch, 0].max()
|
|
|
|
else: return self.epoch_accuracy[:self.current_epoch, 1].max()
|
|
|
|
|
|
|
|
def plot_curve(self, save_path):
|
2019-03-29 19:10:20 +01:00
|
|
|
import matplotlib
|
|
|
|
matplotlib.use('agg')
|
|
|
|
import matplotlib.pyplot as plt
|
2019-01-31 15:27:38 +01:00
|
|
|
title = 'the accuracy/loss curve of train/val'
|
|
|
|
dpi = 100
|
|
|
|
width, height = 1600, 1000
|
|
|
|
legend_fontsize = 10
|
|
|
|
figsize = width / float(dpi), height / float(dpi)
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=figsize)
|
|
|
|
x_axis = np.array([i for i in range(self.total_epoch)]) # epochs
|
|
|
|
y_axis = np.zeros(self.total_epoch)
|
|
|
|
|
|
|
|
plt.xlim(0, self.total_epoch)
|
|
|
|
plt.ylim(0, 100)
|
|
|
|
interval_y = 5
|
|
|
|
interval_x = 5
|
|
|
|
plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x))
|
|
|
|
plt.yticks(np.arange(0, 100 + interval_y, interval_y))
|
|
|
|
plt.grid()
|
|
|
|
plt.title(title, fontsize=20)
|
|
|
|
plt.xlabel('the training epoch', fontsize=16)
|
|
|
|
plt.ylabel('accuracy', fontsize=16)
|
|
|
|
|
|
|
|
y_axis[:] = self.epoch_accuracy[:, 0]
|
|
|
|
plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-accuracy', lw=2)
|
|
|
|
plt.legend(loc=4, fontsize=legend_fontsize)
|
|
|
|
|
|
|
|
y_axis[:] = self.epoch_accuracy[:, 1]
|
|
|
|
plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-accuracy', lw=2)
|
|
|
|
plt.legend(loc=4, fontsize=legend_fontsize)
|
|
|
|
|
|
|
|
|
|
|
|
y_axis[:] = self.epoch_losses[:, 0]
|
|
|
|
plt.plot(x_axis, y_axis*50, color='g', linestyle=':', label='train-loss-x50', lw=2)
|
|
|
|
plt.legend(loc=4, fontsize=legend_fontsize)
|
|
|
|
|
|
|
|
y_axis[:] = self.epoch_losses[:, 1]
|
|
|
|
plt.plot(x_axis, y_axis*50, color='y', linestyle=':', label='valid-loss-x50', lw=2)
|
|
|
|
plt.legend(loc=4, fontsize=legend_fontsize)
|
|
|
|
|
|
|
|
if save_path is not None:
|
|
|
|
fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
|
|
|
|
print ('---- save figure {} into {}'.format(title, save_path))
|
|
|
|
plt.close(fig)
|