import numpy as np


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0.0
        self.avg = 0.0
        self.sum = 0.0
        self.count = 0.0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __repr__(self):
        return "{name}(val={val}, avg={avg}, count={count})".format(
            name=self.__class__.__name__, **self.__dict__
        )


class RecorderMeter(object):
    """Computes and stores the minimum loss value and its epoch index"""

    def __init__(self, total_epoch):
        self.reset(total_epoch)

    def reset(self, total_epoch):
        assert total_epoch > 0, "total_epoch should be greater than 0 vs {:}".format(
            total_epoch
        )
        self.total_epoch = total_epoch
        self.current_epoch = 0
        self.epoch_losses = np.zeros(
            (self.total_epoch, 2), dtype=np.float32
        )  # [epoch, train/val]
        self.epoch_losses = self.epoch_losses - 1
        self.epoch_accuracy = np.zeros(
            (self.total_epoch, 2), dtype=np.float32
        )  # [epoch, train/val]
        self.epoch_accuracy = self.epoch_accuracy

    def update(self, idx, train_loss, train_acc, val_loss, val_acc):
        assert (
            idx >= 0 and idx < self.total_epoch
        ), "total_epoch : {} , but update with the {} index".format(
            self.total_epoch, idx
        )
        self.epoch_losses[idx, 0] = train_loss
        self.epoch_losses[idx, 1] = val_loss
        self.epoch_accuracy[idx, 0] = train_acc
        self.epoch_accuracy[idx, 1] = val_acc
        self.current_epoch = idx + 1
        return self.max_accuracy(False) == self.epoch_accuracy[idx, 1]

    def max_accuracy(self, istrain):
        if self.current_epoch <= 0:
            return 0
        if istrain:
            return self.epoch_accuracy[: self.current_epoch, 0].max()
        else:
            return self.epoch_accuracy[: self.current_epoch, 1].max()

    def plot_curve(self, save_path):
        import matplotlib

        matplotlib.use("agg")
        import matplotlib.pyplot as plt

        title = "the accuracy/loss curve of train/val"
        dpi = 100
        width, height = 1600, 1000
        legend_fontsize = 10
        figsize = width / float(dpi), height / float(dpi)

        fig = plt.figure(figsize=figsize)
        x_axis = np.array([i for i in range(self.total_epoch)])  # epochs
        y_axis = np.zeros(self.total_epoch)

        plt.xlim(0, self.total_epoch)
        plt.ylim(0, 100)
        interval_y = 5
        interval_x = 5
        plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x))
        plt.yticks(np.arange(0, 100 + interval_y, interval_y))
        plt.grid()
        plt.title(title, fontsize=20)
        plt.xlabel("the training epoch", fontsize=16)
        plt.ylabel("accuracy", fontsize=16)

        y_axis[:] = self.epoch_accuracy[:, 0]
        plt.plot(x_axis, y_axis, color="g", linestyle="-", label="train-accuracy", lw=2)
        plt.legend(loc=4, fontsize=legend_fontsize)

        y_axis[:] = self.epoch_accuracy[:, 1]
        plt.plot(x_axis, y_axis, color="y", linestyle="-", label="valid-accuracy", lw=2)
        plt.legend(loc=4, fontsize=legend_fontsize)

        y_axis[:] = self.epoch_losses[:, 0]
        plt.plot(
            x_axis, y_axis * 50, color="g", linestyle=":", label="train-loss-x50", lw=2
        )
        plt.legend(loc=4, fontsize=legend_fontsize)

        y_axis[:] = self.epoch_losses[:, 1]
        plt.plot(
            x_axis, y_axis * 50, color="y", linestyle=":", label="valid-loss-x50", lw=2
        )
        plt.legend(loc=4, fontsize=legend_fontsize)

        if save_path is not None:
            fig.savefig(save_path, dpi=dpi, bbox_inches="tight")
            print("---- save figure {} into {}".format(title, save_path))
        plt.close(fig)