#!/usr/bin/env python

import pytorch_lightning as pl

import json
import os
import sys

os.environ["TOKENIZERS_PARALLELISM"] = "false" 

sys.path.append(os.path.join(os.path.dirname(__file__), "../data"))
sys.path.append(os.path.join(os.path.dirname(__file__), "../model"))


from model_BLIP2_umms import UMMSTransformerBLIP2
from data_loader_BLIP2 import UMMSDataModuleBLIP2
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.profilers import AdvancedProfiler
from transformers import AutoTokenizer

import argparse
import numpy as np
import torch

torch.set_float32_matmul_precision("medium")

parser = argparse.ArgumentParser(description="UMMS training parameters.")

parser.add_argument(
    "--mode",
    type=str,
    default="val",
    help="Whether to compute on dev or test set.",
)
parser.add_argument(
    "--ckpt",
    type=str,
    help="Path to the checkpoint",
)

mms_args = parser.parse_args()

training_name = os.path.basename(os.path.dirname(os.path.dirname(mms_args.ckpt)))

append_task_id = False
if "_append_task_id" in training_name:
    append_task_id = True

mms_data = UMMSDataModuleBLIP2(
    argparse.Namespace(
        mlask_path_txt="__TEMPLATE__",
        mlask_path_vid_raw="__TEMPLATE__",
        mlask_path_vid="__TEMPLATE__",
        mlask_path_img="__TEMPLATE__",
        m3ls_path_txt="__TEMPLATE__",
        m3ls_path_img_src="__TEMPLATE__",
        m3ls_path_img_tgt="__TEMPLATE__",
        m3ls_path_img_src_raw="__TEMPLATE__",
        m3ls_path_img_tgt_raw="__TEMPLATE__",
        pens_path="__TEMPLATE__",
        append_task_id=append_task_id,
        pretrained_model="Salesforce/blip2-flan-t5-xl",
        max_src_len=1024,
        max_tgt_len=128,
        num_workers=4,
        smart_frame_sample=False,
        mlask_only=False,
        pens_only=False,
        m3ls_only=False,
        train_batch_size=1,
        dev_batch_size=12,
        test_batch_size=12
    )
)


if mms_args.mode == "val":
    test_loader = mms_data.val_dataloader()
elif mms_args.mode == "test":
    test_loader = mms_data.test_dataloader()
else:
    print(f"Mode {mms_args.mode} not supported for validation!")
    sys.exit(1)


tb_logger = TensorBoardLogger("validation", name=f"unmms_blip2_{mms_args.mode}", version=training_name)
trainer = pl.Trainer(
    precision="bf16-mixed",
    accelerator="gpu",
    devices=1,
    logger=tb_logger,
)

model = UMMSTransformerBLIP2(
    pre_trained_ckpt="Salesforce/blip2-flan-t5-xl",
    append_task_id=append_task_id,
)

_results = trainer.validate(model, dataloaders=test_loader, ckpt_path=mms_args.ckpt)
res = {k: round(v,2) for k,v in _results[0].items()}

with open(
    os.path.join(
        os.path.dirname(os.path.dirname(mms_args.ckpt)).replace(
        "unmms_blip2", f"unmms_blip2_{mms_args.mode}").replace("trainings", "validation"),
        "test_results.json"
    ), "w"
    ) as f:
    json.dump(res, f, ensure_ascii=False, indent=4)
