import torch
import torch.nn as nn
import torch.nn.functional as K

import transformers

from transformers import ViTConfig, ViTModel
from transformers import T5Config, T5EncoderModel, T5Tokenizer
from transformers import FlavaConfig, FlavaProcessor, FlavaModel
from transformers import AutoTokenizer, AutoProcessor, AutoFeatureExtractor, AutoModel
from transformers import TrainingArguments, Trainer, logging
from transformers import BertLayer

import os
import numpy as np
import pandas as pd
import json
from copy import deepcopy

import datasets
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from PIL import Image
from sklearn.metrics import accuracy_score, f1_score

from tqdm.notebook import tqdm

import argparse
parser = argparse.ArgumentParser()

parser.add_argument('--data_path',default=None,type=str)

parser.add_argument('--text_model',default='t5-base',type=str)
parser.add_argument('--vision_model',default="google/vit-base-patch16-224-in21k",type=str)
parser.add_argument('--flava_model',default='facebook/flava-full',type=str)

parser.add_argument('--output_dir',default=None,type=str)

parser.add_argument('--max_seq_len',default=24,type=int)
parser.add_argument('--batch_size',default=10,type=int)
parser.add_argument('--epochs',default=5,type=int)
parser.add_argument('--lr',default=3e-5,type=float)
parser.add_argument('--random_seed',default=42,type=int)
parser.add_argument('--dropout',default=0.5,type=float)
parser.add_argument('--scheduler_name',default='linear',type=str)
parser.add_argument('--intermediate_dim',default=1536,type=int)
parser.add_argument('--fp16',default=False,type=bool)
parser.add_argument('--warmup_ratio',default=0.0,type=float)
parser.add_argument('--ga_steps',default=1,type=int)

args = parser.parse_args()

print(args)

os.environ["TOKENIZERS_PARALLELISM"] = "false" 

# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
# os.environ["CUDA_VISIBLE_DEVICES"]="3"

device = torch.device("cuda")
print('There are %d GPU(s) available.' % torch.cuda.device_count())
print('We will use the GPU:', torch.cuda.get_device_name(0))

print("Load dataset from disk ... ")
dataset_train = datasets.load_from_disk(os.path.join(args.data_path,'train'))
dataset_dev = datasets.load_from_disk(os.path.join(args.data_path,'dev'))
dataset_test = datasets.load_from_disk(os.path.join(args.data_path,'test'))
print("Done ! ")


@dataclass
class MultimodalCollator:
    tokenizer: AutoTokenizer
    preprocessor: AutoFeatureExtractor

    def tokenize_text(self, texts: List[str]):
        encoded_text = self.tokenizer(
            text=texts,
            padding='max_length',
            max_length=args.max_seq_len,
            truncation=True,
            return_tensors='pt',
            return_attention_mask=True,
        )
        return {
            "input_ids": encoded_text['input_ids'].squeeze(),
            "attention_mask": encoded_text['attention_mask'].squeeze(),
        }

    def preprocess_images(self, images: List[str]):
        processed_images = self.preprocessor(
            images,
            return_tensors="pt",
        )
        return {
            "pixel_values": processed_images['pixel_values'].squeeze(),
        }
            
    def __call__(self, raw_batch_dict):
        return {
            **self.tokenize_text(
                raw_batch_dict['hypothesis']
                if isinstance(raw_batch_dict, dict) else
                [i['hypothesis'] for i in raw_batch_dict]
            ),
            **self.preprocess_images(
                raw_batch_dict['image'].convert('RGB')
                if isinstance(raw_batch_dict, dict) else
                [i['image'].convert('RGB') for i in raw_batch_dict]
            ),
            'labels': torch.tensor(
                raw_batch_dict['label']
                if isinstance(raw_batch_dict, dict) else
                [i['label'] for i in raw_batch_dict],
                dtype=torch.int64
            ),
        }

class BasicStudentModelForSNLI(nn.Module):
    def __init__(self,  pretrained_text_name, pretrained_image_name, flava_mm_encoder, num_labels=3, 
                 intermediate_dim=768, dropout=0.5):
        super(BasicStudentModelForSNLI, self).__init__()
        
        self.num_labels = num_labels
        self.pretrained_text_name = pretrained_text_name
        self.pretrained_image_name = pretrained_image_name
        
        self.intermediate_dim = intermediate_dim
        
        # Pretrained transformers for text & image featurization 
        self.text_encoder = T5EncoderModel.from_pretrained(self.pretrained_text_name)
        self.image_encoder = AutoModel.from_pretrained(self.pretrained_image_name)
        
        self.flava_mm_encoder = flava_mm_encoder

        self.text_projection = nn.Linear(self.text_encoder.config.hidden_size, self.flava_mm_encoder.config.hidden_size)
        self.image_projection = nn.Linear(self.image_encoder.config.hidden_size, self.flava_mm_encoder.config.hidden_size)
        
        self.fusion = nn.Linear(self.flava_mm_encoder.config.hidden_size, self.intermediate_dim)
        self.fusion_activations = nn.Sequential(nn.ReLU(),
                                                nn.Dropout(args.dropout))
        
        self.classifier = nn.Linear(intermediate_dim, self.num_labels)
        
        self.criterion = nn.CrossEntropyLoss()
    
    def forward(
            self,
            input_ids: torch.LongTensor,
            pixel_values: torch.FloatTensor,
            attention_mask: Optional[torch.LongTensor] = None,
            labels: Optional[torch.LongTensor] = None):
        
        encoded_text = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )
        encoded_image = self.image_encoder(
            pixel_values=pixel_values,
            return_dict=True,
        )
        
        projected_text = self.text_projection(encoded_text['last_hidden_state'])
        projected_image = self.image_projection(encoded_image['last_hidden_state'])

        multimodal_output = self.flava_mm_encoder(torch.cat([projected_text,projected_image], dim=1))['last_hidden_state']
        
        logits = self.classifier(self.fusion_activations(self.fusion(multimodal_output[:,0,:])))
        
        out = {
            "text_cls": projected_text[:,0,:],
            "image_cls": projected_image[:,0,:],
            "multimodal_cls":multimodal_output[:,0,:],
            "logits": logits,
        }
        if labels is not None:
            loss = self.criterion(logits, labels)
            out["loss"] = loss
        
        return out


def compute_metrics(eval_tuple: Tuple[np.ndarray, np.ndarray]) -> Dict[str, float]:
    logits, labels = eval_tuple
    # labels = labels.argmax(axis=-1)
    preds = logits[3].argmax(axis=-1)
    return {
        "acc": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds, average='macro')
    }

multi_args = TrainingArguments(
    output_dir=args.output_dir,
    seed=args.random_seed, 
    learning_rate=args.lr,
    # evaluation_strategy="steps",
    # eval_steps=50,
    # save_strategy="steps",
    # save_steps=50,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=500,
    save_total_limit=5,
    metric_for_best_model='acc',
    per_device_train_batch_size=args.batch_size,
    per_device_eval_batch_size=args.batch_size,
    remove_unused_columns=False,
    num_train_epochs=args.epochs,
    dataloader_num_workers=8,
    load_best_model_at_end=True,
    warmup_ratio=args.warmup_ratio,
    lr_scheduler_type=args.scheduler_name,
    eval_accumulation_steps=500,
    gradient_accumulation_steps=args.ga_steps,

)

model_flava = FlavaModel.from_pretrained(args.flava_model)
flava_mm_encoder = deepcopy(model_flava.multimodal_model)
model_flava=None # Mem flush

tokenizer = AutoTokenizer.from_pretrained(args.text_model)
preprocessor = AutoFeatureExtractor.from_pretrained(args.vision_model)
collator = MultimodalCollator(tokenizer=tokenizer, preprocessor=preprocessor)

model = BasicStudentModelForSNLI(pretrained_text_name=args.text_model, 
                                 pretrained_image_name=args.vision_model,
                                 flava_mm_encoder=flava_mm_encoder,
                                 intermediate_dim=args.intermediate_dim).to(device)

print("#"*20, "Show Model Architecture","#"*20)
print(model)

multi_trainer = Trainer(
    model,
    multi_args,
    train_dataset=dataset_train,
    eval_dataset=dataset_dev,
    data_collator=collator,
    compute_metrics=compute_metrics
)


train_multi_metrics = multi_trainer.train()

prd_results = multi_trainer.predict(dataset_test)

prd_label_ids = np.argmax(prd_results[0][3],axis=-1)

test_acc = accuracy_score(dataset_test['label'], prd_label_ids)
print("test accuracy = "+str(test_acc))
with open(os.path.join(args.output_dir,'test_eval_result.txt'),'w') as f:
    f.write("test accuracy = "+str(test_acc))