#include <fstream>
#include <algorithm>

#include "rule.h"
#include "token_manager.h"
#include "include/basic.h"

std::vector< const RuleSet::Tail* > RuleSet::empty_;

void printState(std::ostream& os, const RuleSet::State state) {
    os << TokenManager::stateAt(state)
       << '(' << STATE_VAR_COUNT(state)
       << ';' << STATE_DIRECTION(state)
       << ")";
}

void printStates(std::ostream& os, const RuleSet::States& states) {
    os << "{ ";
    for (auto state : states) {
        printState(os, state);
        os << " ";
    }
    os << "}";
}

void printEquations(std::ostream& os,
                    const RuleSet::Equations& equations) {
    for (auto equation : equations) {
        os << "<";
        if (IS_SENTENCE_VAR(equation[0]))
            os << "G=";
        else
            os << VAR_TO_STRING(equation[0]) << "=";
        int n = equation.size();
        for (int i = 1; i < n; ++i) {
            int var = equation[i];
            if (IS_LEMMA_VAR(var))
                os << 'L';
            else if (IS_SENSE_VAR(var))
                os << 'S';
            else if (IS_THAT_VAR(var))
                os << 'W';
            else if (IS_LITERAL_VAR(var))
                os << '"' << TokenManager::literalAt(var) << '"';
            else
                os << VAR_TO_STRING(var);
            if (i != n - 1)
                os << "+";
        }
        os << "> ";
    }
}

void loadRulesFromFile(const std::string& rules_file,
                       RuleSet& rule_set) {
    LOG_INFO(<< "Loading Rules ...");
    std::ifstream is(rules_file);
    if (!is) {
        LOG_ERROR(<< "Can't find file " << rules_file);
        return;
    }

    is >> rule_set;
    LOG_INFO(<< "Loaded " << rule_set.rules().size() << " Rules");
    is.close();
}

bool RuleSet::Head::operator<(const Head& v) const {
    if (label_index != v.label_index)
        return label_index < v.label_index;

    auto len1 = in_states.size(), len2 = v.in_states.size();
    if (len1 != len2)
        return len1 < len2;

    if (out_state_count != v.out_state_count)
        return out_state_count < v.out_state_count;

    for (auto it1 = in_states.begin(), it2 = v.in_states.begin();
         it1 != in_states.end();
         ++it1, ++it2)
        if (*it1 != *it2)
            return *it1 < *it2;

    return false;
}

std::istream& operator>>(std::istream& is, RuleSet& rs) {
    std::string label;
    int rule_count;
    is >> rule_count;

    rs.rules_list_.resize(rule_count);
    for (int n = 0; n < rule_count; ++n) {
        is >> label;

        RuleSet::Head& rule_head = rs.rules_list_[n].head;
        RuleSet::Tail& rule_tail = rs.rules_list_[n].tail;

        rule_head.label_index = TokenManager::indexOfNodeLabel(label);
        rule_tail.rule_index = n;

        int state_count;

        is >> state_count;
        rule_head.in_states.resize(state_count);

        for (int i = 0; i < state_count; ++i) {
            auto& state = rule_head.in_states[i];
            is >> state;
            state = IS_EMPTY_STATE_INTERNAL(state)
                        ? MAKE_EMPTY(state)
                        : MAKE_NORMAL(state);
        }

        is >> state_count;
        rule_head.out_state_count = state_count;

        rule_tail.out_states.resize(state_count);
        for (int i = 0; i < state_count; ++i) {
            auto& state = rule_tail.out_states[i];
            is >> state;
            state = IS_EMPTY_STATE_INTERNAL(state)
                        ? MAKE_EMPTY(state)
                        : MAKE_REVERSED(state);
        }

        int equation_count, variable_count;
        is >> equation_count;

        auto& equations = rule_tail.equations;
        equations.resize(equation_count);

        for (int j = 0; j < equation_count; ++j) {
            is >> variable_count;

            auto& equation = equations[j];
            equation.resize(variable_count);

            for (int k = 0; k < variable_count; ++k)
                is >> equation[k];
        }

        if (equations.empty() || IS_SENTENCE_VAR(equations[0][0]))
            continue;
        int head_index = VAR_MAJOR(equations[0][0]);
        int in_state_count = rule_head.in_states.size();
        if (head_index < in_state_count) {
            // head_state 被排序调到后面
            auto& head_state = rule_head.in_states[head_index];
            int next_pos = head_index;
            while (next_pos < in_state_count &&
                   rule_head.in_states[next_pos] == head_state)
                ++next_pos;
            SET_REVERSED(head_state);

            if (head_index < next_pos - 1) {
                // LOG_DEBUG(<< "Rule head " << n << " changes");
                // LOG_DEBUG(<< "    from: " << rs.rules_list_[n]);
                for (auto& equation : equations)
                    for (auto& var : equation)
                        if (IS_NORMAL_VAR(var)) {
                            int var_major = VAR_MAJOR(var);
                            int var_minor = VAR_MINOR(var);
                            if (var_major == head_index) {
                                var = MAKE_VAR(next_pos - 1, var_minor);
                            } else if (var_major > head_index &&
                                       var_major < next_pos) {
                                var = MAKE_VAR(var_major - 1, var_minor);
                            }
                        }
                std::sort(rule_head.in_states.begin(), rule_head.in_states.end());
                // LOG_DEBUG(<< "    to: " << rs.rules_list_[n]);
            }
        } else {
            // 不会有一样的状态 (后缀不一样, 所以 index 不会改变)
            SET_NORMAL(rule_tail.out_states[head_index - in_state_count]);
        }
    }

    for (auto& rule : rs.rules_list_)
        rs.rules_map_[rule.head].push_back(&rule.tail);

    return is;
}

std::ostream& operator<<(std::ostream& os, const RuleSet& rs) {
    for (auto& item : rs.rules_map_) {
        auto& head = item.first.get();
        printStates(os, head.in_states);
        std::cerr << ' '
                  << TokenManager::nodeLabelAt(head.label_index) << '\n';
        for (auto tail_ptr : item.second) {
            os << " => ";
            printStates(os, tail_ptr->out_states);
            os << ' ';
            printEquations(os, tail_ptr->equations);
            os << '\n';
        }
    }
    return os;
}

std::ostream& operator<<(std::ostream& os,
                         const RuleSet::Rule& rule) {
    printStates(os, rule.head.in_states);
    os << ' ' << TokenManager::nodeLabelAt(rule.head.label_index) << ' ';
    printStates(os, rule.tail.out_states);
    os << " (" << rule.tail.rule_index << ") ";
    printEquations(os, rule.tail.equations);
    return os;
}
