from .metric_utils import ClassificationMetrics, PredictionStats
from sklearn.neighbors import KNeighborsClassifier
from torch.utils.tensorboard import SummaryWriter
from sklearn.decomposition import PCA
from ..model_utils import MetricModel
from betacal import BetaCalibration
from typing import Iterator, Union
import torch.nn as nn
from tqdm import tqdm
import torch


def get_embeddings(model: MetricModel, data_loader: Iterator) -> tuple:
    """
    Computes the feeature representation of all samples of a provided dataset.
    
    Parameters
    ----------
    model: MetricModel object, which is used to compute the feature representations.
    data_loader: Data loader, used to provide the data for which the feature representaions are computed.

    Returns
    -------
    features: Feature representations, computed by the MetricModel.
    targets: Labeles, corresponding to the feature representations.
    samples: Raw data samples.
    """
    
    with torch.no_grad():
        
        # initialize lists, in which features, targets and data samples are stored
        feature_list, target_list, samples = [], [], []
        
        # iterate through the validation set
        for data, labels in tqdm(data_loader, 'Computing Embeddings...', total=len(data_loader)):
            
            # get training data and corresponding labels
            samples.extend(data)
            
            # convert data types if necessary
            if type(labels) == list:
                labels = torch.LongTensor(labels)
            
            # forward pass of the model
            features = model.feature_extractor(data)
            
            # store the features and targets
            feature_list.append(features)
            target_list.append(labels)
        
        # get the correct representation of features and targets
        features = torch.cat(feature_list, dim=0).detach().cpu()
        targets = torch.cat(target_list, dim=0).detach().cpu()
        return features, targets, samples
    
    
def calibrate_model(predictions: torch.Tensor, targets: torch.Tensor, bc: Union[None, BetaCalibration]) -> tuple:
    
    # initialize a beta calibration object
    if bc is None:
        bc = BetaCalibration(parameters="abm")
        
        # fit the beta calibrator
        bc.fit(predictions.numpy().reshape(-1, 1), targets.numpy())
    
    # calibrate the predictions
    predictions_calibrated = torch.from_numpy(bc.predict(predictions))
    return bc, predictions_calibrated
    
    
def fit_knn_classifier(features_train: torch.Tensor, features_val: torch.Tensor,
                       tar_train: int, num_neighbors: int, var_threshold: float) -> tuple:
    
    # create a pca object
    pca = PCA()

    # compute the principal components
    features_all = torch.cat([features_train, features_val], dim=0).numpy()
    
    # reduce the dimensionality of the features
    pca.fit(features_all)
    
    # estimate the value of components, in which a certain threshold of the variance is preserved
    num_features = 2
    
    # get the variance, stored in each dimension
    var_ratio_cumsum = pca.explained_variance_ratio_.cumsum().tolist()
    
    # extract the minimal number of dimensions, which maintains a pre-defined ratio of the total variance
    for i, var in enumerate(reversed(var_ratio_cumsum)):
        if var < var_threshold:
            num_features = len(var_ratio_cumsum) - i
            break
    
    # at least two dimensions
    num_features = max(num_features, 2)
    
    # apply the dimensionality reduction
    features_train = pca.transform(features_train.numpy())[:, :num_features]
    features_val = pca.transform(features_val.numpy())[:, :num_features]
    
    # fit a knn classifier
    classifier = KNeighborsClassifier(n_neighbors=num_neighbors, weights='distance')
    classifier.fit(features_train, tar_train)
    return classifier, features_train, features_val


def val_epoch_metric(model: MetricModel, train_loader: Iterator, val_loader: Iterator, num_neighbors: int,
                     metrics: ClassificationMetrics, class_names: list, writer: SummaryWriter,
                     time_step: int, var_threshold: float, ps: PredictionStats):
    
    # evaluation mode
    model.eval()
    
    # extract the feature representations and corresponding labels for the training data
    features_train, tar_train, samples_train = get_embeddings(model, train_loader)

    # extract the feature representations and corresponding labels for the validation data
    features_val, tar_val, sample_val = get_embeddings(model, val_loader)
    
    # extract the knn classifier as well as the train and validation features after a dimensionailty reduction step
    classifier, _, features_val = fit_knn_classifier(features_train, features_val, tar_train, num_neighbors, var_threshold)
    
    # get probabilities of the nearest neighbor classifier
    prob_knn = torch.Tensor(classifier.predict_proba(features_val))[:, 1]
    
    # compute metrics and write them to tensorboard
    output_dict = metrics(prob_knn, tar_val, class_names, writer, time_step)
    
    # extract the f1 score
    f1 = output_dict['scalar']['macro avg']['f1-score']
    
    # compute statistics over the predictions and write them to tensorboard
    ps(prob_knn, tar_val, None, time_step, writer)
    return f1


def val_epoch_standard(model: nn.Module, val_loader: Iterator, metrics: ClassificationMetrics, writer: SummaryWriter,
                       epoch: int, ps: PredictionStats, bc: Union[None, BetaCalibration] = None) -> float:
    
    model.eval()
    
    # perform computations without storing gradients
    with torch.no_grad():
        # lists, in which predictions and targets are stored
        predictions, targets, data_list = list(), list(), list()
        
        # iterate though the validation loader
        for i, (data, labels) in tqdm(enumerate(val_loader), 'Validation Epoch', total=len(val_loader)):
            
            # forward pass
            logits = model(data)
            
            # store predictions
            probabilities = torch.softmax(logits, dim=1)[:, 1]
            predictions.append(probabilities)
            
            # store the targets
            if type(labels) == list:
                targets = torch.LongTensor(targets)
            targets.append(labels)
            
            # store the data samples
            data_list.extend(data)
        
        # store predicitons and targets on the cpu
        predictions = torch.cat(predictions, dim=0).detach().cpu()
        targets = torch.cat(targets, dim=0).detach().cpu()
    
    # calibrate the model
    bc, predictions = calibrate_model(predictions, targets, bc)
    
    # store and visualize the results
    results = metrics(predictions, targets, ['Normal', 'Hatespeech'], writer, epoch)
    ps(predictions, targets, None, epoch, writer)
    results = results['scalar']['macro avg']['f1-score']
    return results, bc, predictions, targets, data_list


def val_epoch_output(model: MetricModel, val_loader: Iterator, metrics: ClassificationMetrics, writer: SummaryWriter,
                     epoch: int, ps: PredictionStats, bc: Union[None, BetaCalibration] = None) -> float:
    
    # switch to the evaluation mode
    model.eval()
    
    # perform computations without storing gradients
    with torch.no_grad():
        
        # lists, in which predictions and targets are stored
        predictions, targets, data_list = list(), list(), list()
        
        # iterate though the validation loader
        for i, (data, labels) in tqdm(enumerate(val_loader), 'Validation Epoch', total=len(val_loader)):
            
            # forward pass through the model
            features = model.feature_extractor(data)
            
            # use the correct device
            labels = labels.to(features.device)
            
            # forward pass
            logits = model.classifier(features)
            
            # store predictions
            probabilities = torch.softmax(logits, dim=1)[:, 1]
            predictions.append(probabilities)
                
            # store the targets
            targets.append(labels)
            
            # store the data samples
            data_list.extend(data)
            
        # store predicitons and targets on the cpu
        predictions = torch.cat(predictions, dim=0).detach().cpu()
        targets = torch.cat(targets, dim=0).detach().cpu()

    # calibrate the model
    bc, predictions = calibrate_model(predictions, targets, bc)
    
    # store and visualize the results
    results = metrics(predictions, targets, ['Normal', 'Hatespeech'], writer, epoch)
    ps(predictions, targets, None, epoch, writer)
    results = results['scalar']['macro avg']['f1-score']
    return results, bc, predictions, targets, data_list
