import re
from abc import ABCMeta, abstractmethod
from typing import List, Dict, Sequence, Optional, Any

from dataclasses import dataclass

from coli.hrgguru.const_tree import ConstTree
from coli.hrgguru.hrg import CFGRule, DerivationPrecursor
from coli.hrgguru.hyper_graph import HyperEdge, GraphNode

ExtraTagType = Optional[Sequence[Any]]


class LabelerBase(object, metaclass=ABCMeta):
    @classmethod
    @abstractmethod
    def from_derivation_precursors(cls, derivation_precursors: List[DerivationPrecursor]):
        pass

    @classmethod
    def format_extra_tag(cls, extra_tag: ExtraTagType):
        if extra_tag is None:
            return "None"
        else:
            return "".join(str(i) for i in extra_tag)

    @abstractmethod
    def rewrite_cfg_label(self, root: ConstTree, derivations: List[CFGRule], *args):
        pass
        # for rule in root.generate_rules():
        #     rule.tag = rule.tag + "#" + cls.format_extra_tag(extra_tag_map[root])


@dataclass
class EPCountLabeler(LabelerBase):
    extra_tags: List[int]

    @classmethod
    def from_derivation_precursors(cls, derivation_precursors: List[DerivationPrecursor]):
        extra_tags = []
        for precursor in derivation_precursors:
            extra_tags.append(len(precursor.hrg.external_nodes) \
                                  if precursor is not None else 0)
        return cls(extra_tags)

    def rewrite_cfg_label(self, root: ConstTree, derivations: List[CFGRule], *args):
        for rule, extra_tag in zip(root.generate_rules(), self.extra_tags):
            rule.tag = rule.tag + "#" + str(extra_tag)


@dataclass
class EPPredLabeler(LabelerBase):
    extra_tags: List[Dict[GraphNode, str]]

    @classmethod
    def from_derivation_precursors(cls, derivation_precursors: List[DerivationPrecursor]):
        extra_tags: List[Optional[Dict[GraphNode, str]]] = []
        for precursor in derivation_precursors:
            if precursor is None:
                extra_tags.append(None)
            else:
                # n: no pred edge at this ep
                # s: has pred edge at this ep
                eps = precursor.hrg.external_nodes
                status = {i: "n" for i in eps}
                # terminal pred edge
                for edge in precursor.hrg.all_edges:
                    if edge.nodes[0] in status and len(edge.nodes) == 1 and edge.is_terminal:
                        status[edge.nodes[0]] = "s"  # propagate from left subtree
                # propagate from left subtree
                if len(precursor.cfg.children) >= 1:
                    left_edge = precursor.left_child_edge
                    if isinstance(left_edge, HyperEdge):
                        left_tree = precursor.cfg.children[0]
                        left_extra_tags = extra_tags[left_tree.postorder_idx]
                        if left_extra_tags is None:
                            continue
                        assert len(left_edge.nodes) == len(left_extra_tags), \
                            "{} != {} when left={}".format(len(left_edge.nodes), len(left_extra_tags), left_extra_tags)
                        for node in left_edge.nodes:
                            if node in status:
                                has_semantic = left_extra_tags[node]
                                if has_semantic == "s":
                                    status[node] = "s"
                # propagate from right subtree
                if len(precursor.cfg.children) >= 2:
                    right_edge = precursor.right_child_edge
                    if isinstance(right_edge, HyperEdge):
                        right_tree = precursor.cfg.children[1]
                        right_extra_tags = extra_tags[right_tree.postorder_idx]
                        if right_extra_tags is None:
                            continue
                        assert len(right_edge.nodes) == len(right_extra_tags)
                        for node in right_edge.nodes:
                            if node in status:
                                has_semantic = right_extra_tags[node]
                                if has_semantic == "s":
                                    status[node] = "s"
                extra_tags.append(status)
        return cls(extra_tags)

    def rewrite_cfg_label(self, root: ConstTree, derivations: List[CFGRule], original_node_map, *args):
        for step, (rule, sync_rule, extra_tag_dict) in enumerate(zip(root.generate_rules(), derivations, self.extra_tags)):
            if extra_tag_dict is None:
                extra_tag = None
            else:
                assert sync_rule.hrg is not None
                extra_tag = [extra_tag_dict[original_node_map[step][i]] for i in sync_rule.hrg.lhs.nodes]
            rule.tag = rule.tag + "#" + self.format_extra_tag(extra_tag)


@dataclass
class QLabeler(LabelerBase):
    extra_tags: List[Dict[GraphNode, str]]

    @classmethod
    def from_derivation_precursors(cls, derivation_precursors: List[DerivationPrecursor]):
        extra_tags: List[Optional[Dict[GraphNode, str]]] = []
        for precursor in derivation_precursors:
            if precursor is None:
                extra_tags.append(None)
            else:
                # y: a quantifier points to this point
                # n: no quantifier points to this point
                # q: this point is a quantifier
                eps = precursor.hrg.external_nodes
                status = {i: "n" for i in eps}
                quantifier_pred_nodes = set()
                # terminal pred edge
                for edge in precursor.hrg.all_edges:
                    if edge.label.endswith("_q") and edge.is_terminal and len(edge.nodes) == 1:
                        if edge.nodes[0] in eps:
                            status[edge.nodes[0]] = "q"
                        quantifier_pred_nodes.add(edge.nodes[0])
                for edge in precursor.hrg.all_edges:
                    if edge.label == "RSTR/H" and edge.is_terminal \
                            and len(edge.nodes) == 2 and edge.nodes[0] in quantifier_pred_nodes:
                        if edge.nodes[1] in eps:
                            status[edge.nodes[1]] = "y"
                # propagate from left subtree
                if len(precursor.cfg.children) >= 1:
                    left_edge = precursor.left_child_edge
                    if isinstance(left_edge, HyperEdge):
                        left_tree = precursor.cfg.children[0]
                        left_extra_tags = extra_tags[left_tree.postorder_idx]
                        if left_extra_tags is None:
                            continue
                        assert len(left_edge.nodes) == len(left_extra_tags), \
                            "{} != {} when left={}".format(len(left_edge.nodes), len(left_extra_tags), left_extra_tags)
                        for node in left_edge.nodes:
                            if node in status:
                                sub_status = left_extra_tags[node]
                                if sub_status != "n":
                                    status[node] = sub_status
                # propagate from right subtree
                if len(precursor.cfg.children) >= 2:
                    right_edge = precursor.right_child_edge
                    if isinstance(right_edge, HyperEdge):
                        right_tree = precursor.cfg.children[1]
                        right_extra_tags = extra_tags[right_tree.postorder_idx]
                        if right_extra_tags is None:
                            continue
                        assert len(right_edge.nodes) == len(right_extra_tags)
                        for node in right_edge.nodes:
                            if node in status:
                                sub_status = right_extra_tags[node]
                                if sub_status != "n":
                                    status[node] = sub_status
                extra_tags.append(status)
        return cls(extra_tags)

    def rewrite_cfg_label(self, root: ConstTree, derivations: List[CFGRule], original_node_map, *args):
        for step, (rule, sync_rule, extra_tag_dict) in enumerate(zip(root.generate_rules(), derivations, self.extra_tags)):
            if extra_tag_dict is None:
                extra_tag = None
            else:
                assert sync_rule.hrg is not None
                extra_tag = [extra_tag_dict[original_node_map[step][i]] for i in sync_rule.hrg.lhs.nodes]
            rule.tag = rule.tag + "#" + self.format_extra_tag(extra_tag)


@dataclass
class EPVariablesLabeler(EPPredLabeler):
    @classmethod
    def from_derivation_precursors(cls, derivation_precursors: List[DerivationPrecursor], *args):
        extra_tags: List[Optional[Dict[GraphNode, str]]] = []
        for step, precursor in enumerate(derivation_precursors):
            if precursor is None:
                extra_tags.append(None)
            else:
                # x: is variable
                # h: is handle
                # TODO:
                # i: is intrinsic variable
                # b: is bound variable
                eps = precursor.hrg.external_nodes
                status = {i: "x" for i in eps}
                # terminal pred edge
                for node in precursor.new_edge.nodes:
                    if node in status and re.match("^h\d+$", node.name):
                        status[node] = "h"
                extra_tags.append(status)
        return cls(extra_tags)
