
from main import *

def multiclass_acc(preds, truths):
    """
    Compute the multiclass accuracy w.r.t. groundtruth
    :param preds: Float array representing the predictions, dimension (N,)
    :param truths: Float/int array representing the groundtruth classes, dimension (N,)
    :return: Classification accuracy
    """
    return np.sum(np.round(preds) == np.round(truths)) / float(len(truths))


def weighted_accuracy(test_preds_emo, test_truth_emo):
    true_label = (test_truth_emo > 0)
    predicted_label = (test_preds_emo > 0)
    tp = float(np.sum((true_label == 1) & (predicted_label == 1)))
    tn = float(np.sum((true_label == 0) & (predicted_label == 0)))
    p = float(np.sum(true_label == 1))
    n = float(np.sum(true_label == 0))

    return (tp * (n / p) + tn) / (2 * n)


def eval_mosi(results, truths, exclude_zero=False):
    test_preds = results.view(-1).cpu().detach().numpy()
    test_truth = truths.view(-1).cpu().detach().numpy()

    non_zeros = np.array([i for i, e in enumerate(test_truth) if e != 0 or (not exclude_zero)])

    test_preds_a7 = np.clip(test_preds, a_min=-3., a_max=3.)
    test_truth_a7 = np.clip(test_truth, a_min=-3., a_max=3.)
    test_preds_a5 = np.clip(test_preds, a_min=-2., a_max=2.)
    test_truth_a5 = np.clip(test_truth, a_min=-2., a_max=2.)

    mae = np.mean(np.absolute(test_preds - test_truth))  # Average L1 distance between preds and truths
    corr = np.corrcoef(test_preds, test_truth)[0][1]
    mult_a7 = multiclass_acc(test_preds_a7, test_truth_a7)
    mult_a5 = multiclass_acc(test_preds_a5, test_truth_a5)
    f_score = f1_score((test_preds[non_zeros] > 0), (test_truth[non_zeros] > 0), average='weighted')
    binary_truth = (test_truth[non_zeros] > 0)
    binary_preds = (test_preds[non_zeros] > 0)

    print("MAE: ", mae)
    print("Correlation Coefficient: ", corr)
    print("mult_acc_7: ", mult_a7)
    print("mult_acc_5: ", mult_a5)
    print("F1 score: ", f_score)
    print("Accuracy: ", accuracy_score(binary_truth, binary_preds))

    print("-" * 50)


def scores(results, truths):
    emos = ["Neutral", "Happy", "Sad", "Angry"]
    preds = results.view(-1, 4, 2).detach().numpy()
    label = truths.view(-1, 4).detach().numpy()
    for emo_ind in range(4):
        print(f"{emos[emo_ind]}: ")
        test_preds_i = np.argmax(preds[:, emo_ind], axis=1)
        test_truth_i = label[:, emo_ind]
        f1 = f1_score(test_truth_i, test_preds_i, average='weighted')
        acc = accuracy_score(test_truth_i, test_preds_i)
        print("  - F1 Score: ", f1)
        print("  - Accuracy: ", acc)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset='iemocap'

def evaluate(model,test_loader, valid_loader,dataset='iemocap', test=False):
    model.eval()
    if dataset == 'iemocap':
        criterion = nn.CrossEntropyLoss()
    else:

        criterion = nn.L1Loss()
    loader = test_loader if test else valid_loader
    total_loss = 0.0

    results = []
    truths = []

    with torch.no_grad():
        for i_batch, (batch_X, batch_Y) in enumerate(loader):
            sample_ind, text, audio, vision = batch_X
            eval_attr = batch_Y.squeeze(dim=-1)
            # eval_attr = batch_Y.squeeze(dim=-1)  # if num of labels is 1
            text = text.to(device)
            audio = audio.to(device)
            vision = vision.to(device)

            # if hyp_params.use_cuda:
            #     with torch.cuda.device(0):
            #         text, audio, vision, eval_attr = text.cuda(), audio.cuda(), vision.cuda(), eval_attr.cuda()
            #         if hyp_params.dataset == 'iemocap':
            #             eval_attr = eval_attr.long()
            preds = model(text, audio, vision).to(device)
            if dataset == 'iemocap':
                eval_attr = eval_attr.long()
                # eval_attr = torch.argmax(eval_attr, dim=-1)
                eval_attr = eval_attr.view(-1).to(device)
                preds = preds.view(-1, 2)
            batch_size = text.size(0)

            total_loss += criterion(preds, eval_attr).item() * batch_size

            # Collect the results into dictionary
            results.append(preds)
            truths.append(eval_attr)
    avg_loss = total_loss #/ (length_test if test else length_valid)
    results = torch.cat(results)
    truths = torch.cat(truths)
    return avg_loss, results, truths

model=torch.load('model_best_perform.pt')

_, results, truths = evaluate(model,test_loader,valid_loader,  test=True)

if dataset == 'iemocap':
    scores(results, truths)
else:

    eval_mosi(results, truths, True)

print(scores)