import argparse
from os import remove
import re
import string
from pysbd import Segmenter
from zhon.hanzi import punctuation

from ioFn import readTxt, readJsonl, saveGeneralCSV

def calOverlap(results_a: list, results_b: list):
    """
    results_a, results_b: both list of sentence id
    """
    average_ratio = 0.0
    for (ra, rb) in zip(results_a, results_b):
        ratio = len(set(ra).intersection(set(rb))) / (len(ra) + 1e-3)
        average_ratio += ratio
    return average_ratio / (len(results_a) + 1e-3)

def remove_punc(s):
    # chinese
    for i in punctuation:
        s = s.replace(i, ' ')

    # english
    punctuation_string = string.punctuation
    for i in punctuation_string:
        s = s.replace(i, ' ')
    return s.strip()
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--l-zh", help="label file", 
        default="/home/tiger/wikiLingua_enzh_ext_finetune_english_mspm4_EnZh_chinese/resource/dataset/test.label.jsonl"
    )
    parser.add_argument(
        "--l-en", help="label file", 
        default="/home/tiger/wikiLingua_enzh_ext_finetune_english_mspm4_EnZh_english/resource/dataset/test.label.jsonl"
    )
    parser.add_argument("--system-file", default="/opt/tiger/sumtest/multilingual/system_list.txt")
    parser.add_argument('-k', type=int, default=10)
    parser.add_argument('-o', default="ext_case.csv")

    args = parser.parse_args()
    
    inputs_zh = readJsonl(args.l_zh)[:args.k]
    # oracles_zh = [sorted(item['label'], key=lambda x: x) for item in inputs_zh]
    oracles_zh = [item['label'] for item in inputs_zh]

    inputs_en = readJsonl(args.l_en)[:args.k]
    # oracles_en = [sorted(item['label'], key=lambda x: x) for item in inputs_en]
    oracles_en = [item['label'] for item in inputs_en]

    inputs = {
        'en': {'inputs': inputs_en, "oracles": oracles_en},
        'zh': {'inputs': inputs_zh, "oracles": oracles_zh},
    }

    systems = []
    with open(args.system_file, 'r') as fin:
        for line in fin.readlines():
            hypo_file, name, lang = line.strip().split('\t')
            infos = {}
            
            infos['hypos'] = readTxt(hypo_file)[:args.k]
            infos['name'] = name
            infos['lang'] = lang
            
            seg = Segmenter(lang, clean=False)
            hypo_sents = []
            hypo_sent_ids = []
            match_sents = 0
            for (i, hypo) in enumerate(infos['hypos']):
                # if lang == 'zh':
                #     hypo = re.sub("\s+", "。", hypo)
                hypo = remove_punc(hypo)
                # sents = seg.segment(hypo)
                # hypo_sents.append(sents)
                document = inputs[lang]['inputs'][i]['document']
                document = [remove_punc(sent) for sent in document]
                sent_ids = []
                for (i, sent) in enumerate(document):
                    if sent in hypo:
                        sent_ids.append(i)
                sent_ids = sorted(sent_ids)
                # sent_num += len(sents)
                # for sent in sents:
                #     sent = remove_punc(sent)
                #     if sent in document:
                #         sent_ids.append(document.index(remove_punc(sent)))
                #     else:
                #         print(sent)
                #         print("doc: ", document)
                #         missed_num += 1
                hypo_sent_ids.append(sent_ids)
                match_sents += len(sent_ids)
            infos['hypo_sents'] = hypo_sents
            infos['hypo_sent_ids'] = hypo_sent_ids
            infos['recall_w_oracle'] = calOverlap(inputs[lang]['oracles'], hypo_sent_ids)
            infos['precision_w_oracle'] = calOverlap(hypo_sent_ids, inputs[lang]['oracles'])
            print("match_sents: ", match_sents)
            systems.append(infos)

    outputs = []
    for i in range(args.k):
        items = [
            " ".join(inputs['zh']['inputs'][i]['document']),
            inputs['zh']['inputs'][i]['label'],
            inputs['zh']['inputs'][i]['summary'],
        ]
        for infos in systems:
            items.append(
                infos['hypos'][i]
            )
            items.append(
                infos['hypo_sent_ids'][i]
            )
        outputs.append(items)

    headers = ['document', 'label', 'summary']
    for system_infos in systems:
        headers.extend([
            system_infos['name'] + " Hypothesis",
            system_infos['name'] + " Hypothesis Sentence ids"
        ])
    saveGeneralCSV(outputs, args.o, start=headers)

    for info in systems:
        print(info['name'])    
        print(info['precision_w_oracle'])
        print(info['recall_w_oracle'])

    a_info = systems[0]
    b_info = systems[2]
    overlap = calOverlap(
        a_info['hypo_sent_ids'], b_info['hypo_sent_ids']
    )
    print("overlap between {} and {}: {:.4f}".format(
        a_info['name'],
        b_info['name'],
        overlap,
    ))

    a_info = systems[1]
    b_info = systems[3]
    overlap = calOverlap(
        a_info['hypo_sent_ids'], b_info['hypo_sent_ids']
    )
    print("overlap between {} and {}: {:.4f}".format(
        a_info['name'],
        b_info['name'],
        overlap,
    ))