#!/usr/bin/env python

import pytorch_lightning as pl

import json
import os
import sys

os.environ["TOKENIZERS_PARALLELISM"] = "false" 

sys.path.append(os.path.join(os.path.dirname(__file__), "../data"))
sys.path.append(os.path.join(os.path.dirname(__file__), "../model"))


from model_umms import UMMSTransformerT5
from data_loader import UMMSDataModule
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.profilers import AdvancedProfiler
from transformers import AutoTokenizer

import argparse
import numpy as np
import torch

torch.set_float32_matmul_precision("medium")

parser = argparse.ArgumentParser(description="UMMS training parameters.")
parser.add_argument(
    "--append_task_id",
    action="store_true",
    help="Whether to add the input specific tag to each sequence.",
)
parser.add_argument(
    "--smart_frame_sample",
    action="store_true",
    help="Whether to use the algorithmic approach for frame sampling.",
)
parser.add_argument(
    "--use_smooth_labels",
    action="store_true",
    help="Whether to use the smooth labels for image/frame similarity",
)
parser.add_argument(
    "--use_all_mlask_frames",
    action="store_true",
    help="Whether to use all frames from mlask",
)
parser.add_argument(
    "--start_with_img",
    action="store_true",
    help="Whether to start the generation with image id",
)
parser.add_argument(
    "--mlask_only",
    action="store_true",
    help="Whether to train only on MLASK",
)
parser.add_argument(
    "--pens_only",
    action="store_true",
    help="Whether to train only on PENS",
)
parser.add_argument(
    "--m3ls_only",
    action="store_true",
    help="Whether to train only on M3LS",
)
parser.add_argument(
    "--mode",
    type=str,
    default="val",
    help="Whether to compute on dev or test set.",
)
parser.add_argument(
    "--ckpt",
    type=str,
    help="Path to the checkpoint",
)

mms_args = parser.parse_args()

training_name = os.path.basename(os.path.dirname(os.path.dirname(mms_args.ckpt)))

if "_append_task_id" in training_name:
    mms_args.append_task_id = True

if "_smart_frame_sample" in training_name:
    mms_args.smart_frame_sample = True

if "_use_smooth_labels" in training_name:
    mms_args.use_smooth_labels = True

if "_mlask_only" in training_name:
    mms_args.mlask_only = True

if "_pens_only" in training_name:
    mms_args.pens_only = True

if "_m3ls_only" in training_name:
    mms_args.m3ls_only = True

if "_use_all_mlask_frames" in training_name:
    mms_args.use_all_mlask_frames = True
    _mlask_path_vid ="__TEMPLATE__",
else:
    _mlask_path_vid ="__TEMPLATE__",

if "_start_with_img" in training_name:
     mms_args.start_with_img = True

mms_data = UMMSDataModule(
    argparse.Namespace(
        mlask_path_txt="__TEMPLATE__",
        mlask_path_vid=_mlask_path_vid,
        mlask_path_img="__TEMPLATE__",
        m3ls_path_txt="__TEMPLATE__",
        m3ls_path_img_src="__TEMPLATE__",
        m3ls_path_img_tgt="__TEMPLATE__",
        pens_path="__TEMPLATE__",
        msmo_path_txt="__TEMPLATE__",
        msmo_path_img="__TEMPLATE__",
        append_task_id=mms_args.append_task_id,
        text_model="google/t5-v1_1-base",
        max_src_len=1024,
        max_tgt_len=128,
        num_workers=6,
        smart_frame_sample=mms_args.smart_frame_sample,
        use_all_mlask_frames=mms_args.use_all_mlask_frames,
        use_smooth_labels=mms_args.use_smooth_labels,
        mlask_only=mms_args.mlask_only,
        pens_only=mms_args.pens_only,
        m3ls_only=mms_args.m3ls_only,
        start_with_img=mms_args.start_with_img,
        train_batch_size=5,
        dev_batch_size=48,
        test_batch_size=48
    )
)

if mms_args.mode == "val":
    test_loader = mms_data.val_dataloader()
elif mms_args.mode == "test":
    test_loader = mms_data.test_dataloader()
else:
    print(f"Mode {mms_args.mode} not supported for validation!")
    sys.exit(1)

tb_logger = TensorBoardLogger("validation", name=f"unmms_t5_{mms_args.mode}", version=training_name)
trainer = pl.Trainer(
    precision="bf16-mixed",
    accelerator="gpu",
    devices=1,
    logger=tb_logger,
)

model = UMMSTransformerT5(
    pre_trained_ckpt="google/t5-v1_1-base",
    append_task_id=mms_args.append_task_id,
    # This one is not relevant for evaluation
    visual_weight=1,
    is_single_task=mms_args.mlask_only or mms_args.pens_only or mms_args.m3ls_only,
    is_text_only=mms_args.pens_only,
    start_with_img=mms_args.start_with_img,
)

_results = trainer.validate(model, dataloaders=test_loader, ckpt_path=mms_args.ckpt)
res = {k: round(v,2) for k,v in _results[0].items()}

with open(
    os.path.join(
        os.path.dirname(os.path.dirname(mms_args.ckpt)).replace(
        "unmms_t5", f"unmms_t5_{mms_args.mode}").replace("trainings", "validation"),
        "test_results.json"
    ), "w"
    ) as f:
    json.dump(res, f, ensure_ascii=False, indent=4)
