import os
import sys
import random
import argparse
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, EncoderDecoderModel
from models import EncoderDecoderModelWithGates, EncoderModelWithGates

pd.options.display.max_columns = 1000

model_path = '../models/bart/'
min_len_src = 20
max_len_src = 300
min_len_tgt = 20
max_len_tgt = 300
model_type = 'bart'
pretrained_encoder_path = 'facebook/bart-base'
pretrained_decoder_path = None
seed = 66
teacher_forcing = True
gates = ['mask','copy','generate','skip']
    
if __name__ == '__main__':
    
    texts = [sys.argv[1]]

    encoder_tokenizer = AutoTokenizer.from_pretrained(pretrained_encoder_path)
    decoder_tokenizer = encoder_tokenizer

    valX = torch.Tensor(np.asarray([encoder_tokenizer.encode(i, max_length=max_len_src, truncation=True, padding='max_length', add_special_tokens=True) \
                                    for i in texts]))

    valX = torch.tensor(valX, dtype=torch.long)

    model = EncoderModelWithGates(model_type, pretrained_encoder_path, gates=gates)
    model.encoder.config.max_length = max_len_src
    model.decoder.config.max_length = max_len_tgt
    model.encoder.config.min_length = min_len_src
    model.decoder.config.min_length = min_len_tgt

    model.encoder_tokenizer = encoder_tokenizer
    model.decoder_tokenizer = decoder_tokenizer
    encoder_mask_id = encoder_tokenizer.mask_token_id
    decoder_mask_id = decoder_tokenizer.mask_token_id
        
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.empty_cache()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)

    val_data_loader = torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(valX), batch_size=4)

    model.load_state_dict(torch.load(os.path.join(model_path,'model.pth'), map_location=torch.device('cpu')))

    model.eval()
    all_val_logits = []
    all_generate_probs = []
    all_copy_probs = []
    all_masking_probs = []
    all_skip_probs = []

    # Evaluate data for one epoch
    for batch in tqdm(val_data_loader):
        input_ids = batch[0].to(device)
        with torch.no_grad():        

            # Forward pass, calculate logit predictions.
            # token_type_ids is the same as the "segment ids", which 
            # differentiates sentence 1 and 2 in 2-sentence tasks.
            # Get the "logits" output by the model. The "logits" are the output
            # values prior to applying an activation function like the softmax.
            outputs, generate_prob, copy_prob,masking_prob, skip_prob = model(input_ids=input_ids, encoder_mask_token_id = torch.tensor([[encoder_mask_id]]).to(device),\
                                     decoder_mask_token_id = decoder_mask_id, return_dict=True)
            logits = outputs.logits

        logits = logits.detach().cpu().numpy()
        
        all_val_logits.extend(logits.argmax(-1))
        all_generate_probs.extend(generate_prob.detach().cpu().numpy())
        all_copy_probs.extend(copy_prob.detach().cpu().numpy())
        all_masking_probs.extend(masking_prob.detach().cpu().numpy())
        all_skip_probs.extend(skip_prob.detach().cpu().numpy())

    predicted_texts = []

    #if len(all_val_logits) != len(texts):
    all_val_logits = np.concatenate(all_val_logits, axis=0)

    print (all_val_logits)
    
    for i in all_val_logits:
        text = decoder_tokenizer.decode(i)
        text = text.replace('<s>','')
        text = text.replace('</s>','')
        text = text.replace('<pad>','')
        #text = [k for k in text if k not in ['<s>','</s>','<pad>']]
        predicted_texts.append(text.strip())
        #predicted_texts.append(" ".join(text).strip())

    print (predicted_texts)
