
import torch
import pickle

import os
import time
import json
import numpy as np
from collections import defaultdict
from speaker import Speaker
from mbert import mBERT

from utils import read_vocab,write_vocab,build_vocab,Tokenizer,padding_idx,timeSince, read_img_features
import utils
from env import R2RBatch
from agent import Seq2SeqAgent
from eval import Evaluation
import warnings
warnings.filterwarnings("ignore")


from tensorboardX import SummaryWriter
from transformers import BertTokenizer

import sys
import random
import math

import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
from utils import padding_idx, add_idx, Tokenizer
from collections import defaultdict

from transformers import BertModel, BertConfig, AdamW, get_linear_schedule_with_warmup

import heapq

import CLIP.clip as clip



#-------Main---------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = 'language'

path = 'snap/encoder1/state_dict/best_val_unseen_loss'

tok = BertTokenizer.from_pretrained('bert-base-multilingual-cased')


if model == 'language':
    encoder = mBERT("", "", tok, 0)

    states = torch.load(path)

    state = encoder.encoder.state_dict()
    model_keys = set(state.keys())
    load_keys = set(states['encoder']['state_dict'].keys())
    if model_keys != load_keys:
        print("NOTICE: DIFFERENT KEYS IN THE LISTEREN")
    state.update(states['encoder']['state_dict'])
    encoder.encoder.load_state_dict(state)
elif model == 'mbert':
    model_config = BertConfig.from_pretrained("bert-base-multilingual-cased", return_dict=True)
    encoder = BertModel.from_pretrained("bert-base-multilingual-cased", config=model_config).to(device)
elif model == 'clip':
    encoder, preprocess = clip.load("ViT-B/32", device=device)

splits = ['train', 'val_seen', 'val_unseen']
# splits = ['val_unseen']
# splits = ['train']

for split in splits:
    with open('../RXR/rxr-data/rxr_%s_guide_multi.jsonl' % split) as f:
        new_data = json.load(f)
    f.close()


    features = dict()
    path_length = dict()
    path_id = dict()
    print(len(new_data))
    print("Generating Text features")
    for i, data in enumerate(new_data):
        print(i)
        if model == 'clip':
            if data['language'] == 'te-IN' or data['language'] == 'hi-IN':
                continue
            encoding = tok(data['instruction'], padding='max_length', truncation=True, max_length=77)
            seq = encoding['input_ids']
            text_seq = tok.convert_ids_to_tokens(seq)
            text = clip.tokenize(text_seq).to(device)
            text_features = encoder.encode_text(text)
            representation = text_features[0,:].unsqueeze(0)
            # representation = torch.mean(text_features,dim=0).unsqueeze(0)
            # print(representation.shape)
        else:
            encoding = tok(data['instruction'], padding='max_length', truncation=True, max_length=160)
            seq = encoding['input_ids']
            attention_mask = encoding['attention_mask']
            input = torch.from_numpy(np.array(seq)).unsqueeze(0).to(device)
            attn = torch.from_numpy(np.array(attention_mask)).unsqueeze(0).to(device)
            if model == 'language':
                text_features = encoder.encoder(input, attention_mask=attn)
            elif model == 'mbert':
                text_features = encoder(input, attention_mask=attn)
            representation = text_features.last_hidden_state[:,0,:].squeeze(1)

        features[data['instruction_id']] = F.normalize(representation, dim=1).detach().cpu().numpy()
        path_length[data['instruction_id']] = len(data['path'])
        path_id[data['instruction_id']] = data['path_id']

        # if i == 10:
        #     break

    # with open('filename.pickle', 'wb') as handle:
    #     pickle.dump(a, handle, protocol=pickle.HIGHEST_PROTOCOL)
    # with open("lang_features_%s.pickle" % split, "w") as f:
    #     pickle.dump(features, f, protocol=pickle.HIGHEST_PROTOCOL)
    # f.close()
    # with open("lang_path_length_%s.pickle" % split, "w") as f:
    #     pickle.dump(path_length, f, protocol=pickle.HIGHEST_PROTOCOL)
    # f.close()
    # with open("lang_path_id_%s.pickle" % split, "w") as f:
    #     pickle.dump(path_id, f, protocol=pickle.HIGHEST_PROTOCOL)
    # f.close()

    # print(features)

    # ins = new_data[9]['instruction_id']
    # topN = [(0, "0"), (0, "0")]
    # heapq.heapify(topN)
    # query = features[ins]
    # for id, feature in features.items():
    #     if path_id[id] == path_id[ins] or path_length[id] != path_length[ins]:
    #         continue
    #     sim = np.matmul(query, feature.T)
    #     if sim > topN[0][0]:
    #         heapq.heapreplace(topN, (sim.item(), id))
    #
    # ins_to_data = dict()
    # for data in new_data:
    #     ins_to_data[data['instruction_id']] = data
    # result = topN
    # print("TopN:", topN)
    # data = new_data[9]
    # print("scan", data['scan'])
    # print("path", data['path'])
    # print("path_id", data['path_id'])
    # print("ins", data['instruction'])
    # similar_data1 = ins_to_data[int(result[0][1])]
    # similar_data2 = ins_to_data[int(result[1][1])]
    # print("s_scan1", similar_data1['scan'])
    # print("s_path1", similar_data1['path'])
    # print("s_ins1", similar_data1['instruction'])
    # print("s_scan2", similar_data2['scan'])
    # print("s_path2", similar_data2['path'])
    # print("s_ins2", similar_data2['instruction'])
    # exit()

    print("Computing Similarity")
    results = dict()
    i = 0
    count = 0
    for ins, query in features.items():
        print(i)
        i += 1
        topN = [(0, "0"), (0, "0"), (0, "0"), (0, "0"), (0, "0")]
        heapq.heapify(topN)
        for id, feature in features.items():
            if id == ins:
                continue
            # if path_id[id] == path_id[ins] or path_length[id] != path_length[ins]:
            #     continue
            # print(query)
            # print(feature.T)
            sim = np.matmul(query, feature.T)
            # print(sim)
            if sim > topN[0][0]:
                heapq.heapreplace(topN, (sim.item(), id))

        results[ins] = topN

    # print("Computing Similarity")
    # results = dict()
    # i = 0
    # for ins, query in features.items():
    #     print(i)
    #     i+=1
    #     topN = [(0,"0"), (0,"0")]
    #     heapq.heapify(topN)
    #     for id, feature in features.items():
    #         if id == ins:
    #             continue
    #         # if path_id[id] == path_id[ins] or path_length[id] != path_length[ins]:
    #         #     continue
    #         # print(query)
    #         # print(feature.T)
    #         sim = np.matmul(query, feature.T)
    #         # print(sim)
    #         if sim > topN[0][0]:
    #             heapq.heapreplace(topN, (sim.item(), id))
    #
    #     results[ins] = topN

    # print(results)

    with open("../RXR/rxr-data/rxr_%s_guide_image_exp1lang_top5_screen.jsonl" % split, "w") as f:
        json.dump(results, f)
    f.close()






