import base64
import gzip
import os
import random
import re
import time
import typing
from ast import literal_eval
from collections import Counter, defaultdict
from contextlib import ExitStack
from itertools import zip_longest
from pprint import pformat
from typing import Tuple, Mapping

from nltk import WordNetLemmatizer

from coli.basic_tools.common_utils import set_proc_name, ensure_dir, smart_open
from delphin.mrs import eds, simplemrs
from delphin.mrs.components import links as mrs_links
from coli.hrgguru.eds import ScorerResult, EDSScorer
from coli.hrgguru.extract_LFRG import lfrg_to_mrs
from coli.hrgguru.extract_sync_grammar import ExtractionParams
from coli.hrgguru.graph_readers import mrs_reader
from coli.hrgguru.hrg import CFGRule, HRGRule
from coli.hrgguru.hyper_graph import HyperGraph, GraphNode, HyperEdge
from coli.basic_tools.logger import log_to_file, logger, default_logger
from coli.span.const_tree import ConstTree, Lexicon
from coli.hrgguru.const_tree import Lexicon as HLexicon


def eds_for_smatch(sent_id, e):
    nodes = [(node.nodeid, str(node.pred)) for node in e.nodes()]

    edges = [(node.nodeid, target, label)
             for node in e.nodes()
             for label, target in e.edges(node.nodeid).items()]

    return "#{}\n{}\n".format(sent_id, len(nodes)) + \
           "\n".join(" ".join(i) for i in nodes) + "\n" + \
           "{}\n".format(len(edges)) + \
           "\n".join(" ".join(i) for i in edges) + "\n\n"


def mrs_for_smatch(sent_id, m):
    nodes = [(str(node.nodeid), str(node.pred)) for node in m.eps()]

    edges = [(str(start), str(end), rargname + "/" + post)
             for start, end, rargname, post in mrs_links(m)
             if start != 0]

    return "#{}\n{}\n".format(sent_id, len(nodes)) + \
           "\n".join(" ".join(i) for i in nodes) + "\n" + \
           "{}\n".format(len(edges)) + \
           "\n".join(" ".join(i) for i in edges) + "\n\n"


def lf_for_smatch(sent_id, mrs_obj, include_qeq=True):
    nodes = set()
    edges = []
    for ep in mrs_obj.eps():
        name = base64.b64encode(os.urandom(15)).decode("ascii")
        nodes.add((name, ep.pred.string))
        nodes.add((ep.label, "None"))
        edges.append((name, ep.label, "LBL"))
        for arg_name, node_name in ep.args.items():
            nodes.add((node_name, "None"))
            edges.append((name, node_name, arg_name))

    if include_qeq:
        for hcon in mrs_obj.hcons():
            nodes.add((hcon.hi, "None"))
            nodes.add((hcon.lo, "None"))
            edges.append((hcon.hi, hcon.lo, "QEQ"))

    return "#{}\n{}\n".format(sent_id, len(nodes)) + \
           "\n".join(" ".join(i) for i in nodes) + "\n" + \
           "{}\n".format(len(edges)) + \
           "\n".join(" ".join(i) for i in edges) + "\n\n"


def format_node(node_name, pred_edge, with_span=True):
    return f"{node_name}@{pred_edge.span[0]},{pred_edge.span[1]}"


def to_nodes_and_edges(hg, return_spans=False, node_name_with_span=True):
    # draw eds
    node_mapping = {}
    real_edges = []
    nodes = []
    edges = []
    for edge in hg.edges:  # type: HyperEdge
        if len(edge.nodes) == 1:
            main_node = edge.nodes[0]  # type: GraphNode
            if node_mapping.get(main_node) is None:
                node_mapping[main_node] = edge
            else:
                print("Dumplicate node name {} and {}!".format(
                    node_mapping[main_node],
                    edge.label
                ))
        elif len(edge.nodes) == 2:
            real_edges.append(edge)
        else:
            print("Invalid hyperedge with node count {}".format(len(edge.nodes)))

    for node, pred_edge in node_mapping.items():
        assert pred_edge.span is not None
        new_name = format_node(node.name, pred_edge, node_name_with_span)
        if not return_spans:
            nodes.append((new_name, pred_edge.label))
        else:
            nodes.append((new_name, pred_edge.label, pred_edge.span))

    for edge in real_edges:
        node_1, node_2 = edge.nodes
        pred_edge_1, pred_edge_2 = pred_edges = [node_mapping.get(i) for i in edge.nodes]
        if any(i is None for i in pred_edges):
            print("No span for edge {}, nodes {}!".format(edge, pred_edges))
            continue
        edges.append((
            format_node(node_1.name, pred_edge_1, node_name_with_span),
            format_node(node_2.name, pred_edge_2, node_name_with_span),
            edge.label))
    return nodes, edges


def output_hg(sent_id: str, hg: HyperGraph):
    nodes, edges = to_nodes_and_edges(hg)

    return "#{}\n{}\n".format(sent_id, len(nodes)) + \
           "\n".join(" ".join(i) for i in nodes) + "\n" + \
           "{}\n".format(len(edges)) + \
           "\n".join(" ".join(i) for i in edges) + "\n\n"


class HRGParserTrainingMixin(object):
    DataType = ConstTree
    train: typing.Callable
    save: typing.Callable
    load: typing.Callable
    predict_and_output: typing.Callable

    @classmethod
    def train_parser(cls, options, data_train=None, data_dev=None, data_test=None):
        set_proc_name(options.title)
        ensure_dir(options.output)
        path = os.path.join(options.output, "{}_{}_train.log".format(
            options.title,
            int(time.time())))
        log_to_file(path)
        logger.name = options.title

        logger.info('Options:\n%s', pformat(options.__dict__))
        if data_train is None:
            data_train = cls.DataType.from_file(options.train)

        if data_dev is None:
            data_dev = {i: cls.DataType.from_file(i, False) for i in options.dev}

        try:
            os.makedirs(options.output)
        except OSError:
            pass

        parser = cls(options, data_train)
        random_obj = random.Random(1)

        def do_predict(epoch):
            for file_name, dev_sentences in data_dev.items():
                try:
                    prefix, suffix = os.path.basename(file_name).rsplit(".", 1)
                except ValueError:
                    prefix = file_name
                    suffix = ""

                dev_output = os.path.join(options.output, '{}_epoch_{}.{}'.format(prefix, epoch, suffix))
                cls.predict_and_output(parser, options, dev_sentences, dev_output)

        if options.epochs == 0:
            print("Predict directly.")
            do_predict(0)

        for epoch in range(options.epochs):
            logger.info('Starting epoch %d', epoch)
            random_obj.shuffle(data_train)
            parser.train(data_train)

            # save model and delete old model
            for i in range(0, epoch - options.max_save):
                path = os.path.join(options.output, os.path.basename(options.model)) + str(i + 1)
                if os.path.exists(path):
                    os.remove(path)
            path = os.path.join(options.output, os.path.basename(options.model)) + str(epoch + 1)
            parser.save(path)
            do_predict(epoch)

    @classmethod
    def predict_with_parser(cls, options):
        if options.input_format == "standard":
            data_test = cls.DataType.from_file(options.test, False)
        elif options.input_format == "space":
            with smart_open(options.test) as f:
                data_test = [cls.DataType.from_words_and_postags([(word, "X") for word in line.strip().split(" ")])
                             for line in f]
        elif options.input_format == "english":
            from nltk import download, sent_tokenize
            from nltk.tokenize import TreebankWordTokenizer
            download("punkt")
            with smart_open(options.test) as f:
                raw_sents = sent_tokenize(f.read().strip())
                tokenized_sents = TreebankWordTokenizer().tokenize_sents(raw_sents)
                data_test = [cls.DataType.from_words_and_postags([(token, "X") for token in sent])
                             for sent in tokenized_sents]
        elif options.input_format == "tokenlist":
            with smart_open(options.test) as f:
                items = eval(f.read())
            data_test = cls.DataType.from_words_and_postags(items)
        else:
            raise ValueError("invalid format option")

        logger.info('Initializing...')
        parser = cls.load(options.model, options)

        ts = time.time()
        cls.predict_and_output(parser, options, data_test, options.output)
        te = time.time()
        logger.info('Finished predicting and writing test. %.2f seconds.', te - ts)


class HRGParserMixin(object):
    grammar: Mapping[tuple, Mapping[CFGRule, int]]
    lemmatizer: WordNetLemmatizer
    lexicon_mapping: Mapping[Tuple[HLexicon, int], typing.Counter]
    terminal_mapping: Mapping[str, typing.Counter]
    gold_graphs: dict
    options: typing.Any

    pattern_number = re.compile(r"^[0-9.,]+$")

    @staticmethod
    def filter_rules(rules_counter, lexical_label):
        new_result = Counter()
        for rule, count in rules_counter.items():
            may_be_lexical_edge = rule.rhs[0][1]
            if may_be_lexical_edge is None:
                current_label = "None"
            else:
                current_label = may_be_lexical_edge.label
            if current_label == lexical_label:
                new_result[rule] = count
        return new_result

    @staticmethod
    def filter_rules_by_attachments(rules_counter, attachments_labels):
        attachments_labels = set(attachments_labels)
        this_attachments_labels = set()
        new_result = Counter()
        for rule, count in rules_counter.items():
            if len(rule.rhs) == 1:
                may_be_lexical_edge = rule.rhs[0][1]
            else:
                may_be_lexical_edge = None
            for edge in rule.hrg.rhs.edges:
                if edge.is_terminal and len(edge.nodes) == 1:
                    if may_be_lexical_edge is None or edge.label != may_be_lexical_edge.label:
                        this_attachments_labels.add(edge.label)
            if this_attachments_labels == attachments_labels:
                new_result[rule] = count
        return new_result

    def rule_lookup(self, tree_node, is_train,
                    lexical_label_list=(),
                    attachment_bags_list=(),
                    internal_bags_list=()):
        if not isinstance(tree_node.children[0], Lexicon):
            keyword = (tree_node.tag,
                       tuple(i.tag for i in tree_node.children))
            result = self.grammar.get(keyword)
            if result is None:
                raise ValueError(str(tree_node))

            for internal_bag_int in internal_bags_list:
                bags_str = self.tagger.statistics.internal_bags.int_to_word[
                    internal_bag_int]
                attachment_bag = set(bags_str.split(";")) if bags_str else set()
                result_filtered = self.filter_rules_by_attachments(result, attachment_bag)
                if result_filtered:
                    # print(f"Use {attachment_bag} for {tree_node}{tree_node.span}")
                    return result_filtered
        else:
            lexicon = tree_node.children[0].string
            tag = tree_node.tag
            # get lexicon-specific rules
            if self.options.word_threshold >= 0:
                result = self.lexicon_to_graph.get((lexicon, tag))
                # print(f"Get {len(result) if result is not None else 0} lexical rules for {lexicon}")
                # low freq word
                if self.options.word_threshold > 0 and result is not None \
                        and sum(result.values()) < self.options.word_threshold:
                    result = None
                    # print(f"filter out low freq word {lexicon}")
            else:
                result = None

            # get postag-specific rules
            fallback_rules = self.sync_grammar_fallback(tree_node)
            if not result:
                # print(f"{lexicon} @{tag} has no rules, falling back")
                result = fallback_rules
            elif is_train:
                result = result | fallback_rules

            # filter rules by lexical label and attachments
            for lexical_label_int in lexical_label_list:
                lexical_label = self.tagger.statistics.pred_tags.int_to_word[lexical_label_int].split("!!!")[0]

                result_filtered = self.filter_rules(result, lexical_label)

                if is_train or not result_filtered:
                    fallback_rules_filtered = self.filter_rules(fallback_rules, lexical_label)
                    if not result_filtered:
                        result_filtered = fallback_rules_filtered
                    elif is_train:
                        if lexical_label is not None:
                            result_filtered = result_filtered | fallback_rules_filtered

                if result_filtered:
                    # print(f"filter {lexicon} by {lexical_label}")
                    for attachment_bag_int in attachment_bags_list:
                        bags_str = self.tagger.statistics.attachment_bags.int_to_word[
                            attachment_bag_int]
                        attachment_bag = set(bags_str.split(";")) if bags_str else set()
                        result_filtered_2 = self.filter_rules_by_attachments(result_filtered, attachment_bag)
                        if result_filtered_2:
                            # print(f"filter {lexicon} by {lexical_label} and {attachment_bag}")
                            return result_filtered_2

            # for lexical_label_int in lexical_label_list:
            #     lexical_label = self.tagger.statistics.pred_tags.int_to_word[lexical_label_int]
            #     result = self.lexical_label_mapping.get((lexical_label, tag))
            #     if result is not None:
            #         return result
            # raise Exception(f"No lexical entry for {tag}")
        return result

    def sync_grammar_fallback(self, tree_node):
        tag = tree_node.tag
        results = self.postag_mapping.get(tag)
        if not results:
            results = self.sync_grammar_fallback_2(tree_node)
        return results

    def transform_edge(self, edge, lexicon):
        if "NEWLEMMA" in edge.label:
            word = lexicon.string.replace("_", "+")
            if "_u_unknown" in edge.label:
                item = word
            else:
                pos = edge.label[edge.label.find("NEWLEMMA") + 10]
                if pos in ("n", "v", "a"):
                    item = self.lemmatizer.lemmatize(word, pos)
                else:
                    item = self.lemmatizer.lemmatize(lexicon.string.replace("_", "+"))
            new_label = edge.label.format(NEWLEMMA=item)
            # print(edge.label, lexicon, item, new_label)
            return HyperEdge(edge.nodes, new_label,
                             edge.is_terminal, edge.span)
        return edge

    def recover_rule(self, rule, lexicon, tag, eval_word=True):
        if eval_word:
            hrg_part = HRGRule(
                lhs=rule.lhs,
                rhs=HyperGraph(
                    nodes=rule.rhs.nodes,
                    edges=frozenset(self.transform_edge(edge, lexicon)
                                    for edge in rule.rhs.edges)
                ))
            cfg_rule = CFGRule(lhs=tag,
                               rhs=((lexicon, None),),
                               hrg=hrg_part
                               )
        else:
            return rule

    def sync_grammar_fallback_2(self, tree_node):
        print(f"Using default rule for {tree_node}")
        rule_name, main_node_count = tree_node.tag.rsplit("#", 1)
        word = tree_node.children[0].string
        try:
            main_node_count = int(main_node_count)
        except ValueError:
            pass
        if main_node_count == 1 or main_node_count == "s":
            main_node = GraphNode("0")
            surface = tree_node.children[0].string

            if self.pattern_number.match(surface):
                label = "card"
            elif rule_name.find("generic_proper") >= 0:
                label = "named"
            else:
                lemma = self.lemmatizer.lemmatize(word)
                if rule_name.find("n_-_c-pl-unk_le") >= 0:
                    label = "_{}/nns_u_unknown".format(lemma)
                elif rule_name.find("n_-_mc_le") >= 0 or rule_name.find("n_-_c_le") >= 0:
                    label = "_{}_n_1".format(lemma)  # more number is used
                elif rule_name.find("generic_mass_count_noun") >= 0:
                    label = "_{}/nn_u_unknown".format(lemma)  # more number is used
                else:
                    candidates = self.lexicon_mapping.get((HLexicon(word), main_node_count))
                    if candidates:
                        return candidates
                    else:
                        label = "named"

            old_edge = HyperEdge(
                nodes=[main_node],
                label=rule_name,
                is_terminal=False
            )

            main_edge = HyperEdge(
                nodes=[main_node],
                label=label,
                is_terminal=True
            )

            fallback = CFGRule(lhs=rule_name,
                               rhs=((tree_node.children[0], None),),
                               hrg=HRGRule(
                                   lhs=old_edge,
                                   rhs=HyperGraph(
                                       nodes=frozenset([main_node]),
                                       edges=frozenset([main_edge])
                                   )
                               ))
        else:
            ret1 = self.terminal_mapping.get(tree_node.tag)
            if ret1:
                return Counter([ret1.most_common(1)[0][0]])
            connected_nodes = [GraphNode(str(i)) for i in range(main_node_count)]
            centural_node = GraphNode(str(main_node_count + 1))
            old_edge = HyperEdge(
                nodes=connected_nodes,
                label=rule_name,
                is_terminal=False
            )
            main_edges = [HyperEdge(
                nodes=[centural_node, i],
                label="???",
                is_terminal=True
            ) for i in connected_nodes]
            fallback = CFGRule(lhs=rule_name,
                               rhs=((tree_node.children[0], None),),
                               hrg=HRGRule(
                                   lhs=old_edge,
                                   rhs=HyperGraph(
                                       nodes=frozenset(connected_nodes + [centural_node]),
                                       edges=frozenset(main_edges)
                                   )
                               ))
        return Counter([fallback])

    def construct_derivation(
            self, sub_graph_map, correspondents_map, sync_rule_map, tree_node):
        if isinstance(tree_node.children[0], ConstTree):
            yield from self.construct_derivation(
                sub_graph_map, correspondents_map, sync_rule_map, tree_node.children[0])
        if len(tree_node.children) >= 2 and isinstance(tree_node.children[0], ConstTree):
            yield from self.construct_derivation(
                sub_graph_map, correspondents_map, sync_rule_map, tree_node.children[1])
        # ignore empty beam item
        sub_graph = sub_graph_map.get(tree_node)
        if sub_graph:
            correspondents = [i[0] for i in sorted(
                correspondents_map[tree_node],
                key=lambda x: x[1], reverse=True)]
            yield sub_graph, sync_rule_map[tree_node], correspondents
        else:
            yield None, None, []

    def populate_delphin_spans(self, tree, args_and_names=False):
        preterminals = list(tree.generate_preterminals())
        spans = literal_eval(tree.extra["DelphinSpans"])
        assert len(preterminals) == len(spans)

        # if self.options.use_lexical_labels:
        #     lexical_labels = literal_eval(tree.extra["LexicalLabels"])
        #     assert len(lexical_labels) == len(preterminals)
        # else:
        #     lexical_labels = []

        for span, preterminal in zip_longest(spans, preterminals):
            preterminal.extra["DelphinSpan"] = span
            # if self.options.use_lexical_labels:
            #     preterminal.extra["LexicalLabel"] = lexical_label

        for rule in tree.generate_rules():
            if isinstance(rule.children[0], ConstTree):
                rule.extra["DelphinSpan"] = (rule.children[0].extra["DelphinSpan"][0],
                                             rule.children[-1].extra["DelphinSpan"][1])

        if args_and_names:
            args_tuples = literal_eval(tree.extra["Args"])
            tree.extra["DelphinArgsSet"] = frozenset(args_tuples)
            tree.extra["DelphinArgsSetSimple"] = frozenset(
                (i[1], i[3]) for i in args_tuples)
            names_tuples = literal_eval(tree.extra["Names"])
            tree.extra["DelphinNamesSet"] = frozenset(names_tuples)

    def load_gold_graph(self, sent_id):
        gold_fields = self.gold_graphs.get(sent_id)
        if gold_fields is not None:
            return gold_fields

        with gzip.open(self.options.deepbank_dir + "/" + sent_id + ".gz",
                       "rb") as f_gz:
            contents = f_gz.read().decode("utf-8")
        fields = contents.strip().split("\n\n")
        if self.options.graph_type == "eds":
            eds_literal = fields[-2]
            eds_literal = re.sub(r"\{.*\}", "", eds_literal)
            e = eds.loads_one(eds_literal)
            hg = HyperGraph.from_eds(e)
            gold_graph = EDSScorer.from_eds(e, sent_id)
            smatch_field = eds_for_smatch(sent_id, e)
        else:
            assert self.options.graph_type in ("dmrs", "lf")
            m = simplemrs.loads_one(fields[-3])
            gold_graph = EDSScorer.from_mrs(m)
            hg = HyperGraph.from_mrs(m)
            smatch_field = mrs_for_smatch(sent_id, m)

        self.gold_graphs[sent_id] = gold_graph, smatch_field, hg
        return gold_graph, smatch_field, hg

    def evaluate_hg(self, results, output_file):
        total_result = ScorerResult.zero()
        with open(output_file + ".txt", "w") as f, \
                open(output_file + ".graph", "w") as f_graph, \
                open(output_file + ".gold_graph", "w") as f_gold_graph, \
                open(output_file + ".mrs", "w") \
                        if self.options.graph_type == "lf" else ExitStack() as f_mrs:
            for sent_id, graph in results:
                gold_graph, smatch_field, gold_hg = self.load_gold_graph(sent_id)
                if self.options.graph_type != "lf":
                    graph_scorer = EDSScorer.from_hypergraph(graph, sent_id=sent_id)
                else:
                    mrs_literal = lfrg_to_mrs(graph)
                    f_mrs.write(mrs_literal)
                    f_mrs.write("\n\n")
                    graph = dmrs_repr = mrs_reader(lfrg_to_mrs(graph), ExtractionParams(), True)
                    graph_scorer = EDSScorer.from_hypergraph(dmrs_repr, sent_id=sent_id)
                result = graph_scorer.compare_with(
                    gold_graph,
                    True, log_func=lambda x: print(x, file=f))
                f.write(str(result))
                total_result += result
                f_graph.write(output_hg(sent_id, graph))
                f_gold_graph.write(smatch_field)
            print("Total:")
            print(total_result)
            f.write(str(total_result))

        self.smatch_eval(output_file + ".gold_graph", output_file + ".graph")

        return total_result.f1 * 100

    @staticmethod
    def smatch_eval(gold_file, system_file):
        current_path = os.path.dirname(__file__)
        os.system('{}/utils/smatch_1 {} {} {}.smatch {}.mapoutput {}.allmapoutput'.format(
            current_path,
            system_file, gold_file,
            system_file, system_file, system_file))

        with open(f"{system_file}.smatch") as f:
            content = f.read()

        try:
            return float(re.findall(r"Total Smatch (.*)", content)[-1]) * 100
        except IndexError:
            return 0

    # def count_lexicons(self):
    #     postag_mapping = defaultdict(Counter)  # type: Mapping[str, typing.Counter]
    #     lexical_label_mapping = defaultdict(Counter)
    #     for (word, main_node_count), graph_counter in self.lexicon_to_graph.items():
    #         try:
    #             main_node_count = int(main_node_count)
    #         except ValueError:
    #             main_node_count = len(main_node_count)
    #     for keyword, rules_counter in self.grammar.items():
    #         tag = keyword[0]
    #         if isinstance(keyword[1][0], HLexicon) and keyword[1][0].string == "{NEWLEMMA}":
    #             postag_mapping[tag].update(rules_counter)
    #             for rule, count in rules_counter.items():
    #                 may_be_lexical_edge = rule.rhs[0][1]
    #                 lexical_label = may_be_lexical_edge.label \
    #                     if may_be_lexical_edge is not None else "None"
    #                 lexical_label_mapping[lexical_label, tag][rule] += count
    #     return dict(postag_mapping), dict(lexical_label_mapping)

    @classmethod
    def predict_with_parser(cls, options):
        default_logger.info('Loading Model...')
        options.is_train = False
        parser = cls.load(options.model, options)
        parser.logger.info('Model loaded')

        DataFormatClass = cls.get_data_formats()[parser.options.data_format]

        if options.input_format == "standard":
            data_test = DataFormatClass.from_file(options.test, False)
        elif options.input_format == "space":
            with smart_open(options.test) as f:
                data_test = [DataFormatClass.from_words_and_postags([(word, "X") for word in line.strip().split(" ")])
                             for line in f]
        elif options.input_format.startswith("english"):
            from nltk import download, sent_tokenize
            from nltk.tokenize import TreebankWordTokenizer
            download("punkt")
            with smart_open(options.test) as f:
                raw_sents = []
                for line in f:
                    if options.input_format == "english-line":
                        raw_sents.append(line.strip())
                    else:
                        this_line_sents = sent_tokenize(line.strip())
                        raw_sents.extend(this_line_sents)
                tokenized_sents = TreebankWordTokenizer().tokenize_sents(raw_sents)
                data_test = [DataFormatClass.from_words_and_postags([(token, "X") for token in sent])
                             for sent in tokenized_sents]
        elif options.input_format == "tokenlist":
            with smart_open(options.test) as f:
                items = eval(f.read())
            data_test = DataFormatClass.from_words_and_postags(items)
        else:
            raise ValueError("invalid format option")

        ts = time.time()
        results = parser.predict(data_test)
        parser.evaluate_hg(results, options.output)
        te = time.time()
        parser.logger.info('Finished predicting and writing test. %.2f seconds.', te - ts)
