# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from os import path as osp
from copy import deepcopy as copy
from tqdm import tqdm
import warnings, time, random, numpy as np

from pts_utils import generate_label_map
from xvision import denormalize_points
from xvision import identity2affine, solve2theta, affine2image
from .dataset_utils import pil_loader
from .landmark_utils import PointMeta2V
from .augmentation_utils import CutOut
import torch
import torch.utils.data as data


class LandmarkDataset(data.Dataset):
    def __init__(
        self,
        transform,
        sigma,
        downsample,
        heatmap_type,
        shape,
        use_gray,
        mean_file,
        data_indicator,
        cache_images=None,
    ):

        self.transform = transform
        self.sigma = sigma
        self.downsample = downsample
        self.heatmap_type = heatmap_type
        self.dataset_name = data_indicator
        self.shape = shape  # [H,W]
        self.use_gray = use_gray
        assert transform is not None, "transform : {:}".format(transform)
        self.mean_file = mean_file
        if mean_file is None:
            self.mean_data = None
            warnings.warn("LandmarkDataset initialized with mean_data = None")
        else:
            assert osp.isfile(mean_file), "{:} is not a file.".format(mean_file)
            self.mean_data = torch.load(mean_file)
        self.reset()
        self.cutout = None
        self.cache_images = cache_images
        print("The general dataset initialization done : {:}".format(self))
        warnings.simplefilter("once")

    def __repr__(self):
        return "{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})".format(
            name=self.__class__.__name__, **self.__dict__
        )

    def set_cutout(self, length):
        if length is not None and length >= 1:
            self.cutout = CutOut(int(length))
        else:
            self.cutout = None

    def reset(self, num_pts=-1, boxid="default", only_pts=False):
        self.NUM_PTS = num_pts
        if only_pts:
            return
        self.length = 0
        self.datas = []
        self.labels = []
        self.NormDistances = []
        self.BOXID = boxid
        if self.mean_data is None:
            self.mean_face = None
        else:
            self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T)
            assert (self.mean_face >= -1).all() and (
                self.mean_face <= 1
            ).all(), "mean-{:}-face : {:}".format(boxid, self.mean_face)
        # assert self.dataset_name is not None, 'The dataset name is None'

    def __len__(self):
        assert len(self.datas) == self.length, "The length is not correct : {}".format(
            self.length
        )
        return self.length

    def append(self, data, label, distance):
        assert osp.isfile(data), "The image path is not a file : {:}".format(data)
        self.datas.append(data)
        self.labels.append(label)
        self.NormDistances.append(distance)
        self.length = self.length + 1

    def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset):
        if reset:
            self.reset(num_pts, boxindicator)
        else:
            assert (
                self.NUM_PTS == num_pts and self.BOXID == boxindicator
            ), "The number of point is inconsistance : {:} vs {:}".format(
                self.NUM_PTS, num_pts
            )
        if isinstance(file_lists, str):
            file_lists = [file_lists]
        samples = []
        for idx, file_path in enumerate(file_lists):
            print(
                ":::: load list {:}/{:} : {:}".format(idx, len(file_lists), file_path)
            )
            xdata = torch.load(file_path)
            if isinstance(xdata, list):
                data = xdata  # image or video dataset list
            elif isinstance(xdata, dict):
                data = xdata["datas"]  # multi-view dataset list
            else:
                raise ValueError("Invalid Type Error : {:}".format(type(xdata)))
            samples = samples + data
        # samples is a dict, where the key is the image-path and the value is the annotation
        # each annotation is a dict, contains 'points' (3,num_pts), and various box
        print("GeneralDataset-V2 : {:} samples".format(len(samples)))

        # for index, annotation in enumerate(samples):
        for index in tqdm(range(len(samples))):
            annotation = samples[index]
            image_path = annotation["current_frame"]
            points, box = (
                annotation["points"],
                annotation["box-{:}".format(boxindicator)],
            )
            label = PointMeta2V(
                self.NUM_PTS, points, box, image_path, self.dataset_name
            )
            if normalizeL is None:
                normDistance = None
            else:
                normDistance = annotation["normalizeL-{:}".format(normalizeL)]
            self.append(image_path, label, normDistance)

        assert (
            len(self.datas) == self.length
        ), "The length and the data is not right {} vs {}".format(
            self.length, len(self.datas)
        )
        assert (
            len(self.labels) == self.length
        ), "The length and the labels is not right {} vs {}".format(
            self.length, len(self.labels)
        )
        assert (
            len(self.NormDistances) == self.length
        ), "The length and the NormDistances is not right {} vs {}".format(
            self.length, len(self.NormDistance)
        )
        print(
            "Load data done for LandmarkDataset, which has {:} images.".format(
                self.length
            )
        )

    def __getitem__(self, index):
        assert index >= 0 and index < self.length, "Invalid index : {:}".format(index)
        if self.cache_images is not None and self.datas[index] in self.cache_images:
            image = self.cache_images[self.datas[index]].clone()
        else:
            image = pil_loader(self.datas[index], self.use_gray)
        target = self.labels[index].copy()
        return self._process_(image, target, index)

    def _process_(self, image, target, index):

        # transform the image and points
        image, target, theta = self.transform(image, target)
        (C, H, W), (height, width) = image.size(), self.shape

        # obtain the visiable indicator vector
        if target.is_none():
            nopoints = True
        else:
            nopoints = False
        if index == -1:
            __path = None
        else:
            __path = self.datas[index]
        if isinstance(theta, list) or isinstance(theta, tuple):
            affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = (
                [],
                [],
                [],
                [],
                [],
                [],
            )
            for _theta in theta:
                (
                    _affineImage,
                    _heatmaps,
                    _mask,
                    _norm_trans_points,
                    _theta,
                    _transpose_theta,
                ) = self.__process_affine(
                    image, target, _theta, nopoints, "P[{:}]@{:}".format(index, __path)
                )
                affineImage.append(_affineImage)
                heatmaps.append(_heatmaps)
                mask.append(_mask)
                norm_trans_points.append(_norm_trans_points)
                THETA.append(_theta)
                transpose_theta.append(_transpose_theta)
            affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = (
                torch.stack(affineImage),
                torch.stack(heatmaps),
                torch.stack(mask),
                torch.stack(norm_trans_points),
                torch.stack(THETA),
                torch.stack(transpose_theta),
            )
        else:
            (
                affineImage,
                heatmaps,
                mask,
                norm_trans_points,
                THETA,
                transpose_theta,
            ) = self.__process_affine(
                image, target, theta, nopoints, "S[{:}]@{:}".format(index, __path)
            )

        torch_index = torch.IntTensor([index])
        torch_nopoints = torch.ByteTensor([nopoints])
        torch_shape = torch.IntTensor([H, W])

        return (
            affineImage,
            heatmaps,
            mask,
            norm_trans_points,
            THETA,
            transpose_theta,
            torch_index,
            torch_nopoints,
            torch_shape,
        )

    def __process_affine(self, image, target, theta, nopoints, aux_info=None):
        image, target, theta = image.clone(), target.copy(), theta.clone()
        (C, H, W), (height, width) = image.size(), self.shape
        if nopoints:  # do not have label
            norm_trans_points = torch.zeros((3, self.NUM_PTS))
            heatmaps = torch.zeros(
                (self.NUM_PTS + 1, height // self.downsample, width // self.downsample)
            )
            mask = torch.ones((self.NUM_PTS + 1, 1, 1), dtype=torch.uint8)
            transpose_theta = identity2affine(False)
        else:
            norm_trans_points = apply_affine2point(target.get_points(), theta, (H, W))
            norm_trans_points = apply_boundary(norm_trans_points)
            real_trans_points = norm_trans_points.clone()
            real_trans_points[:2, :] = denormalize_points(
                self.shape, real_trans_points[:2, :]
            )
            heatmaps, mask = generate_label_map(
                real_trans_points.numpy(),
                height // self.downsample,
                width // self.downsample,
                self.sigma,
                self.downsample,
                nopoints,
                self.heatmap_type,
            )  # H*W*C
            heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(
                torch.FloatTensor
            )
            mask = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor)
            if self.mean_face is None:
                # warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.')
                transpose_theta = identity2affine(False)
            else:
                if torch.sum(norm_trans_points[2, :] == 1) < 3:
                    warnings.warn(
                        "In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}".format(
                            aux_info
                        )
                    )
                    transpose_theta = identity2affine(False)
                else:
                    transpose_theta = solve2theta(
                        norm_trans_points, self.mean_face.clone()
                    )

        affineImage = affine2image(image, theta, self.shape)
        if self.cutout is not None:
            affineImage = self.cutout(affineImage)

        return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta