import csv
import glob
import json
import random
import os
import unicodedata

import pytorch_lightning as pl
import scipy
import torch
from tqdm import tqdm
import numpy as np
import pandas as pd
from natsort import natsorted

from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer


VISUAL_FEATURES_SIZE = 1024


class UMMSDatasetBase(Dataset):
    def __init__(self, args, mode):
        self.args = args
        assert mode in ["dev", "test", "train"]
        self.mode = mode

        self.data = {}

        self.tokenizer = AutoTokenizer.from_pretrained(self.args.text_model)
        _old_len = len(self.tokenizer)

        # Add the new tokens that will be used to select frames
        new_tokens = [f"img_ind_{_ind}" for _ind in range(351)]
        if self.args.append_task_id:
            new_tokens.extend(["t+v->t+i", "t+i->t+i", "t->t"])
        self.tokenizer.add_tokens(new_tokens)
        
        self.integer_to_token_id_mapping = {_ind: _old_len + _ind for _ind in range(351)}


    def __len__(self):
        raise NotImplementedError


    def __getitem__(self, idx):
        raise NotImplementedError


    def _read_mlask(self):
        """
        Read the MLASK data
        """

        path_txt = self.args.mlask_path_txt
        path_vid = self.args.mlask_path_vid
        path_img = self.args.mlask_path_img

        if self.args.smart_frame_sample:
            file_name_ending = "_alg.npy"
        elif self.args.use_all_mlask_frames:
            file_name_ending = ".npy"
        else:
            file_name_ending = "_uniform.npy"

        self.data["t+v->t+i"] = {}

        df = pd.read_csv(
            os.path.join(path_txt, "EN_multimodal_mms.tsv"),
            sep="\t",
            quoting=csv.QUOTE_NONE,
        )

        with open(os.path.join(path_txt, "_SPLIT", f"{self.mode}.id")) as f:
            _split_ids = [ int(_id) for _id in f.readlines()]
        
        df = df[df.id.apply(lambda val: val in _split_ids)] 

        self.data["t+v->t+i"]["src"] = df.article.values
        self.data["t+v->t+i"]["tgt"] = df.title.values

        self.data["t+v->t+i"]["ref_img_path"] = []
        self.data["t+v->t+i"]["video_path"] = []

        _ids = df.id.values

        for _id in tqdm(_ids):
            _video_dir = str(int(_id) // 200)
            self.data["t+v->t+i"]["ref_img_path"].append(os.path.join(path_img, _video_dir, str(_id).zfill(5) + ".npy"))
            self.data["t+v->t+i"]["video_path"].append(os.path.join(path_vid, _video_dir, str(_id).zfill(5) + file_name_ending))


    def _read_m3ls(self):
        """
        Read the M3LS data
        """
        
        path_txt = self.args.m3ls_path_txt
        path_img_src = self.args.m3ls_path_img_src
        path_img_tgt = self.args.m3ls_path_img_tgt

        self.data["t+i->t+i"] = {}

        df = pd.read_csv(
            os.path.join(path_txt, f"multimodal_m3ls_text_{self.mode}.tsv"),
            sep="\t",
            quoting=csv.QUOTE_NONE,
        )

        self.data["t+i->t+i"]["src"] = df.Article.values
        self.data["t+i->t+i"]["tgt"] = df.Title.values

        self.data["t+i->t+i"]["ref_img_path"] = []
        self.data["t+i->t+i"]["src_imgs_paths"] = []

        _hashes = df.HASH.values

        for _hash in tqdm(_hashes):
            s1, s2, *_, = _hash
            ref_img = os.path.join(path_img_tgt, s1, s2, _hash + "_ref.npy")
            self.data["t+i->t+i"]["ref_img_path"].append(ref_img)
            src_imgs = natsorted(glob.glob(os.path.join(path_img_src, s1, s2, _hash+"*npy")))
            # For this variant we assume reference image is one of the possible options
            src_imgs.append(ref_img)
            # We must shuffle the indices here otherwise the models learns to output the last position always
            if self.mode == "train":
                random.shuffle(src_imgs)
            self.data["t+i->t+i"]["src_imgs_paths"].append(src_imgs)


    def _read_msmo(self):
        """
        Read the MSMO data
        """

        assert self.mode == "test"

        path_txt = self.args.msmo_path_txt
        path_img = self.args.msmo_path_img

        self.data["t+i->t+i"] = {}

        df = pd.read_csv(
            os.path.join(path_txt, "test_msmo.tsv"),
            sep="\t",
            quoting=csv.QUOTE_NONE,
        )

        self.data["t+i->t+i"]["src"] = df.Article.values
        self.data["t+i->t+i"]["tgt"] = df.Title.values

        self.data["t+i->t+i"]["src_imgs_paths"] = []

        _ids = df.ID.values

        for _id in tqdm(_ids):
            _dir = str(int(_id) // 1000)
            src_imgs = natsorted(glob.glob(os.path.join(path_img, _dir, str(_id).zfill(6), "*npy")))
            self.data["t+i->t+i"]["src_imgs_paths"].append(src_imgs)


    def _read_pens(self):
        """
        Read the textual data from the PENS corpus
        """
        path = self.args.pens_path

        df = pd.read_csv(
            os.path.join(path, f"pens_{self.mode}.tsv"),
            sep="\t",
            quoting=csv.QUOTE_NONE,
        )

        self.data["t->t"] = {}
        self.data["t->t"]["src"] = df.Article.values
        self.data["t->t"]["tgt"] = df.Title.values


    def collate_fn(self, batch):    
        max_src_len = self.args.max_src_len
        max_tgt_len = self.args.max_tgt_len

        # Source tokens
        src_encoded = self.tokenizer(
            [self.data_str + " " + _item["src"] if self.args.append_task_id else _item["src"] for _item in batch],
            padding="longest",
            truncation=True,
            max_length=max_src_len,
        )

        src_ids = torch.tensor(src_encoded["input_ids"])
        src_mask = torch.tensor(src_encoded["attention_mask"])


        _tgt_texts = [_item["tgt"] if self.text_only else _item["tgt"] + " " + "img_ind_" + str(np.argmax(_item["tgt_cos_sim"])) for _item in batch]

        if self.args.start_with_img:
            _tgt_texts = [_item["tgt"] if self.text_only else "img_ind_" + str(np.argmax(_item["tgt_cos_sim"])) + " " + _item["tgt"] for _item in batch]


        # Target tokens
        tgt_encoded = self.tokenizer(
            _tgt_texts,
            padding="longest",
            truncation=True,
            max_length=max_tgt_len,
        )
        tgt_ids = torch.tensor(tgt_encoded["input_ids"])
        tgt_mask = torch.tensor(tgt_encoded["attention_mask"])

        _return_dict = {
            "src": [self.data_str + " " + _item["src"] if self.args.append_task_id else _item["src"] for _item in batch],
            "src_ids": src_ids,
            "src_mask": src_mask,
            "tgt": [_item["tgt"] for _item in batch],
            "tgt_ids": tgt_ids,
            "tgt_mask": tgt_mask,
        }

        tgt_probs_for_loss = torch.nn.functional.one_hot(
            torch.tensor(tgt_encoded["input_ids"]),
            num_classes=len(self.tokenizer)
        ).float()
        
        _return_dict["tgt_probs_for_loss"] = tgt_probs_for_loss
        if self.text_only:
            # Since we train with image target, we compute loss with respect to class probabilities    
            return _return_dict
        if self.args.use_smooth_labels:
            assert self.args.start_with_img is False
            for _iter, _item in enumerate(batch):
                _ind = int(torch.sum(tgt_mask[_iter, :]).item())-2
                img_similarity_target = torch.nn.functional.one_hot(torch.tensor(
                    [self.integer_to_token_id_mapping[_ind] for _ind in range(_item["tgt_cos_sim"].size)]
                ), num_classes=len(self.tokenizer))
                img_similarity_target = torch.sum(
                    torch.tensor(_item["tgt_cos_sim"]).view(-1, 1) * img_similarity_target,
                    dim=0
                )
                _tgt_cos_sim = img_similarity_target.view(-1).detach().clone()
                tgt_probs_for_loss[_iter, _ind, :] = img_similarity_target.view(-1)

            _return_dict["tgt_probs_for_loss"] = tgt_probs_for_loss
        _merged_visual_features = torch.nn.utils.rnn.pad_sequence(
            sequences=[torch.tensor(_item["visual_features"]) for _item in batch],
            batch_first=True,
            padding_value=0.0
        )

        _b, _s, *_ = _merged_visual_features.shape
        
        visual_features_mask = torch.sum(
            torch.eq(
                _merged_visual_features.view(-1, VISUAL_FEATURES_SIZE),
                torch.zeros([_b*_s, VISUAL_FEATURES_SIZE])
            ),
            axis=-1
        ).view(_b, _s) == 0
        visual_features_mask = visual_features_mask.long()


        _return_dict["visual_features"] = _merged_visual_features
        _return_dict["visual_mask"] = visual_features_mask
        _return_dict["raw_cos_sim"] = [_item["raw_cos_sim"] for _item in batch]
        _return_dict["task_type"] = batch[0]["task_type"]

        return _return_dict


class UMMSDatasetMLASK(UMMSDatasetBase):
    def __init__(self, args, mode):
        super().__init__(args, mode)
        self._read_mlask()
        self.data_str = "t+v->t+i"
        self.text_only = False
    
    def __len__(self):
        return len(self.data[self.data_str]["src"])
    
    def __getitem__(self, idx):
        _return_dict = {
            "src": self.data[self.data_str]["src"][idx],
            "tgt": self.data[self.data_str]["tgt"][idx],
            "task_type": self.data_str
        }

        visual_features = np.load(self.data[self.data_str]["video_path"][idx])
        tgt_visual_features = np.load(self.data[self.data_str]["ref_img_path"][idx])
        assert tgt_visual_features.shape[0] == 1
        tgt_visual_features = tgt_visual_features.flatten()
        _tgt_cos_sim = np.dot(visual_features, tgt_visual_features) / (np.linalg.norm(visual_features, axis=1) * np.linalg.norm(tgt_visual_features))
        # We must have a probability distribution to train with CE
        tgt_cos_sim = scipy.special.softmax(_tgt_cos_sim, -1)

        _return_dict["visual_features"] = visual_features
        _return_dict["tgt_cos_sim"] = tgt_cos_sim
        _return_dict["raw_cos_sim"] = _tgt_cos_sim

        return _return_dict

class UMMSDatasetPENS(UMMSDatasetBase):
    def __init__(self, args, mode):
        super().__init__(args, mode)
        self._read_pens() 
        self.data_str = "t->t"
        self.text_only = True
    
    def __len__(self):
        return len(self.data[self.data_str]["src"])
    
    def __getitem__(self, idx):
        _return_dict = {
            "src": self.data[self.data_str]["src"][idx],
            "tgt": self.data[self.data_str]["tgt"][idx],
            "task_type": self.data_str
        }

        return _return_dict

class UMMSDatasetM3LS(UMMSDatasetBase):
    def __init__(self, args, mode):
        super().__init__(args, mode)
        self._read_m3ls() 
        self.data_str = "t+i->t+i"
        self.text_only = False
    
    def __len__(self):
        return len(self.data[self.data_str]["src"])
    
    def __getitem__(self, idx):
        _return_dict = {
            "src": self.data[self.data_str]["src"][idx],
            "tgt": self.data[self.data_str]["tgt"][idx],
            "task_type": self.data_str
        }

        visual_features = np.concatenate([np.load(_img) for _img in self.data[self.data_str]["src_imgs_paths"][idx]], axis=0)
        tgt_visual_features = np.load(self.data[self.data_str]["ref_img_path"][idx])
        assert tgt_visual_features.shape[0] == 1
        tgt_visual_features = tgt_visual_features.flatten()
        _tgt_cos_sim = np.dot(visual_features, tgt_visual_features) / (np.linalg.norm(visual_features, axis=1) * np.linalg.norm(tgt_visual_features))
        # We must have a probability distribution to train with CE
        tgt_cos_sim = scipy.special.softmax(_tgt_cos_sim, -1)

        _return_dict["visual_features"] = visual_features
        _return_dict["tgt_cos_sim"] = tgt_cos_sim
        _return_dict["raw_cos_sim"] = _tgt_cos_sim
      
        return _return_dict


class UMMSDataModule(pl.LightningDataModule):
    def __init__(self, args):
        super().__init__()
        self.args = args

    def train_dataloader(self):
        _mlask = UMMSDatasetMLASK(self.args, "train")
        _pens = UMMSDatasetPENS(self.args, "train")
        _m3ls = UMMSDatasetM3LS(self.args, "train")

        self.train_loader_mlask = DataLoader(
            dataset=_mlask,
            batch_size=self.args.train_batch_size,
            num_workers=self.args.num_workers,
            shuffle=True,
            collate_fn=_mlask.collate_fn,
        )
        self.train_loader_pens = DataLoader(
            dataset=_pens,
            batch_size=self.args.train_batch_size,
            num_workers=self.args.num_workers,
            shuffle=True,
            collate_fn=_pens.collate_fn,
        )
        self.train_loader_m3ls= DataLoader(
            dataset=_m3ls,
            batch_size=self.args.train_batch_size,
            num_workers=self.args.num_workers,
            shuffle=True,
            collate_fn=_m3ls.collate_fn,
        )

        _data_map = {
               "t+v->t+i": self.train_loader_mlask,
               "t->t": self.train_loader_pens,
               "t+i->t+i": self.train_loader_m3ls
        }

        if self.args.mlask_only:
            _data_map= {"t+v->t+i": _data_map["t+v->t+i"]}
        elif self.args.pens_only:
            _data_map= {"t->t": _data_map["t->t"]}
        elif self.args.m3ls_only:
            _data_map= {"t+i->t+i": _data_map["t+i->t+i"]}

        return pl.utilities.combined_loader.CombinedLoader(
            _data_map,
            "max_size_cycle"
        )

    def val_dataloader(self):
        _mlask = UMMSDatasetMLASK(self.args, "dev")
        _pens = UMMSDatasetPENS(self.args, "dev")
        _m3ls = UMMSDatasetM3LS(self.args, "dev")

        self.val_loader_mlask = DataLoader(
            dataset=_mlask,
            batch_size=self.args.dev_batch_size,
            num_workers=self.args.num_workers,
            shuffle=False,
            collate_fn=_mlask.collate_fn,
        )
        self.val_loader_pens = DataLoader(
            dataset=_pens,
            batch_size=self.args.dev_batch_size,
            num_workers=self.args.num_workers,
            shuffle=False,
            collate_fn=_pens.collate_fn,
        )
        self.val_loader_m3ls= DataLoader(
            dataset=_m3ls,
            batch_size=self.args.dev_batch_size,
            num_workers=self.args.num_workers,
            shuffle=False,
            collate_fn=_m3ls.collate_fn,
        )

        _data_map = {
               "t+v->t+i": self.val_loader_mlask,
               "t->t": self.val_loader_pens,
               "t+i->t+i": self.val_loader_m3ls
        }

        if self.args.mlask_only:
            _data_map= {"t+v->t+i": _data_map["t+v->t+i"]}
        elif self.args.pens_only:
            _data_map= {"t->t": _data_map["t->t"]}
        elif self.args.m3ls_only:
            _data_map= {"t+i->t+i": _data_map["t+i->t+i"]}

        return pl.utilities.combined_loader.CombinedLoader(
            _data_map,
            "sequential"
        )

    def test_dataloader(self):
        _mlask = UMMSDatasetMLASK(self.args, "test")
        _pens = UMMSDatasetPENS(self.args, "test")
        _m3ls = UMMSDatasetM3LS(self.args, "test")

        self.test_loader_mlask = DataLoader(
            dataset=_mlask,
            batch_size=self.args.test_batch_size,
            num_workers=self.args.num_workers,
            shuffle=False,
            collate_fn=_mlask.collate_fn,
        )
        self.test_loader_pens = DataLoader(
            dataset=_pens,
            batch_size=self.args.test_batch_size,
            num_workers=self.args.num_workers,
            shuffle=False,
            collate_fn=_pens.collate_fn,
        )

        self.test_loader_m3ls= DataLoader(
            dataset=_m3ls,
            batch_size=self.args.test_batch_size,
            num_workers=self.args.num_workers,
            shuffle=False,
            collate_fn=_m3ls.collate_fn
        )

        _data_map = {
               "t+v->t+i": self.test_loader_mlask,
               "t->t": self.test_loader_pens,
               "t+i->t+i": self.test_loader_m3ls
        }

        if self.args.mlask_only:
            _data_map= {"t+v->t+i": _data_map["t+v->t+i"]}
        elif self.args.pens_only:
            _data_map= {"t->t": _data_map["t->t"]}
        elif self.args.m3ls_only:
            _data_map= {"t+i->t+i": _data_map["t+i->t+i"]}

        return pl.utilities.combined_loader.CombinedLoader(
            _data_map,
            "sequential"
        )
