import os

from BERT_sister_terms_similarity.pedersen_similarities import Comparator, SimilarityFunction
from utility.randomfixedseed import Random
from BERT_utility.words_in_synset import SynsetCouple, SynsetOOVCouple, SaverSynsetCouples, ReaderSynsetCouples, \
    ReaderSynsetOOVCouple

from utility.word_in_vocabulary import WNManager, Checker
from nltk.corpus import wordnet as wn


class Picker:
    def __init__(self, checker):
        self.checker = checker

        self.ALL_NAMES = [x for x in wn.all_synsets('n')]
        self.ALL_VERBS = [x for x in wn.all_synsets('v')]

    def pick_from(self, s1, w1, similar=True):
        if similar:
            return self._similar_word_to(s1, w1)
        else:
            return self._dissimilar_word_to(s1, w1)

    def _similar_word_to(self, s1, w1):
        hypernyms = s1.hypernyms()
        if len(hypernyms) == 0:
            return None, None

        # see hypernyms_sister_term_choice file to justify this
        sister_synss = hypernyms[0].hyponyms()
        if s1 in sister_synss:
            sister_synss.remove(s1)

        if len(sister_synss) == 0:
            return None, None
        s2 = Random.randomchoice(sister_synss)
        in_voc = [lemma for lemma in s2.lemma_names() if lemma != w1 and
                  not WNManager.is_expression(lemma) and self.checker.is_in_vocabulary(lemma)]

        if len(in_voc) == 0:
            return None, None
        w2 = Random.randomchoice(in_voc)
        return s2, w2

    def _dissimilar_word_to(self, s1, w1):
        if s1.pos() == wn.NOUN:
            syns = self.ALL_NAMES
        else:
            syns = self.ALL_VERBS

        self_found = False
        i = 0
        while i < 7:
            if len(syns) == 0:
                return None, None
            s2 = Random.randomchoice(syns)
            if s1 == s2:
                self_found = True
                syns.remove(s2)
                continue

            syns.remove(s2)

            in_voc = [x for x in s2.lemma_names() if x != w1 and
                      not WNManager.is_expression(x) and self.checker.is_in_vocabulary(x)]
            if len(in_voc) != 0:
                w2 = Random.randomchoice(in_voc)
                return s2, w2
            i += 1

        if self_found:
            syns.append(s1)
        return None, None


def get_couples_from(words, picker: Picker, similar=True, output_path=None):
    couples = []
    for w1 in words:
        for pos in ['n', 'v']:
            ss = wn.synsets(w1, pos=pos)
            if len(ss) > 0:
                s1 = ss[0]
                s2, w2 = picker.pick_from(s1, w1, similar=similar)
                if s2 is not None:
                    couples.append(SynsetCouple(s1, w1, s2, w2, s1.pos()))
                    #print(similar, len(couples))
    print(output_path, similar, len(couples))
    if output_path is not None:
        header = '\t'.join(['S1', 'S2', 'W1', 'W2', 'S1_POS', '#\n'])
        SaverSynsetCouples.save(couples, output_path, header)
    return couples


def retrieve_synset_couples_divided_by_value_of_similarity(positive_input_path, negative_input_path, measure_name):
    similarity_function = SimilarityFunction.by_name(measure_name)
    ordered_couple = {}

    positive_couples = ReaderSynsetCouples.read(positive_input_path)
    for couple in positive_couples:
        similarity_value = similarity_function(couple.s1, couple.s2)
        if similarity_value not in ordered_couple:
            ordered_couple[similarity_value] = []
        ordered_couple[similarity_value].append(couple)

    negative_couples = ReaderSynsetCouples.read(negative_input_path)
    for couple in negative_couples:
        similarity_value = similarity_function(couple.s1, couple.s2)
        if similarity_value not in ordered_couple:
            ordered_couple[similarity_value] = []
        ordered_couple[similarity_value].append(couple)

    return ordered_couple


class OOVSisterTerms_LineReader(object):
    def readline(self, line):
        s1_index = 5
        w1_index = 1
        s2_index = 9
        w2_index = 10
        s_pos_index = 6

        value = float(line[11])
        oov = line[w1_index]
        synset_oov = line[s1_index]
        first = line[2:4]
        second = line[w2_index]
        synset_second = line[s2_index]
        target_pos = line[s_pos_index]
        w1_pos = line[s_pos_index + 1]
        w2_pos = line[s_pos_index + 2]

        return value, oov, synset_oov, first, second, synset_second, target_pos, w1_pos, w2_pos


class OOVSisterTerms_LineReader_V2(object):
    def readline(self, line):
        s1_index = 0
        w1_index = 1
        s2_index = 8
        w2_index = 9
        s_pos_index = 4

        value = float(line[10])
        oov = line[w1_index]
        synset_oov = line[s1_index]
        first = line[2:4]
        second = line[w2_index]
        synset_second = line[s2_index]
        target_pos = line[s_pos_index]
        w1_pos = line[s_pos_index + 1]
        w2_pos = line[s_pos_index + 2]

        return value, oov, synset_oov, first, second, synset_second, target_pos, w1_pos, w2_pos


def retrieve_oov_couples_divided_by_value_of_similarity(input_path, version=None):
    if version is None:
        reader = OOVSisterTerms_LineReader()
    else:
        reader = OOVSisterTerms_LineReader_V2()

    ordered_couples = {}
    with open(input_path, 'r+') as input_file:
        while True:
            line = input_file.readline()
            if not line:
                break

            value, oov, synset_oov, first, second, synset_second, target_pos, w1_pos, w2_pos = reader.readline(
                line.split('\t'))

            s_oov = SynsetOOVCouple(oov, synset_oov, first, second, synset_second, target_pos, w1_pos, w2_pos)
            if value not in ordered_couples.keys():
                ordered_couples[value] = []

            ordered_couples[value].append(s_oov)

    return ordered_couples


def positive_negative_in_voc_synset_couples_from(positive_input_path, negative_input_path):
    positive_couples = ReaderSynsetCouples.read(positive_input_path)
    negative_couples = ReaderSynsetCouples.read(negative_input_path)
    return positive_couples, negative_couples


def positive_negative_comparable_synset_couples_from(positive_input_path, negative_input_path,
                                                     s1_index, w1_index,
                                                     s2_index, w2_index, first_indexes,
                                                     s_pos_index, w1_pos, w2_pos,
                                                     exclude_first):
    positive_couples = ReaderSynsetCouples.read(positive_input_path, s1_index=s1_index, w1_index=w1_index,
                                                s2_index=s2_index,  w2_index=w2_index, s_pos_index=s_pos_index,
                                                exclude_first=exclude_first)
    negative_couples = ReaderSynsetCouples.read(negative_input_path, s1_index=s1_index, w1_index=w1_index,
                                                s2_index=s2_index,  w2_index=w2_index, s_pos_index=s_pos_index,
                                                exclude_first=exclude_first)
    return positive_couples, negative_couples


def positive_negative_oov_synset_couples_from(positive_input_path, negative_input_path):
    positive_couples = ReaderSynsetCouples.read(positive_input_path)
    negative_couples = ReaderSynsetCouples.read(negative_input_path)
    return positive_couples, negative_couples


def positive_negative_oov_synset_couples_from(positive_input_path, negative_input_path,
                                              s1_index=5, w1_index=1, s2_index=9, w2_index=10, first_indexes=[2, 4],
                                              s_pos_index=6, w1_pos=7, w2_pos=8, exclude_first=False):
    positive_couples = ReaderSynsetOOVCouple.read(positive_input_path, s1_index=s1_index, w1_index=w1_index,
                                                  s2_index=s2_index, w2_index=w2_index, first_indexes=first_indexes,
                                                  s_pos_index=s_pos_index, w1_pos=w1_pos, w2_pos=w2_pos,
                                                  exclude_first=exclude_first)
    negative_couples = ReaderSynsetOOVCouple.read(negative_input_path, s1_index=s1_index, w1_index=w1_index,
                                                  s2_index=s2_index, w2_index=w2_index, first_indexes=first_indexes,
                                                  s_pos_index=s_pos_index, w1_pos=w1_pos, w2_pos=w2_pos,
                                                  exclude_first=exclude_first)
    return positive_couples, negative_couples


def compare_couples(couples, similarity_function, similarity_output_path, header):
    comparator = Comparator(couples, similarity_function)
    comparator.write_similarities(similarity_output_path, header)

    return comparator.get_similarities()


def voc_sim(couples_output_dir='data/similarity_pedersen_test/sister_terms',
            model_name=None, binary=True):
    checker = Checker.get_instance_from_path(model_name, binary=binary)
    picker = Picker(checker)

    positive_couples = get_couples_from(checker.model.vocab, picker=picker, similar=True,
                                        output_path=couples_output_dir + '/in_voc_sister_terms_positive.txt')
    negative_couples = get_couples_from(checker.model.vocab, picker=picker, similar=False,
                                        output_path=couples_output_dir + '/in_voc_sister_terms_negative.txt')
    return positive_couples, negative_couples


def oov_sim(couples_output_dir='data/similarity_pedersen_test/sister_terms',
            model_name=None, binary=True):
    wn_manager = WNManager()
    checker = Checker.get_instance_from_path(model_name, binary=binary)
    picker = Picker(checker)

    oovs = checker.get_OOV(wn_manager.lemma_from_synsets(allow_expression=False))
    positive_couples = get_couples_from(oovs, picker=picker, similar=True,
                                        output_path=couples_output_dir + '/oov_sister_terms_positive.txt')
    negative_couples = get_couples_from(oovs, picker=picker, similar=False,
                                        output_path=couples_output_dir + '/oov_sister_terms_negative.txt')

    return positive_couples, negative_couples


def compute_sister_terms(seeds, sister_terms_output_path, model_name='bert-base-uncased', check_if_computed=True):
    for seed in seeds:
        print(seed)
        Random.set_seed(int(seed))
        seed_dir = 'seed_' + seed
        couples_output_dir = os.path.join(sister_terms_output_path, seed_dir)

        if not os.path.exists(couples_output_dir):
            os.mkdir(couples_output_dir)

        if not check_if_computed or len(os.listdir(couples_output_dir)) == 0:
            voc_sim(couples_output_dir=couples_output_dir,
                    model_name=model_name)


def compute_oov_sister_terms(seeds, sister_terms_output_path, model_name='bert-base-uncased', check_if_computed=True):
    for seed in seeds:
        Random.set_seed(int(seed))
        seed_dir = 'seed_' + seed
        couples_output_dir = os.path.join(sister_terms_output_path, seed_dir)

        if not os.path.exists(couples_output_dir):
            os.mkdir(couples_output_dir)

        if not check_if_computed or len(os.listdir(couples_output_dir)) == 0:
            oov_sim(couples_output_dir=couples_output_dir,
                    model_name=model_name)
