import argparse
import json
import logging
import random
from pathlib import Path
from typing import List

from mile.data.problem import ProblemFlag, ProblemEpisode
from mile.preproc.alg514 import ALG514Preprocessor
from mile.preproc.dolphin import DolphinPreprocessor
from mile.preproc.mathqa import MathQAPreprocessor
from mile.preproc.mawps import MAWPSPreprocessor
from mile.preproc.math23k import Math23KPreprocessor
from mile.preproc.preproc import get_counter_logstring, update_counter


def preprocess_each():
    """
    *********************************
    **** Preprocess each dataset ****
    *********************************
    """
    # Preprocess ALG514
    if not Path(args.output, 'ALG.all.json').exists():
        print("\033[1;44;38mBegin to process ALG514.\033[0m")
        ALG514Preprocessor(text_reader=lambda x: x['sQuestion'], save_path=args.output, answer_time_limit=5,
                           allow_reversed_answer=True, require_exact_match=False, error_limit=1E-2) \
            .run(Path(args.input, args.alg), [], 'ALG')
    else:
        print("\033[1;32mALG514 already exists\033[0m")

    # Preprocess DRAW, which has the same structure with ALG514
    if not Path(args.output, 'DRAW.all.json').exists():
        print("\033[1;44;38mBegin to process DRAW.\033[0m")
        ALG514Preprocessor(text_reader=lambda x: x['sQuestion'], save_path=args.output, answer_time_limit=5,
                           allow_reversed_answer=True, require_exact_match=False, error_limit=1E-2) \
            .run(Path(args.input, args.draw), ['train', 'test', 'dev'], 'DRAW')
    else:
        print("\033[1;32mDRAW already exists\033[0m")

    # Preprocess MAWPS
    if not Path(args.output, 'MAWPS.all.json').exists():
        print("\033[1;44;38mBegin to process MAWPS.\033[0m")
        MAWPSPreprocessor(text_reader=lambda x: x['sQuestion'], save_path=args.output, answer_time_limit=5,
                          allow_reversed_answer=False, require_exact_match=False, error_limit=1E-2) \
            .run(Path(args.input, args.mawps), None, 'MAWPS')
    else:
        print("\033[1;32mMAWPS already exists\033[0m")

    # Preprocess MathQA
    if not Path(args.output, 'MQ.test.json').exists():
        print("\033[1;44;38mBegin to process MathQA.\033[0m")
        MathQAPreprocessor(text_reader=lambda x: x['Problem'], save_path=args.output, answer_time_limit=5,
                           allow_reversed_answer=False, require_exact_match=True) \
            .run(Path(args.input, args.mathqa), ['train', 'dev', 'test'], 'MQ')
    else:
        print("\033[1;32mMathQA already exists\033[0m")


def remove_duplicates(items: List[ProblemEpisode]):
    new_items = []
    text_eqn_set = set()

    for item in items:
        key = (item.text_indexed.strip(), item.formula['postfix']['basic'])

        if key not in text_eqn_set:
            new_items.append(item)

    return new_items


def unite_all_dataset():
    """
    *********************************
    **** Integrate whole dataset ****
    *********************************
    """
    print("\033[1;44;38mIntegrate whole dataset\033[0m")

    # Collect only solvable dataset, and split into two sets: (1) single equations and (2) system of equations
    dataset = []
    # The sets which train/dev/test split or fold information is provided, we will not separate such sets.
    fold_or_train_provided = set()

    for path in Path(args.output).glob('*.json'):
        if '_with_original' in path.name or 'mile' in path.name:
            # We don't need to preserve original format when integrate all.
            continue

        if 'train' in path.name or 'all' in path.name:
            fold_or_train_provided.add(path.name.split('.')[0])

        # Only collect solvable & using named transcendent constants.
        dataset += ProblemEpisode.read_json(path, ProblemFlag.SOLVABLE)

    # Split training, development, test by 8:1:1
    random.seed(1)
    set_to_write = {
        'train': [],
        'dev': [],
        'test': []
    }

    for item in remove_duplicates(dataset):
        dataname, setname = item.set.split('/')[:2]
        if dataname in fold_or_train_provided:
            if setname in set_to_write:
                writeto = setname
            else:
                # For datasets which are provided with 5 folds,
                # we assign the 4th(03) and the 5th(04) fold as dev and test respectively.
                if int(setname) == 3:
                    writeto = 'dev'
                elif int(setname) == 4:
                    writeto = 'test'
                else:
                    writeto = 'train'
        else:
            # Randomly shuffle if training set or folds were not provided.
            randomid = random.randrange(10)
            writeto = 'test' if randomid == 0 else ('dev' if randomid == 1 else 'train')

        # Overwrite set information.
        item = item.json
        item['origin'] = item['set']
        item['set'] = 'mile/%s' % writeto

        set_to_write[writeto].append(item)

    for set_type, group in set_to_write.items():
        wp = Path(args.output, 'mile.%s.json' % set_type).open('w+t', encoding='UTF-8')
        json.dump(group, wp)
        wp.close()

        # Write statistics
        counter = {'total': 0, 'solvable': 0, 'group1': 0, 'group2': 0,
                   'missingInfo': 0, 'syntaxError': 0, 'evaluationError': 0, 'notsolvable': 0}

        for prob in group:
            update_counter(counter, prob)

        logging.info('Information of mile/%s', set_type)
        logging.info(get_counter_logstring(counter))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Datasets
    parser.add_argument('--input', '-in', '-I', type=str, required=True, help='Path of input files')
    parser.add_argument('--output', '-out', '-O', type=str, required=True, help='Path to save output files')

    parser.add_argument('--alg', type=str, help='Path of ALG514 files in input directory',
                        default='alg514/alg514.json')
    parser.add_argument('--draw', type=str, help='Path of ALG514 files in input directory',
                        default='draw/draw.json')
    parser.add_argument('--mawps', type=str, help='Path of MAWPS files in input directory',
                        default='mawps/mawps.json')
    parser.add_argument('--mathqa', type=str, help='Path of MathQA files in input directory',
                        default='MathQA')
    parser.add_argument('--dolphin18k', type=str, help='Path of Dolphin files in input directory',
                        default='dolphin18k')
    parser.add_argument('--dolphin1878', type=str, help='Path of Dolphin files in input directory',
                        default='dolphin1878')
    parser.add_argument('--math23k', type=str, help='Path of Dolphin files in input directory',
                        default='math23k/math23k.json')

    # Tokenizer
    parser.add_argument('--bert', type=str, default='bert-base-uncased', help='Name of encoder used for parsing')

    # Parse arguments
    args = parser.parse_args()

    outpath = Path(args.output)
    if not outpath.exists():
        outpath.mkdir(parents=True)

    logging.basicConfig(filename=str(Path(args.output, 'preprocess.log')),
                        level=logging.INFO,
                        format='[%(asctime)s] %(levelname)s @ %(name)s:%(funcName)s :: %(message)s')

    preprocess_each()
    unite_all_dataset()
