import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from scipy.stats import linregress
import seaborn as sns
from project_root import join_with_root

def fix_nan(l):
    return [0 if np.isnan(a) else a for a in l]

def compareZSOS(task, model):
    sns.set_theme()
    df1 = pd.read_json(join_with_root("outputs/evaluation/corr_zero_shot_train_avg_no_emotion.json"))
    df2 = pd.read_json(join_with_root("outputs/evaluation/corr_few_shot_train_avg_no_emotion.json"))
    df1=df1[(df1["task"]==task)&(df1["model"]==model)]
    df2=df2[(df2["task"]==task)&(df2["model"]==model)]

    joint_df = pd.merge(left=df2,right=df1,on="ID")

    # k=0=orange =OS
    ks = list(zip(fix_nan(joint_df["kendall_x"].tolist()), fix_nan(joint_df["kendall_y"].tolist())))
    ks = sorted(ks, key=lambda x:x[1])

    print(np.mean([k[0] for k in ks]), np.mean([k[1] for k in ks]))

    plt.scatter(list(range(len(ks))), [k[0] for k in ks], s=3, color="orange")
    plt.scatter(list(range(len(ks))), [k[1] for k in ks], s=3, color="blue")
    plt.plot([np.median([k[1] for k in ks])]*len(ks), color="blue", linewidth=2)
    plt.plot([np.median([k[0] for k in ks])]*len(ks), color="red", linewidth=2)


    reg = linregress(list(range(len(ks))), [k[0] for k in ks])
    plt.axline(xy1=(0, reg.intercept), slope=reg.slope, linestyle="--", color="red", linewidth=2)
    reg = linregress(list(range(len(ks))), [k[1] for k in ks])
    plt.axline(xy1=(0, reg.intercept), slope=reg.slope, linestyle="--", color="blue", linewidth=2)
    plt.xlabel("Prompt ID")
    plt.ylabel("Kendall correlation")


    plt.tight_layout()
    plt.savefig(join_with_root(f"outputs/plots/ZS_vs_OS_train_{task}_{model}.pdf"))
    plt.show()

if __name__ == '__main__':
    compareZSOS("en_de","Platypus70B")
    compareZSOS("zh_en","Platypus70B")
    compareZSOS("summarization", "Platypus70B")
    compareZSOS("en_de", "Nous")
    compareZSOS("zh_en", "Nous")
    compareZSOS("summarization", "Nous")
    compareZSOS("en_de", "OpenOrca")
    compareZSOS("zh_en", "OpenOrca")
    compareZSOS("summarization", "OpenOrca")