import os
import joblib
import numpy as np
from tqdm import tqdm
import os
import json
import importlib
from tokenizers import Tokenizer
import argparse
from easydict import EasyDict as edict
import yaml
from pprint import pprint
from nltk.corpus import wordnet
from copy import deepcopy
import fasttext.util
from scipy.spatial.distance import cosine
import torch.nn as nn
import torch
import networkx
import igraph
import re
import nltk
from nltk.corpus import wordnet as wn
from nltk.stem import WordNetLemmatizer


def get_config():
    config = edict(yaml.load(open('config.yml'), Loader=yaml.SafeLoader))
    return config

def collect_files(directory):
    all_files = []
    for path, subdirs, files in os.walk(directory):
        for name in files:
            filename = os.path.join(path, name)
            all_files.append(filename)
    return all_files

def get_multisimlex(lang):
    file1 = None
    print('Getting MultiSimLex data...')
    if lang == 'en':
        file1 = open('datasets/multisimlex/ENG.csv', 'r')
    elif lang == 'ar':
        file1 = open('datasets/multisimlex/ARA.csv', 'r')
    elif lang == 'zh':
        file1 = open('datasets/multisimlex/CMN.csv', 'r')
    elif lang == 'fi':
        file1 = open('datasets/multisimlex/FIN.csv', 'r')
    elif lang == 'fr':
        file1 = open('datasets/multisimlex/FRA.csv', 'r')
    elif lang == 'he':
        file1 = open('datasets/multisimlex/HEB.csv', 'r')
    elif lang == 'pl':
        file1 = open('datasets/multisimlex/POL.csv', 'r')
    elif lang == 'ru':
        file1 = open('datasets/multisimlex/RUS.csv', 'r')
    elif lang == 'es':
        file1 = open('datasets/multisimlex/SPA.csv', 'r')

    Lines = file1.readlines()
    word_pairs = {}
    unique_words = []
    messed_up_annotation_count = 0
    for line in tqdm(Lines[1:]):
        line = line[:-1]  # remove the newline character
        pieces = line.split(',')
        ID = pieces[0]
        word_pairs[ID] = {'word1': pieces[1],
                          'word2': pieces[2],
                          'POS_tag': pieces[3],
                          'annotator_scores': pieces[4:]}
        try:
            # TODO: update this so we take mean (there is one pair for some langs with a missing value)
            score_list = []
            for x in word_pairs[ID]['annotator_scores']:
                try:
                    score_list.append(float(x))
                except:
                    """"""
            total_score = np.mean(np.asarray(score_list))
            # total_score = sum([float(x) for x in word_pairs[ID]['annotator_scores']])
            for word in [word_pairs[ID]['word1'], word_pairs[ID]['word2']]:
                if word not in unique_words:
                    unique_words.append(word)
            word_pairs[ID]['total_score'] = total_score
        except:
            """There was a missing score"""
            messed_up_annotation_count += 1
    print("Messed up annotation count: " + str(messed_up_annotation_count))
    return word_pairs, unique_words

def get_crosslingual_multisimlex(lang_pair):
    config = get_config()
    file1 = None
    print('Getting Cross-Lingual MultiSimLex data...')
    multisimlex_code_dict = {'ar': 'ARA', 'en': 'ENG', 'es': 'SPA', 'fi': 'FIN', 'fr': 'FRA',
                             'he': 'HEB', 'pl': 'POL', 'ru': 'RUS', 'zh': 'CMN'}
    flip_code_dict = {'ARA': 'ar', 'ENG': 'en', 'SPA': 'es', 'FIN': 'fi', 'FRA': 'fr', 'HEB': 'he',
                      'POL': 'pl', 'RUS': 'ru', 'CMN': 'zh'}
    lang1, lang2 = lang_pair.split('_')
    # lang1, lang2 = sorted([lang1, lang2])
    file1_ending = multisimlex_code_dict[lang1] + "-" + multisimlex_code_dict[lang2] + '.csv'
    file2_ending = multisimlex_code_dict[lang2] + "-" + multisimlex_code_dict[lang1] + '.csv'
    cross_lingual_file1 = os.path.join(config.directories.cross_lingual_multisimlex, file1_ending)
    cross_lingual_file2 = os.path.join(config.directories.cross_lingual_multisimlex, file2_ending)
    if os.path.exists(cross_lingual_file1):
        file1 = open(cross_lingual_file1, 'r')
        src = lang1
        tgt = lang2
    elif os.path.exists(cross_lingual_file2):
        file1 = open(cross_lingual_file2, 'r')
        src = lang2
        tgt = lang1
    else:
        return None, None, None, None, None

    Lines = file1.readlines()
    word_pairs = {}
    unique_words_src = []
    unique_words_tgt = []
    messed_up_annotation_count = 0
    header = Lines[0]
    src_lang = flip_code_dict[header.split(",")[1]]
    tgt_lang = flip_code_dict[header.split(",")[2]]
    assert src == src_lang
    assert tgt == tgt_lang
    for line in tqdm(Lines[1:]):
        line = line[:-1]  # remove the newline character
        pieces = line.split(',')
        ID = pieces[0]
        word_pairs[ID] = {'word1': pieces[1],
                          'word2': pieces[2],
                          'POS_tag': pieces[3],
                          'annotator_score': pieces[4]}
        try:
            total_score = float(word_pairs[ID]['annotator_score'])
            # for word in [word_pairs[ID]['word1'], word_pairs[ID]['word2']]:
            if word_pairs[ID]['word1'] not in unique_words_src:
                unique_words_src.append(word_pairs[ID]['word1'])
            if word_pairs[ID]['word2'] not in unique_words_tgt:
                unique_words_tgt.append(word_pairs[ID]['word2'])
            word_pairs[ID]['total_score'] = total_score
        except:
            """There was a missing score"""
            messed_up_annotation_count += 1
    print("Messed up annotation count: " + str(messed_up_annotation_count))
    return word_pairs, unique_words_src, unique_words_tgt, src, tgt

def get_lancaster_norms(normalize=False):
    file1 = open('datasets/Lancaster_sensorimotor_norms_for_39707_words.csv', 'r')
    Lines = file1.readlines()
    """
    Auditory.mean
    Gustatory.mean
    Haptic.mean
    Interoceptive.mean
    Olfactory.mean
    Visual.mean
    Foot_leg.mean
    Hand_arm.mean
    Head.mean
    Mouth.mean
    Torso.mean
    """
    headers = Lines[0]
    word_data = {}
    unique_words = []
    for line in tqdm(Lines[1:]):
        line = line[:-1]  # remove the newline character
        pieces = line.split(',')
        word = pieces[0]
        word_data[word] = {'Auditory.mean': float(pieces[1]),
                           'Gustatory.mean': float(pieces[2]),
                           'Haptic.mean': float(pieces[3]),
                           'Interoceptive.mean': float(pieces[4]),
                           'Olfactory.mean': float(pieces[5]),
                           'Visual.mean': float(pieces[6]),
                           'Foot_leg.mean': float(pieces[7]),
                           'Hand_arm.mean': float(pieces[8]),
                           'Head.mean': float(pieces[9]),
                           'Mouth.mean': float(pieces[10]),
                           'Torso.mean': float(pieces[11]),
                           }
        """Also convert to vector"""
        vector = [word_data[word]['Auditory.mean'],
                  word_data[word]['Gustatory.mean'],
                  word_data[word]['Haptic.mean'],
                  word_data[word]['Interoceptive.mean'],
                  word_data[word]['Olfactory.mean'],
                  word_data[word]['Visual.mean'],
                  word_data[word]['Foot_leg.mean'],
                  word_data[word]['Hand_arm.mean'],
                  word_data[word]['Head.mean'],
                  word_data[word]['Mouth.mean'],
                  word_data[word]['Torso.mean'],
                  ]
        vector = np.asarray(vector)
        if normalize:
            vector = vector / np.linalg.norm(vector)
            vector = vector - np.mean(vector)
            vector = vector / np.linalg.norm(vector)
        word_data[word]['embed'] = vector
    return word_data

def dynamic_import(module):
    module_path, module_class = module.rsplit('.', maxsplit=1)
    module = importlib.import_module(module_path)
    module_class = getattr(module, module_class)
    return module_class

def load(file_path, file_type=None):
    if file_type == 'json' or file_path.endswith('.json'):
        import json
        with open(file_path, 'r') as f:
            data = json.load(f)
    elif file_type == 'pkl' or file_path.endswith('.pkl'):
        import pickle
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
    else:
        raise ValueError('file type {file_path} not implemented')

    return data

def dump(data, file_path, file_type=None):
    if file_type == 'json' or file_path.endswith('.json'):
        import json
        with open(file_path, 'w') as f:
            json.dump(data, f)
    elif file_type == 'pkl' or file_path.endswith('.pkl'):
        import pickle
        with open(file_path, 'wb') as f:
            pickle.dump(data, f)
    else:
        raise ValueError('file type {file_path} not implemented')

def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def divide_index(size, n):
    # evenly divide the index set to n parts
    chunk_size = size // n + 1
    x = [list(range(i, min(i+chunk_size, size))) for i in range(0, size, chunk_size)]
    return x

def parse_config(args):
    config_path = args.config
    config = edict(yaml.load(open(config_path), Loader=yaml.SafeLoader))
    config.specific_args = edict(vars(args))
    if args.use_half:
        config.use_half = args.use_half
    pprint(config)
    return config

def get_synets(word, lang, wn, multi_wn):
    wordnet_lang = None
    if lang == 'es':
        wordnet_lang = 'spa'
    elif lang == 'zh-cn':
        wordnet_lang = 'cmn'
    elif lang == 'ja':
        wordnet_lang = 'jpn'
    elif lang == 'de':
        wordnet_lang = 'de'
    elif lang == 'en':
        wordnet_lang = 'en'
    if wordnet_lang == 'en':
        synsets = wn.synsets(word)
    else:
        synsets = wn.synsets(word, lang=wordnet_lang)
    our_synsets = []
    for syn in synsets:
        syn_identifier = syn._name
        synset_to_append = multi_wn[lang][syn_identifier]
        our_synsets.append(synset_to_append)
    return our_synsets

def write_file_from_list(list, path):
    with open(path, 'w') as f:
        for item in list:
            f.write("%s\n" % item)

# def fix_weird_spaces(word):
#     """Dumb stuff to fix basic parsing stuff for multiple-word words in multisimlex (random spaces in there,
#        mostly for AR ---> UPDATE: didn't use this because it was only leaving out 5ish words for some languages)"""
#     split_word = word.split(" ")
#     word_count = 0
#     real_words = []
#     # if len(split_word) > 1:
#     #     stop = None
#     for chunk in split_word:
#         if chunk != '':
#             word_count += 1
#             real_words.append(chunk)
#     if word_count > 1:
#         word = " ".join(real_words)
#         # word = word.replace(" ", "~")
#     else:
#         word = word.replace(" ", "")
#     return word

def wordnet_lang_converter(lang):
    if lang == 'en':
        return "eng"
    elif lang == 'ar':
        return "arb"
    elif lang == 'zh':
        return "cmn"
    elif lang == 'fi':
        return "fin"
    elif lang == 'fr':
        return "fra"
    elif lang == 'he':
        return "heb"
    elif lang == 'pl':
        return "pol"
    elif lang == 'ru':
        return "rus"
    elif lang == 'es':
        return "spa"

def read_edgelist(edgelist_file, weighted=True):
    '''
    Reads the input edgelist_file into a dictionary.
    '''
    file1 = open(edgelist_file, 'r')
    Lines = file1.readlines()
    edgelist = {}  # initialize edgelist
    for line in tqdm(Lines):
        line = line[:-1]  # remove the newline character
        pieces = line.split(" ")
        node1 = pieces[0]
        node2 = pieces[1]
        weight = float(pieces[2])
        edge = node1 + "_" + node2
        if weighted:
            edge_weight = weight
        else:
            edge_weight = 1.0
        edgelist[edge] = edge_weight
    return edgelist






