import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC, LinearSVC
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
import argparse
import random
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.pylab as pl
from itertools import combinations

# add argument parser
parser = argparse.ArgumentParser()
parser.add_argument('--datasets', nargs="*", type=str, default=['scifact_oracle', 'fever', 'cfever', 'scifact'], help='datasets name')
parser.add_argument('--d', type=str, default='./maple', help='directory')
parser.add_argument('--methods', nargs="*", type=str, default=['supervised'], help=['supervised', 'nlpo'])
parser.add_argument('--models', nargs="*", type=str, default=['t5_small'], help=['t5_small'])
parser.add_argument('--directions', nargs="*", type=str, default=['ec', 'ce'], help=['ec', 'ce'])
parser.add_argument('--metric', type=str, default='semsim', help="language model used for similarity score such as "
                                                              "'semsim', 'bleurt', 'bartscore', 'rouge', "
                                                              "'meteor', 'sacrebleu','bleu',")
parser.add_argument('--ml_name', type=str, default='LR', help='machine learning model name')
parser.add_argument('--epochs', type=list, default=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20])
parser.add_argument('--output_dir', type=str, default="results", help="output directory")
parser.add_argument('--testing', action="store_true", help="if testing, only run experiments once, else run 100 times")

args = parser.parse_args()


def read_data(trainset, test_DF, t, seed, features):
    sampled = sample_t(trainset['label'].tolist(), t=t, seed=seed)
    train_DF = trainset.iloc[sampled]
    train_features = train_DF[features]
    test_features = test_DF[features]
    train_labels = train_DF['label'].tolist()
    test_labels = test_DF['label'].tolist()
    return train_labels, train_features, test_labels, test_features



def sample_t(labels_train, t=10, seed = 123):
    random.seed(seed)
    s = [i for i, label in enumerate(labels_train) if label =='SUPPORTS']
    n = [i for i, label in enumerate(labels_train) if label =='NOT_ENOUGH_INFO']
    c = [i for i, label in enumerate(labels_train) if label =='REFUTES']
    all_indexes = []
    for l in [s, n, c]:
        indexes = random.sample(l, t)
        all_indexes.extend(indexes)
    return all_indexes

def train(train_X, train_y, test_X, test_y, ml_name):
    if ml_name == 'MLP':
        model = MLPClassifier(random_state=0, max_iter=5000)
    elif ml_name == 'SVM':
        # model = SVC(kernel='linear', C=1, probability=True)
        model = SVC()
        # model = LinearSVC(random_state=0, tol=1e-5)
    elif ml_name =='DT':
        model = DecisionTreeClassifier(random_state=0)
    elif ml_name == 'RF':
        model = RandomForestClassifier(random_state=0)
    elif ml_name == 'Voting':
        model = VotingClassifier(estimators=[('lr', LogisticRegression(random_state=0)),
                                             ('lr2', LogisticRegression(random_state=0, multi_class='ovr')),
                                             ('lr1', LogisticRegression(random_state=0, solver='liblinear'))], voting='hard')
    elif ml_name == 'LR':
        model = LogisticRegression(random_state=0, max_iter=5000)
        # model = LogisticRegression(random_state=0, max_iter=5000, solver='liblinear', multi_class='ovr', penalty='l2', tol=0.0001, dual=True)


    # fit the model
    model.fit(train_X, train_y)
    # get predictions
    yhat = model.predict(test_X)
    # evaluate predictions
    f1 = f1_score(test_y, yhat, average='macro')
    accuracy = accuracy_score(test_y, yhat)
    # print('Accuracy: %.2f' % (accuracy*100))
    # print(classification_report(test_y, yhat))
    # confusion matrix
    # print(confusion_matrix(test_y, yhat))

    return model, f1, accuracy


def read_all(args):

    trainsets = {}
    testsets = {}
    for dataset in args.datasets:  #['cfever', 'scifact_oracle', 'fever', 'scifact' ]
        print(dataset)
        trainset = {}
        testset ={}
        #read data
        for model in args.models:  #
            for method in args.methods:  #
                for direction in args.directions:  #
                    for epoch in args.epochs:
                        # print(model, method, direction, epoch)
                        if args.metric == 'semsim':
                            if direction == 'ce':
                                train_file = f"{args.d}/{direction}/{method}_{model}_{dataset}/{method}_{model}_{dataset}/{method}_{model}_{dataset}/aggregate/epoch_{epoch}_val_scores.csv"
                                test_file = f"{args.d}/{direction}/{method}_{model}_{dataset}/{method}_{model}_{dataset}/{method}_{model}_{dataset}/aggregate/epoch_{epoch}_test_scores.csv"
                            else:
                                train_file = f"{args.d}/{direction}/{direction}_{method}_{model}_{dataset}/{direction}_{method}_{model}_{dataset}/{direction}_{method}_{model}_{dataset}/aggregate/epoch_{epoch}_val_scores.csv"
                                test_file = f"{args.d}/{direction}/{direction}_{method}_{model}_{dataset}/{direction}_{method}_{model}_{dataset}/{direction}_{method}_{model}_{dataset}/aggregate/epoch_{epoch}_test_scores.csv"
                        else:
                            if direction == 'ce':
                                train_file = f"{args.d}/{direction}/{method}_{model}_{dataset}/{method}_{model}_{dataset}/{method}_{model}_{dataset}/aggregate_{args.metric}/epoch_{epoch}_val_scores.csv"
                                test_file = f"{args.d}/{direction}/{method}_{model}_{dataset}/{method}_{model}_{dataset}/{method}_{model}_{dataset}/aggregate_{args.metric}/epoch_{epoch}_test_scores.csv"
                            else:
                                train_file = f"{args.d}/{direction}/{direction}_{method}_{model}_{dataset}/{direction}_{method}_{model}_{dataset}/{direction}_{method}_{model}_{dataset}/aggregate_{args.metric}/epoch_{epoch}_val_scores.csv"
                                test_file = f"{args.d}/{direction}/{direction}_{method}_{model}_{dataset}/{direction}_{method}_{model}_{dataset}/{direction}_{method}_{model}_{dataset}/aggregate_{args.metric}/epoch_{epoch}_test_scores.csv"
                        trainset_ = pd.read_csv(train_file)
                        testset_ = pd.read_csv(test_file)
                        if args.metric == 'semsim':
                            trainset[f'{model}_{method}_{direction}_generated_claim_{epoch}'] = trainset_['generated_claim']
                            trainset[f'{model}_{method}_{direction}_generated_evidence_{epoch}'] = trainset_['generated_evidence']
                            trainset[f'{model}_{method}_{direction}_claim_evidence_{epoch}'] = trainset_['claim_evidence']
                            testset[f'{model}_{method}_{direction}_generated_claim_{epoch}'] = testset_['generated_claim']
                            testset[f'{model}_{method}_{direction}_generated_evidence_{epoch}'] = testset_['generated_evidence']
                            testset[f'{model}_{method}_{direction}_claim_evidence_{epoch}'] = testset_['claim_evidence']
                        else:
                            trainset[f'{model}_{method}_{direction}_generated_claim_{epoch}'] = trainset_[f'generated_claim_{args.metric}']
                            trainset[f'{model}_{method}_{direction}_generated_evidence_{epoch}'] = trainset_[f'generated_evidence_{args.metric}']
                            trainset[f'{model}_{method}_{direction}_claim_evidence_{epoch}'] = trainset_[f'claim_evidence_{args.metric}']
                            testset[f'{model}_{method}_{direction}_generated_claim_{epoch}'] = testset_[f'generated_claim_{args.metric}']
                            testset[f'{model}_{method}_{direction}_generated_evidence_{epoch}'] = testset_[f'generated_evidence_{args.metric}']
                            testset[f'{model}_{method}_{direction}_claim_evidence_{epoch}'] = testset_[f'claim_evidence_{args.metric}']
        trainset['label'] = trainset_['label']
        testset['label'] = testset_['label']
        trainsets[dataset] = pd.DataFrame(trainset)
        testsets[dataset] = pd.DataFrame(testset)
    return trainsets, testsets

def run(trainsets, testsets, args):
    if args.testing:
        seeds = range(123, 124)
    else:
        seeds = range(123, 224)
    print(seeds)

    results = pd.DataFrame()
    for directions in [args.directions]:
        for models in [args.models]:
            for methods in [args.methods]:
                pairs = ['claim_evidence', 'generated_claim', 'generated_evidence']
                epochs = args.epochs
                features = [f'{model}_{method}_{direction}_{pair}_{epoch}'
                            for direction in directions for model in models for method in methods
                            for pair in pairs for epoch in epochs]
                for dataset in args.datasets:
                    t_list = []; f1_list = []; acc_list = []; seed_list = []
                    for t in list(range(1, 10)) + list(range(10, 60, 10)):
                        for s in seeds:
                            train_labels, train_features, test_labels, test_features = read_data(
                                trainsets[dataset], testsets[dataset], t=t, seed=s, features = features)
                            model, f1, acc = train(train_features, train_labels, test_features, test_labels, ml_name=args.ml_name)
                            t_list.append(t); f1_list.append(f1); acc_list.append(acc); seed_list.append(s)
                    results_ = pd.DataFrame({'t': t_list, 'f1': f1_list, 'acc': acc_list, 'seed': seed_list})
                    results_['train_set'] = dataset
                    results_['test_set'] = dataset
                    results_['method'] = args.ml_name + '_' + '-'.join(methods) + '_' + '-'.join(models) + '_' + '-'.join(directions)
                    results = pd.concat([results, results_], axis=0)
    return results

if __name__ == '__main__':

    print(args)
    #read data
    trainsets, testsets = read_all(args)
    #run
    results = run(trainsets, testsets, args)
    #plot
    #save results
    os.makedirs(args.output_dir, exist_ok=True)
    datasets = "-".join(args.datasets)
    filename = f'{args.output_dir}/{args.ml_name}_{args.metric}_{datasets}.csv'
    results.to_csv(filename, index=False)