import torch

import os
import time
import json
import numpy as np
from collections import defaultdict
from speaker import Speaker
from mbert import mBERT

from utils import read_vocab,write_vocab,build_vocab,Tokenizer,padding_idx,timeSince, read_img_features
import utils
from env import R2RBatch
from agent import Seq2SeqAgent
from eval import Evaluation
import warnings
warnings.filterwarnings("ignore")


from tensorboardX import SummaryWriter
from transformers import BertTokenizer

import sys
import random
import math

import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
from utils import padding_idx, add_idx, Tokenizer
import model
from collections import defaultdict

from transformers import BertModel, BertConfig, AdamW, get_linear_schedule_with_warmup

import heapq
from bleu import compute_bleu
import string

from rouge import Rouge

def split_sentence(sentence):
    ''' Break sentence into a list of words and punctuation '''
    toks = []
    for word in [s.strip().lower() for s in Tokenizer.SENTENCE_SPLIT_REGEX.split(sentence.strip()) if
                 len(s.strip()) > 0]:
        # Break up any words containing punctuation only, e.g. '!?', unless it is multiple full stops e.g. '..'
        if all(c in string.punctuation for c in word) and not all(c in '.' for c in word):
            toks += list(word)
        else:
            toks.append(word)
    return toks


# splits = ['train', 'val_seen', 'val_unseen']
splits = ['train']

# Compute BLEU

# for split in splits:
#     print(split)
#     count = 0
#     with open('../RXR/rxr-data/rxr_%s_guide_multi.jsonl' % split) as f:
#         new_data = json.load(f)
#     f.close()
#
#     with open("../RXR/rxr-data/rxr_%s_guide_image.jsonl" % split) as f:
#         results = json.load(f)
#     f.close()
#
#     ins_to_data = dict()
#     for data in new_data:
#         ins_to_data[data['instruction_id']] = data
#
#     ins_set = []
#     for data in new_data:
#         if data['language'].split("-")[0] == "en":
#             ins_set.append(data['instruction_id'])
#     ins_set = list(set(ins_set))
#
#     reference = []
#     candidate = []
#     for i, data in enumerate(new_data):
#         print(i)
#         language = data['language'].split("-")[0]
#         if language != "en":
#             continue
#         ins_id = data['instruction_id']
#         result = results[str(ins_id)]
#         instruction = data['instruction']
#         sim_score = result[0][0]
#
#         # instructions = data['instructions']
#         # languages = data['languages']
#         # for i, ins in enumerate(instructions):
#         #     if languages[i].split("-")[0] == "en":
#         #         picked_instruction = ins
#         #         reference.append([split_sentence(instruction)])
#         #         candidate.append(split_sentence(picked_instruction))
#
#         # if sim_score > 0.99:
#         #     if result[0][1] in ins_to_data:
#         #         similar_data = ins_to_data[result[0][1]]
#         #         language = similar_data['language'].split("-")[0]
#         #         if language != "en":
#         #             continue
#         #         picked_instruction = similar_data['instruction']
#         #     else:
#         #         continue
#         #     reference.append([split_sentence(instruction)])
#         #     candidate.append(split_sentence(picked_instruction))
#         pair_id = np.random.choice(ins_set, 1)[0]
#         picked_instruction = ins_to_data[pair_id]['instruction']
#         reference.append([split_sentence(instruction)])
#         candidate.append(split_sentence(picked_instruction))
#
#
#     print(len(reference))
#     tuple = compute_bleu(reference, candidate)
#     bleu = tuple[0]
#     precision = tuple[1]
#
#     print(bleu)
#     print(precision)


# Compute Rouge
rouge = Rouge()

for split in splits:
    print(split)
    count = 0
    with open('../RXR/rxr-data/rxr_%s_guide_multi.jsonl' % split) as f:
        new_data = json.load(f)
    f.close()

    with open("../RXR/rxr-data/rxr_%s_guide_image.jsonl" % split) as f:
        results = json.load(f)
    f.close()

    ins_to_data = dict()
    for data in new_data:
        ins_to_data[data['instruction_id']] = data

    reference = []
    candidate = []

    ins_set = []
    for data in new_data:
        if data['language'].split("-")[0] == "en":
            ins_set.append(data['instruction_id'])
    ins_set = list(set(ins_set))

    for i, data in enumerate(new_data):
        print(i)
        language = data['language'].split("-")[0]
        if language != "en":
            continue
        ins_id = data['instruction_id']
        result = results[str(ins_id)]
        instruction = data['instruction']
        sim_score = result[0][0]

        # instructions = data['instructions']
        # languages = data['languages']
        # for i, ins in enumerate(instructions):
        #     if languages[i].split("-")[0] == "en":
        #         picked_instruction = ins
        #         reference.append(instruction)
        #         candidate.append(picked_instruction)

        if sim_score > 0.99:
            if result[0][1] in ins_to_data:
                similar_data = ins_to_data[result[0][1]]
                language = similar_data['language'].split("-")[0]
                if language != "en":
                    continue
                picked_instruction = similar_data['instruction']
            else:
                continue
            reference.append(instruction)
            candidate.append(picked_instruction)

        # pair_id = np.random.choice(ins_set, 1)[0]
        # picked_instruction = ins_to_data[pair_id]['instruction']
        # reference.append(instruction)
        # candidate.append(picked_instruction)

    print(len(reference))
    score = rouge.get_scores(candidate, reference, avg=True)
    print(score)