#!/usr/bin/env python

import pytorch_lightning as pl

import sys
import os
import platform
_sys = platform.node()

os.environ["TOKENIZERS_PARALLELISM"] = "false" 

sys.path.append(os.path.join(os.path.dirname(__file__), "../data"))
sys.path.append(os.path.join(os.path.dirname(__file__), "../model"))


from model_BLIP2_umms import UMMSTransformerBLIP2
from data_loader_BLIP2 import UMMSDataModuleBLIP2
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.profilers import AdvancedProfiler
from transformers import AutoTokenizer

import argparse
import numpy as np
import torch

torch.set_float32_matmul_precision("medium")

parser = argparse.ArgumentParser(description="UMMS training parameters.")
parser.add_argument(
    "--append_task_id",
    action="store_true",
    help="Whether to add the input specific tag to each sequence.",
)
parser.add_argument(
    "--version",
    type=int,
    default=1,
    help="Manual versioning, to be able to compute variance for several runs.",
)

mms_args = parser.parse_args()

training_name = (
    f"version_{mms_args.version}"
)

if mms_args.append_task_id:
    training_name += "_append_task_id"

ROUGE_checkpoint = ModelCheckpoint(
    filename="{epoch}-{step}-{ROUGEL_SCORE_F_MIN:.2f}",
    monitor="ROUGEL_SCORE_F_MIN",
    mode="max",
    save_top_k=1,
)
ROUGE_stop = EarlyStopping(monitor="ROUGEL_SCORE_F_MIN", mode="max", patience=5)

mms_data = UMMSDataModuleBLIP2(
    argparse.Namespace(
        mlask_path_txt="__TEMPLATE__",
        mlask_path_vid_raw="__TEMPLATE__",
        mlask_path_vid="__TEMPLATE__",
        mlask_path_img="__TEMPLATE__",
        m3ls_path_txt="__TEMPLATE__",
        m3ls_path_img_src="__TEMPLATE__",
        m3ls_path_img_tgt="__TEMPLATE__",
        m3ls_path_img_src_raw="__TEMPLATE__",
        m3ls_path_img_tgt_raw="__TEMPLATE__",
        pens_path="__TEMPLATE__",
        append_task_id=mms_args.append_task_id,
        pretrained_model="Salesforce/blip2-flan-t5-xl",
        max_src_len=1024,
        max_tgt_len=128,
        num_workers=4,
        smart_frame_sample=False,
        mlask_only=mms_args.mlask_only,
        pens_only=mms_args.pens_only,
        m3ls_only=mms_args.m3ls_only,
        train_batch_size=1,
        dev_batch_size=12,
        test_batch_size=12
    )
)

train_loader = mms_data.train_dataloader()
val_loader = mms_data.val_dataloader()

tb_logger = TensorBoardLogger("trainings", name="unmms_blip2", version=training_name)

trainer = pl.Trainer(
    precision="bf16-mixed",
    accelerator="gpu",
    devices=3,
    max_epochs=10,
    logger=tb_logger,
    log_every_n_steps=20,
    val_check_interval=1.0, 
    callbacks=[ROUGE_checkpoint, ROUGE_stop],
    num_sanity_val_steps=1,
    strategy="ddp_find_unused_parameters_true"
)

model = UMMSTransformerBLIP2(
    pre_trained_ckpt="Salesforce/blip2-flan-t5-xl",
    append_task_id=mms_args.append_task_id,
)

trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

