import os
import re
import h5py
import json
import utils
import torch
import torch.nn as nn
import numpy as np
from scipy import spatial
from torch.utils.data import DataLoader, ConcatDataset, Dataset
import _pickle as cPickle
from xml.etree.ElementTree import parse
# from transformers import *
from tqdm import tqdm
from torch.nn import functional as F

class Indexer(object):
    def __init__(self):
        self.objs_to_ints = {}
        self.ints_to_objs = {}

    def __repr__(self):
        return str([str(self.get_object(i)) for i in range(0, len(self))])

    def __str__(self):
        return self.__repr__()

    def __len__(self):
        return len(self.objs_to_ints)

    def get_object(self, index):
        if (index not in self.ints_to_objs):
            return None
        else:
            return self.ints_to_objs[index]

    def contains(self, object):

        return self.index_of(object) != -1

    def index_of(self, object):

        if (object not in self.objs_to_ints):
            return -1
        else:
            return self.objs_to_ints[object]

    def add_and_get_index(self, object, add=True):

        if not add:
            return self.index_of(object)
        if (object not in self.objs_to_ints):
            new_idx = len(self.objs_to_ints)
            self.objs_to_ints[object] = new_idx
            self.ints_to_objs[new_idx] = object
        return self.objs_to_ints[object]

class WordEmbeddings:
    def __init__(self, word_indexer, vectors):
        self.word_indexer = word_indexer
        self.vectors = vectors

    def get_embedding_length(self):
        return len(self.vectors[0])

    def get_embedding(self, word):
            word_idx = self.word_indexer.index_of(word)
            if word_idx != -1:
                return self.vectors[word_idx]
            else:
                return self.vectors[self.word_indexer.index_of("UNK")]

    def get_embeddings(self, word_list):
            emb_list = []
            for word in word_list:
                emb = self.get_embedding(word)
                emb_list.append(emb)
            return np.array(emb_list)

    def similarity(self, w1, w2):
        return 1 - spatial.distance.cosine(self.get_embedding(w1), self.get_embedding(w2))


class Flickr30dataset(Dataset):
    def __init__(self, wordEmbedding, name='train', dataroot='data/flickr30k/', use_bert = False, lite_bert=False):
        super(Flickr30dataset, self).__init__()
        self.use_bert = use_bert
        self.lite_bert = lite_bert
        self.vgg = False
        self.entries, self.img_id2idx = load_dataset(name, dataroot, use_bert = use_bert, lite_bert = lite_bert, vgg=self.vgg)
        # img_id2idx: dict {img_id -> val} val can be used to retrieve image or features
        self.indexer = wordEmbedding.word_indexer

        dataroot='data/flickr30k/'
        h5_path = os.path.join(dataroot, 'my%s.hdf5' % name)
        with h5py.File(h5_path, 'r') as hf:
            self.features = np.array(hf.get('features'))
            self.pos_boxes = np.array(hf.get('pos_bboxes'))

        if use_bert or lite_bert:
            h5_path = os.path.join(dataroot, 'bert_feature_%s' % name)
            with h5py.File(h5_path, 'r') as hf:
                self.label_features = np.array(hf.get('label_features'))


    def __getitem__(self, index):
        '''
        return : labels, query, deteced_bboxes, number of querys

        labels: [K=64] index
        attrs: [K=64] index
        bboxes: [K=64, 5] index
        querys: [Q=32, len=12] index
        query_feats: [Q, dim]
        label_feats: [K, dim]
        target_bboxes: [Q=32, Boxes, 4] index
        '''
        K=100
        Q=32
        lens=12
        tok_lens = 12
        B=20
        bert_dim = 768

        entry = self.entries[index]
        imgid = entry['image']
        labels = entry['labels']
        querys = entry['query']
        idx = self.img_id2idx[int(imgid)]
        if not self.vgg:
            attrs = entry['attrs']

        label_toks = torch.tensor([0])
        sent_toks = torch.tensor([0])
        phrase_toks = torch.tensor([0])
        entity_indices = torch.tensor([0])
        entity_feats = torch.tensor([0])
        label_feats = torch.tensor([0])

        idx = self.img_id2idx[int(imgid)]# to retrieve pos in pos_box
        pos = self.pos_boxes[idx]

        if self.vgg:
            feature = torch.tensor(entry["features"]).float()
            # print("feature shape", feature.shape)
        else:
            feature = self.features[pos[0]:pos[1]]
            # print(feature.shape)
            feature = torch.from_numpy(feature).float()

        if feature.size(0)<K:
            pad = nn.ZeroPad2d((0,0,0,K-feature.size(0)))
            feature = pad(feature)
        else:
            feature = feature[:K]

        num_obj = min(len(labels), K)
        num_query = min(len(querys),Q)

        labels_idx = [0]* K
        labels_idx[:num_obj] = [max(self.indexer.index_of(re.split(' ,',w)[-1]), 1) for w in labels]
        labels_idx = labels_idx[:K]

        if self.vgg:
            attr_idx = [0]*K
        else:
            attr_idx = [0]*K
            attr_idx[:num_obj] = [max(self.indexer.index_of(w), 1) for w in attrs]
            attr_idx = attr_idx[:K]


        querys_idx=[]
        for q in querys:
            q = q.lower().split()
            lis=[0]*lens
            for i in range(min(len(q), lens)):
                lis[i] = max(self.indexer.index_of(q[i]), 1)
            querys_idx.append(lis)
        while(len(querys_idx)<Q):
            querys_idx.append([0]*lens)
        querys_idx = querys_idx[:Q]


        bboxes = entry['detected_bboxes'] # [x1,y1,x2,y2]
        target_bboxes = entry['target_bboxes']

        padbox = [0,0,0,0]

        while(len(bboxes)<K):
            bboxes.append(padbox)
        bboxes = bboxes[:K]

        bboxes = torch.tensor(bboxes)
        area = (bboxes[...,3]-bboxes[...,1])*(bboxes[...,2]-bboxes[...,0])
        # print(area[:3], area.shape)
        bboxes = torch.cat((bboxes, area.unsqueeze_(-1)), -1)
        # print(bboxes.shape)

        for bbox in target_bboxes:
            while(len(bbox)<B):
                bbox.append(padbox)
        target_bboxes = [b[:B] for b in target_bboxes]
        padline = [padbox for i in range(B)]
        while(len(target_bboxes)<Q):
            target_bboxes.append(padline)
        target_bboxes = target_bboxes[:Q]

        assert len(labels_idx)==K
        assert len(attr_idx)==K
        assert len(bboxes)==K
        assert len(querys_idx)==Q
        assert len(target_bboxes)==Q

        return torch.tensor(int(imgid)), torch.tensor(labels_idx), torch.tensor(attr_idx), feature, torch.tensor(querys_idx), label_feats, sent_toks, phrase_toks, entity_indices, entity_feats, bboxes, torch.tensor(target_bboxes), torch.tensor(num_obj), torch.tensor(num_query)

    def __len__(self):
        return len(self.entries)


def read_word_embeddings(embeddings_file: str) -> WordEmbeddings:
    f = open(embeddings_file)
    word_indexer = Indexer()
    vectors = []
    # Make position 0 a PAD token, which can be useful if you
    word_indexer.add_and_get_index("PAD")
    # Make position 1 the UNK token
    word_indexer.add_and_get_index("UNK")
    for line in f:
        if line.strip() != "":
            space_idx = line.find(' ')
            word = line[:space_idx]
            numbers = line[space_idx+1:]
            float_numbers = [float(number_str) for number_str in numbers.split()]
            vector = np.array(float_numbers)
            word_indexer.add_and_get_index(word)
            # Append the PAD and UNK vectors to start. Have to do this weirdly because we need to read the first line
            # of the file to see what the embedding dim is
            if len(vectors) == 0:
                vectors.append(np.zeros(vector.shape[0]))
                vectors.append(np.zeros(vector.shape[0]))
            vectors.append(vector)
    f.close()
    print("Read in " + repr(len(word_indexer)) + " vectors of size " + repr(vectors[0].shape[0]))

    return WordEmbeddings(word_indexer, np.array(vectors))


def load_train_flickr30k(dataroot, img_id2idx, obj_detection, bert_feature_dict, use_bert = False, lite_bert=False, vgg=False):
    """Load entries

    img_id2idx: dict {img_id -> val} val can be used to retrieve image or features
    dataroot: root path of dataset
    name: 'train', 'val', 'test-dev2015', test2015'
    """
    pattern_phrase = r'\[(.*?)\]'
    pattern_no = r'\/EN\#(\d+)'
    missing_entity_count = dict()
    entries = []

    for image_id, idx in tqdm(img_id2idx.items()):

        phrase_file = os.path.join(dataroot, 'Flickr30kEntities/Sentences/%d.txt' % image_id)
        anno_file = os.path.join(dataroot, 'Flickr30kEntities/Annotations/%d.xml' % image_id)

        with open(phrase_file, 'r', encoding='utf-8') as f:
            sents = [x.strip() for x in f]

        # Parse Annotation
        root = parse(anno_file).getroot()
        obj_elems = root.findall('./object')

        target_bboxes_dict = {}
        entitywbox=[]

        for elem in obj_elems:
            if elem.find('bndbox') == None or len(elem.find('bndbox')) == 0:
                continue
            left = int(elem.findtext('./bndbox/xmin'))
            top = int(elem.findtext('./bndbox/ymin'))
            right = int(elem.findtext('./bndbox/xmax'))
            bottom = int(elem.findtext('./bndbox/ymax'))
            assert 0 < left and 0 < top

            for name in elem.findall('name'):
                entity_id = int(name.text)
                assert 0 < entity_id
                entitywbox.append(entity_id)
                if not entity_id in target_bboxes_dict.keys():
                    target_bboxes_dict[entity_id] = []
                target_bboxes_dict[entity_id].append([left, top, right, bottom])

        if vgg:
            bboxes = obj_detection[image_id]['bboxes']
            labels = obj_detection[image_id]['classes'] # [B, 4]
            features =  obj_detection[image_id]['features']
            if len(features) == 0:
                continue
            # print("features", features)
            # input()
        else:
            image_id = str(image_id)
            bboxes = obj_detection[image_id]['bboxes']
            labels = obj_detection[image_id]['classes'] # [B, 4]
            attrs = obj_detection[image_id]['attrs']

        assert(len(bboxes)==len(labels))

        # Parse Sentence
        sent_entries=[]
        for sent_id, sent in enumerate(sents):
            sentence = utils.remove_annotations(sent)
            entities = re.findall(pattern_phrase, sent)
            entity_ids = []
            entity_types = []
            entity_names = []
            entity_indices = []
            target_bboxes = []
            query=[]

            for i, entity in enumerate(entities):
                info, phrase = entity.split(' ', 1)
                entity_id = int(re.findall(pattern_no, info)[0])
                entity_type = info.split('/')[2:]
                entity_idx = utils.find_sublist(sentence.split(' '), phrase.split(' '))

                # assert 0 <= entity_idx

                if not entity_id in target_bboxes_dict:
                    if entity_id >= 0:
                        missing_entity_count[entity_type[0]] = missing_entity_count.get(entity_type[0], 0) + 1
                    continue

                assert 0 < entity_id

                # in entity order
                # entity_feat = sent_feat[entity_id[0]]
                target_bboxes.append(target_bboxes_dict[entity_id])
                query.append(phrase)

                entity_names.append(phrase)
                entity_ids.append(entity_id)
                entity_types.append(entity_type)
                assert len(entity_names) == len(entity_ids)

                entity_indices.append(entity_idx)

            if 0 == len(entity_ids):
                continue
            if vgg:
                entry = {
                    'image'          : image_id,
                    'target_bboxes'  : target_bboxes, # in order of entities
                    "detected_bboxes" : bboxes, # list, in order of labels
                    'labels' : labels,
                    'query' : query,
                    'features': features
                    }
                entries.append(entry)
            else:
                entry = {
                    'image'          : image_id,
                    'target_bboxes'  : target_bboxes, # in order of entities
                    "detected_bboxes" : bboxes, # list, in order of labels
                    'labels' : labels,
                    'attrs': attrs,
                    'query' : query
                    }
                entries.append(entry)

    print("Load Down!")
    return entries


def load_dataset(name='train', dataroot='data/flickr30k/', use_bert = False, lite_bert=False, vgg=False):
    '''
    "xxxx":{
        bboxes:
        classes:
        attrs:
    }
    '''
    bert_feature_dict = None
    obj_detection_dict = json.load(open("data/%s_dataset.json"%name, "r"))
    if vgg:
        obj_detection = json.load(open("data/obj_detection_vgg_0.05.json", "r"))
        obj_detection_dict = gen_obj_dict(obj_detection)
    else:
        obj_detection_dict = json.load(open("data/%s_dataset.json"%name, "r"))
    if use_bert:
        bert_feature_dict = json.load(open("bert_feature_%s.json"%name, "r"))

    img_id2idx = cPickle.load(
        open(os.path.join(dataroot, '%s_imgid2idx.pkl' % name), 'rb'))
    # h5_path = os.path.join(dataroot, '%s.hdf5' % name)

    entries = load_train_flickr30k(dataroot, img_id2idx, obj_detection_dict, bert_feature_dict, use_bert = use_bert, lite_bert = lite_bert, vgg=vgg)
    print("load flickr30k data successfully.")
    return entries, img_id2idx


# generate object detection dictionary
def gen_obj_dict(obj_detection):
    obj_detect_dict={}
    for img in obj_detection:
        img_id = int(img["image"].split('.')[0])
        # print(img_id)
        tmp={"bboxes":[], "classes":[], "scores":[], "features":[]}
        for dic in img['objects']:
            bbox = [int(i) for i in dic["bbox"][1:-1].split(',')]
            tmp["bboxes"].append(bbox)
            tmp["classes"].append(dic["class"])
            tmp["scores"].append(dic["score"])
            tmp["features"].append(dic["feature"])

        obj_detect_dict[img_id]=tmp
    return obj_detect_dict

if __name__=="__main__":
    name = "test"
    obj_detection = json.load(open("data/obj_detection_0.1.json", "r"))
    obj_detection_dict = gen_obj_dict(obj_detection)
    with open("data/%s_detect_dict.json"%name, "w") as f:
        json.dump(obj_detection_dict, f)
