import os    
import tqdm
import jsonlines
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score

strategy = 'original' # original, square, exp
#strategy = 'square'
#strategy = 'exp'

#topk = 1
topk = 1000000 #1000000

dir_path_input = './record/'
dir_path_output = './pecora-results-full/'

def expand(num):
    if strategy == 'original':
        return num
    elif strategy == 'square':
        return num**2
    else:
        return np.exp(num)

files = os.listdir(dir_path_input)
langs = ['te', 'bn', 'ja', 'fi', 'ru']

if topk == 1:
    # top 1
    # mean - 0.5SD
    threshold = {'te': 6.152940643304108, 'bn': 1.8217959971786004, 'ja': 2.4492197151360466, 'fi': 2.3185973456030258, 'ru': 6.349929943943213}
    # mean + 0SD
    #threshold = {'te': 9.2696144466891, 'bn': 3.9564370989799498, 'ja': 6.9015719670998426, 'fi': 6.7099887765944, 'ru': 10.221544576187929}
    # mean + 0.5SD
    #threshold = {'te': 12.386288250074093, 'bn': 6.091078200781299, 'ja': 11.35392421906364, 'fi': 11.101380207585775, 'ru': 14.093159208432645}
    # mean + 1SD
    #threshold = {'te': 15.502962053459086, 'bn': 8.225719302582648, 'ja': 15.806276471027434, 'fi': 15.49277163857715, 'ru': 17.96477384067736}
else:
    # top 1000000
    # mean - 0.5SD
    threshold = {'te': 4.693317216102357, 'bn': 0.9342494171859179, 'ja': 1.0660096725192831, 'fi': 1.6759407219989397, 'ru': 3.5079927750979674}
    # mean + 0SD
    #threshold = {'te': 6.970936516224974, 'bn': 1.9165140522372721, 'ja': 3.1694967514730634, 'fi': 3.554152121106224, 'ru': 5.4493517599332355}
    # mean + 0.5SD
    #threshold = {'te': 9.24855581634759, 'bn': 2.898778687288626, 'ja': 5.272983830426844, 'fi': 5.432363520213508, 'ru': 7.390710744768503}
    # mean + 1SD
    #threshold = {'te': 11.526175116470208, 'bn': 3.8810433223399805, 'ja': 7.3764709093806236, 'fi': 7.310574919320793, 'ru': 9.332069729603772}

res_thres = dict()
for fname in files:
    if 'neg' in fname: continue
    if 'val' not in fname: continue
    lang = fname.split('.')[0].split('-')[1]
    print('*****************************')
    print(lang)
    idx = 0

    pos_ins = {}
    neg_ins = {}
    
    recall = {}
    recall_tot = {}
    precision = {}
    precision_tot = {}
    accuracy = {}
    accuracy_tot = {}

    for l in langs:
        pos_ins[l] = []
        neg_ins[l] = []
        recall[l] = 0
        recall_tot[l] = 0
        precision[l] = 0
        precision_tot[l] = 0
        accuracy[l] = 0
        accuracy_tot[l] = 0

    with open(dir_path_input + fname) as f:
        for item in jsonlines.Reader(f):
            '''
            # Remove yes/no questions
            if item['prediction'] in ["yes", "no"]:
                idx += 1
                continue
            '''
            #print("Current: {} - {}".format(lang, idx))
            #"<Q>: " + item['query'] + " <P>:" + passage
            
            # Get weighted average CCI for each context
            save_path = dir_path_output + lang + '-' + str(idx) + '.json'
            with open(save_path) as r:
                res_pecora = json.load(r)

            score = 0
            for i in range(min(topk, len(res_pecora['cci_scores']))):
                score += expand(res_pecora['cci_scores'][i]['cti_score'])
            score /= min(topk, len(res_pecora['cci_scores']))

            if True in item['ais']:
                pos_ins[item['query_language']].append(score)
                mark_concat = True
            else:
                neg_ins[item['query_language']].append(score)
                mark_concat = False
                
            if score >= threshold[item['query_language']]:
                mark_attri = True
            else:
                mark_attri = False

            if mark_attri == mark_concat:
                accuracy[item['query_language']] += 1
            accuracy_tot[item['query_language']] += 1
            if mark_concat:
                recall_tot[item['query_language']] += 1
                if mark_attri:
                    recall[item['query_language']] += 1
            if mark_attri:
                precision_tot[item['query_language']] += 1
                if mark_attri == mark_concat:
                    precision[item['query_language']] += 1
            idx += 1

            '''
            sum_weight = 0
            sum_value = np.zeros(len(res_pecora['input_context_tokens']), dtype='float64')
            #for i in res_pecora['cci_scores']:
            for i in [res_pecora['cci_scores'][0]]:
                if weight == 'norm1':
                    sum_value += i['cti_score'] * np.array(i['input_context_scores'])
                    sum_weight += i['cti_score']
                elif weight == 'norm2':
                    sum_value += (i['cti_score']**2) * np.array(i['input_context_scores'])
                    sum_weight += i['cti_score']**2
                else:
                    sum_value += np.exp(i['cti_score']) * np.array(i['input_context_scores'])
                    sum_weight += np.exp(i['cti_score'])

            if not top1:
                sum_value /= sum_weight
            #print(sum_value)

            # Get 0/1 mask, 0 means <0>, <1> ...
            passage_id = 0
            mask = np.ones(len(res_pecora['input_context_tokens']), dtype='float64')
            for i in range(len(res_pecora['input_context_tokens'])):
                if str(passage_id) in res_pecora['input_context_tokens'][i] and ('<' in res_pecora['input_context_tokens'][max(0, i-1)] or '<' in res_pecora['input_context_tokens'][i]) and ('>' in res_pecora['input_context_tokens'][min(len(res_pecora['input_context_tokens'])-1, i+1)] or '>' in res_pecora['input_context_tokens'][i]):
                    mask[i] = 0
                    if '<' in res_pecora['input_context_tokens'][max(0, i-1)]:
                        mask[max(0, i-1)] = 0
                    if '>' in res_pecora['input_context_tokens'][min(len(res_pecora['input_context_tokens'])-1, i+1)]:
                        mask[min(len(res_pecora['input_context_tokens'])-1, i+1)] = 0
                    passage_id += 1

            if sum(mask) == 0: 
                #print(item)
                #print(idx)
                idx += 1
                continue
            
            # Split concat passages 
            split_value_list = []
            save_flag = False
            tmp = []
            for i in range(len(mask)):
                if mask[i] == 0 and save_flag:
                    split_value_list.append(tmp)
                    tmp = []
                    save_flag = False
                if mask[i] == 1:
                    tmp.append(sum_value[i])
                    save_flag = True
            if save_flag:
                split_value_list.append(tmp)

            if len(item['ais']) != len(split_value_list):
                #print(item)
                #print(idx)
                idx += 1
                continue

            split_value_list_concat = []
            for i in split_value_list:
                split_value_list_concat += i

            sentence_rep_highest = max(split_value_list_concat)
            sentence_rep_mean = np.mean(split_value_list_concat)
            if True in item['ais']:
                pos_ins_highest.append(sentence_rep_highest)
                pos_ins_mean.append(sentence_rep_mean)
            else:
                neg_ins_highest.append(sentence_rep_highest)
                neg_ins_mean.append(sentence_rep_mean)

            if top1: 
                thres = threshold['top1'][weight]
            else:
                thres = threshold['pruning' if pruning else 'full'][weight]
                
            if sentence_rep_highest >= thres[lang]:
                mark_attri = True
            else:
                mark_attri = False
                
            if True in item['ais']:
                if mark_attri:
                    accuracy += 1
            else:
                if not mark_attri:
                    accuracy += 1
            accuracy_tot += 1

            if True in item['ais']:
                recall_tot += 1
                if mark_attri:
                    recall += 1
            
            idx += 1
            '''    
    for l in langs:
        #res_f1 = 2*(precision[l]/precision_tot[l])*(recall[l]/recall_tot[l])/((precision[l]/precision_tot[l])+(recall[l]/recall_tot[l]))
        
        label = np.array([1 for _ in range(len(pos_ins[l]))] + [0 for _ in range(len(neg_ins[l]))])
        predict = np.array(pos_ins[l]+neg_ins[l])
        roc_auc = roc_auc_score(label, predict)

        print(l)
        print("pos_num:     {}".format(len(pos_ins[l])))
        print("neg_num:     {}".format(len(neg_ins[l])))
        print()
        #print("Recall:      {}/{}={}".format(recall[l], recall_tot[l], recall[l]/recall_tot[l]))
        #print("Precision:   {}/{}={}".format(precision[l], precision_tot[l], precision[l]/precision_tot[l]))
        #print("F1:          2*{}*{}/({}+{})={}".format(precision[l]/precision_tot[l], recall[l]/recall_tot[l], precision[l]/precision_tot[l], recall[l]/recall_tot[l], res_f1))
        print("Accuracy:    {}/{}={}".format(accuracy[l], accuracy_tot[l], accuracy[l]/accuracy_tot[l]))
        print("ROC_AUC:     {}".format(roc_auc))
        print()

        print("pos_mean:    {}".format(np.mean(pos_ins[l])))
        print("neg_mean:    {}".format(np.mean(neg_ins[l])))
        print("pos_median:  {}".format(np.median(pos_ins[l])))
        print("neg_median:  {}".format(np.median(neg_ins[l])))

        print("all_mean:    {}".format(np.mean(pos_ins[l]+neg_ins[l])))
        print("all_median:  {}".format(np.median(pos_ins[l]+neg_ins[l])))

        # Use average as threshold
        res_thres[l] = np.mean(pos_ins[l]+neg_ins[l]) - 0.5 * np.std(pos_ins[l]+neg_ins[l])
        # Use best ROC AUC threshold as threshold
        '''
        label = np.array([1 for _ in range(len(pos_ins[l]))] + [0 for _ in range(len(neg_ins[l]))])
        predict = np.array(pos_ins[l]+neg_ins[l])
        _, _, _, _, optimal_th, _ = ROC(label, predict)
        res_thres[l] = optimal_th
        '''
        print("=============")

print(res_thres)

'''
    label = np.array([1 for _ in range(len(pos_ins_highest))] + [0 for _ in range(len(neg_ins_highest))])
    predict = np.array(pos_ins_highest+neg_ins_highest)
    roc_auc = roc_auc_score(label, predict)
    print("ROC_AUC_highest: {}".format(roc_auc))
    print()
    #label = np.array([1 for _ in range(len(pos_ins_mean))] + [0 for _ in range(len(neg_ins_mean))])
    #predict = np.array(pos_ins_mean+neg_ins_mean)
    #roc_auc = roc_auc_score(label, predict)
    #print("ROC_AUC_mean: {}".format(roc_auc))
    
    #print("Accuracy: {}/{}={}".format(accuracy, accuracy_tot, accuracy/accuracy_tot))
    
    print("pos_highest_mean: {}".format(np.mean(pos_ins_highest)))
    print("neg_highest_mean: {}".format(np.mean(neg_ins_highest)))
    print("pos_highest_median: {}".format(np.median(pos_ins_highest)))
    print("neg_highest_median: {}".format(np.median(neg_ins_highest)))
    
    print(pos_ins_highest)
    print(neg_ins_highest)
'''
