#!/usr/bin/env python

import pytorch_lightning as pl

import sys
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false" 

sys.path.append(os.path.join(os.path.dirname(__file__), "../data"))
sys.path.append(os.path.join(os.path.dirname(__file__), "../model"))


from model_umms import UMMSTransformerT5
from data_loader import UMMSDataModule
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.profilers import AdvancedProfiler
from transformers import AutoTokenizer

import argparse
import numpy as np
import torch

torch.set_float32_matmul_precision("medium")

parser = argparse.ArgumentParser(description="UMMS training parameters.")
parser.add_argument(
    "--visual_weight",
    type=int,
    default=1,
    help="Number of epochs with text encoder/decoder frozen",
)
parser.add_argument(
    "--append_task_id",
    action="store_true",
    help="Whether to add the input specific tag to each sequence.",
)
parser.add_argument(
    "--smart_frame_sample",
    action="store_true",
    help="Whether to use the algorithmic approach for frame sampling.",
)
parser.add_argument(
    "--use_smooth_labels",
    action="store_true",
    help="Whether to use the smooth labels for image/frame similarity",
)
parser.add_argument(
    "--use_all_mlask_frames",
    action="store_true",
    help="Whether to use all frames from mlask",
)
parser.add_argument(
    "--start_with_img",
    action="store_true",
    help="Whether to start the generation with image id",
)
parser.add_argument(
    "--mlask_only",
    action="store_true",
    help="Whether to train only on MLASK",
)
parser.add_argument(
    "--pens_only",
    action="store_true",
    help="Whether to train only on PENS",
)
parser.add_argument(
    "--m3ls_only",
    action="store_true",
    help="Whether to train only on M3LS",
)
parser.add_argument(
    "--version",
    type=int,
    default=1,
    help="Manual versioning, to be able to compute variance for several runs.",
)

mms_args = parser.parse_args()

training_name = (
    f"version_{mms_args.version}_visual_weight_{mms_args.visual_weight}"
)

if mms_args.append_task_id:
    training_name += "_append_task_id"
if mms_args.smart_frame_sample:
    training_name += "_smart_frame_sample"
if mms_args.use_smooth_labels:
    training_name += "_use_smooth_labels"
if mms_args.mlask_only:
    training_name += "_mlask_only"
if mms_args.pens_only:
    training_name += "_pens_only"
if mms_args.m3ls_only:
    training_name += "_m3ls_only"
if mms_args.use_all_mlask_frames:
    training_name += "_use_all_mlask_frames"
    _mlask_path_vid ="__TEMPLATE__",
else:
    _mlask_path_vid ="__TEMPLATE__",
if mms_args.start_with_img:
    training_name += "_start_with_img"

ROUGE_checkpoint = ModelCheckpoint(
    filename="{epoch}-{step}-{ROUGEL_SCORE_F_MIN:.2f}",
    monitor="ROUGEL_SCORE_F_MIN",
    mode="max",
    save_top_k=1,
)
ROUGE_stop = EarlyStopping(monitor="ROUGEL_SCORE_F_MIN", mode="max", patience=5)

mms_data = UMMSDataModule(
    argparse.Namespace(
        mlask_path_txt="__TEMPLATE__",
        mlask_path_vid=_mlask_path_vid,
        mlask_path_img="__TEMPLATE__",
        m3ls_path_txt="__TEMPLATE__",
        m3ls_path_img_src="__TEMPLATE__",
        m3ls_path_img_tgt="__TEMPLATE__",
        pens_path="__TEMPLATE__",
        append_task_id=mms_args.append_task_id,
        text_model="google/t5-v1_1-base",
        max_src_len=1024,
        max_tgt_len=128,
        num_workers=6,
        smart_frame_sample=False,
        use_all_mlask_frames=mms_args.use_all_mlask_frames,
        use_smooth_labels=mms_args.use_smooth_labels,
        mlask_only=mms_args.mlask_only,
        pens_only=mms_args.pens_only,
        m3ls_only=mms_args.m3ls_only,
        start_with_img=mms_args.start_with_img,
        train_batch_size=5,
        dev_batch_size=48,
        test_batch_size=48
    )
)

train_loader = mms_data.train_dataloader()
val_loader = mms_data.val_dataloader()

tb_logger = TensorBoardLogger("trainings", name="unmms_t5", version=training_name)

trainer = pl.Trainer(
    precision="bf16-mixed",
    accelerator="gpu",
    devices=3,
    max_epochs=10,
    logger=tb_logger,
    log_every_n_steps=50,
    val_check_interval=1.0, 
    accumulate_grad_batches=10,
    callbacks=[ROUGE_checkpoint, ROUGE_stop],
    num_sanity_val_steps=1,
)

model = UMMSTransformerT5(
    pre_trained_ckpt="google/t5-v1_1-base",
    append_task_id=mms_args.append_task_id,
    visual_weight=mms_args.visual_weight,
    is_single_task=mms_args.mlask_only or mms_args.pens_only or mms_args.m3ls_only,
    is_text_only=mms_args.pens_only,
    start_with_img=mms_args.start_with_img,
)

trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

