# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from os import path as osp
from copy import deepcopy as copy
from tqdm import tqdm
import warnings, time, random, numpy as np

from pts_utils import generate_label_map
from xvision import denormalize_points
from xvision import identity2affine, solve2theta, affine2image
from .dataset_utils import pil_loader
from .landmark_utils import PointMeta2V
from .augmentation_utils import CutOut
import torch
import torch.utils.data as data


class LandmarkDataset(data.Dataset):

  def __init__(self, transform, sigma, downsample, heatmap_type, shape, use_gray, mean_file, data_indicator, cache_images=None):

    self.transform    = transform
    self.sigma        = sigma
    self.downsample   = downsample
    self.heatmap_type = heatmap_type
    self.dataset_name = data_indicator
    self.shape        = shape # [H,W]
    self.use_gray     = use_gray
    assert transform is not None, 'transform : {:}'.format(transform)
    self.mean_file    = mean_file
    if mean_file is None:
      self.mean_data  = None
      warnings.warn('LandmarkDataset initialized with mean_data = None')
    else:
      assert osp.isfile(mean_file), '{:} is not a file.'.format(mean_file)
      self.mean_data  = torch.load(mean_file)
    self.reset()
    self.cutout       = None
    self.cache_images = cache_images
    print ('The general dataset initialization done : {:}'.format(self))
    warnings.simplefilter( 'once' )


  def __repr__(self):
    return ('{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})'.format(name=self.__class__.__name__, **self.__dict__))


  def set_cutout(self, length):
    if length is not None and length >= 1:
      self.cutout = CutOut( int(length) )
    else: self.cutout = None


  def reset(self, num_pts=-1, boxid='default', only_pts=False):
    self.NUM_PTS = num_pts
    if only_pts: return
    self.length  = 0
    self.datas   = []
    self.labels  = []
    self.NormDistances = []
    self.BOXID = boxid
    if self.mean_data is None:
      self.mean_face = None
    else:
      self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T)
      assert (self.mean_face >= -1).all() and (self.mean_face <= 1).all(), 'mean-{:}-face : {:}'.format(boxid, self.mean_face)
    #assert self.dataset_name is not None, 'The dataset name is None'


  def __len__(self):
    assert len(self.datas) == self.length, 'The length is not correct : {}'.format(self.length)
    return self.length


  def append(self, data, label, distance):
    assert osp.isfile(data), 'The image path is not a file : {:}'.format(data)
    self.datas.append( data )             ;  self.labels.append( label )
    self.NormDistances.append( distance )
    self.length = self.length + 1


  def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset):
    if reset: self.reset(num_pts, boxindicator)
    else    : assert self.NUM_PTS == num_pts and self.BOXID == boxindicator, 'The number of point is inconsistance : {:} vs {:}'.format(self.NUM_PTS, num_pts)
    if isinstance(file_lists, str): file_lists = [file_lists]
    samples = []
    for idx, file_path in enumerate(file_lists):
      print (':::: load list {:}/{:} : {:}'.format(idx, len(file_lists), file_path))
      xdata = torch.load(file_path)
      if isinstance(xdata, list)  : data = xdata          # image or video dataset list
      elif isinstance(xdata, dict): data = xdata['datas'] # multi-view dataset list
      else: raise ValueError('Invalid Type Error : {:}'.format( type(xdata) ))
      samples = samples + data
    # samples is a dict, where the key is the image-path and the value is the annotation
    # each annotation is a dict, contains 'points' (3,num_pts), and various box
    print ('GeneralDataset-V2 : {:} samples'.format(len(samples)))

    #for index, annotation in enumerate(samples):
    for index in tqdm( range( len(samples) ) ):
      annotation = samples[index]
      image_path  = annotation['current_frame']
      points, box = annotation['points'], annotation['box-{:}'.format(boxindicator)]
      label = PointMeta2V(self.NUM_PTS, points, box, image_path, self.dataset_name)
      if normalizeL is None: normDistance = None
      else                 : normDistance = annotation['normalizeL-{:}'.format(normalizeL)]
      self.append(image_path, label, normDistance)

    assert len(self.datas) == self.length, 'The length and the data is not right {} vs {}'.format(self.length, len(self.datas))
    assert len(self.labels) == self.length, 'The length and the labels is not right {} vs {}'.format(self.length, len(self.labels))
    assert len(self.NormDistances) == self.length, 'The length and the NormDistances is not right {} vs {}'.format(self.length, len(self.NormDistance))
    print ('Load data done for LandmarkDataset, which has {:} images.'.format(self.length))


  def __getitem__(self, index):
    assert index >= 0 and index < self.length, 'Invalid index : {:}'.format(index)
    if self.cache_images is not None and self.datas[index] in self.cache_images:
      image = self.cache_images[ self.datas[index] ].clone()
    else:
      image = pil_loader(self.datas[index], self.use_gray)
    target = self.labels[index].copy()
    return self._process_(image, target, index)


  def _process_(self, image, target, index):

    # transform the image and points
    image, target, theta = self.transform(image, target)
    (C, H, W), (height, width) = image.size(), self.shape

    # obtain the visiable indicator vector
    if target.is_none(): nopoints = True
    else               : nopoints = False
    if index == -1: __path = None
    else          : __path = self.datas[index]
    if isinstance(theta, list) or isinstance(theta, tuple):
      affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = [], [], [], [], [], []
      for _theta in theta:
        _affineImage, _heatmaps, _mask, _norm_trans_points, _theta, _transpose_theta \
          = self.__process_affine(image, target, _theta, nopoints, 'P[{:}]@{:}'.format(index, __path))
        affineImage.append(_affineImage)
        heatmaps.append(_heatmaps)
        mask.append(_mask)
        norm_trans_points.append(_norm_trans_points)
        THETA.append(_theta)
        transpose_theta.append(_transpose_theta)
      affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = \
          torch.stack(affineImage), torch.stack(heatmaps), torch.stack(mask), torch.stack(norm_trans_points), torch.stack(THETA), torch.stack(transpose_theta)
    else:
      affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = self.__process_affine(image, target, theta, nopoints, 'S[{:}]@{:}'.format(index, __path))

    torch_index = torch.IntTensor([index])
    torch_nopoints = torch.ByteTensor( [ nopoints ] )
    torch_shape = torch.IntTensor([H,W])

    return affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta, torch_index, torch_nopoints, torch_shape

  
  def __process_affine(self, image, target, theta, nopoints, aux_info=None):
    image, target, theta = image.clone(), target.copy(), theta.clone()
    (C, H, W), (height, width) = image.size(), self.shape
    if nopoints: # do not have label
      norm_trans_points = torch.zeros((3, self.NUM_PTS))
      heatmaps          = torch.zeros((self.NUM_PTS+1, height//self.downsample, width//self.downsample))
      mask              = torch.ones((self.NUM_PTS+1, 1, 1), dtype=torch.uint8)
      transpose_theta   = identity2affine(False)
    else:
      norm_trans_points = apply_affine2point(target.get_points(), theta, (H,W))
      norm_trans_points = apply_boundary(norm_trans_points)
      real_trans_points = norm_trans_points.clone()
      real_trans_points[:2, :] = denormalize_points(self.shape, real_trans_points[:2,:])
      heatmaps, mask = generate_label_map(real_trans_points.numpy(), height//self.downsample, width//self.downsample, self.sigma, self.downsample, nopoints, self.heatmap_type) # H*W*C
      heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(torch.FloatTensor)
      mask     = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor)
      if self.mean_face is None:
        #warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.')
        transpose_theta = identity2affine(False)
      else:
        if torch.sum(norm_trans_points[2,:] == 1) < 3:
          warnings.warn('In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}'.format(aux_info))
          transpose_theta = identity2affine(False)
        else:
          transpose_theta = solve2theta(norm_trans_points, self.mean_face.clone())

    affineImage = affine2image(image, theta, self.shape)
    if self.cutout is not None: affineImage = self.cutout( affineImage )

    return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta