import base64
import functools
import io
import json
import os
import pathlib
import pickle
import platform
import warnings
import zipfile
from typing import List, Optional, Tuple
import cloudpickle

import flax.linen as nn
import jax
import numpy as np
import optax
import torch as th
import random

from common.type_aliases import *
from common.type_aliases import TensorDict, Params


def set_seed(seed: int):
    """
    Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).

    Args:
        seed (`int`): The seed to set.
    """
    random.seed(seed)
    np.random.seed(seed)
    th.manual_seed(seed)
    th.cuda.manual_seed_all(seed)


class MLP(nn.Module):
    net_arch: List
    activation_fn: nn.Module
    dropout: float = 0.0
    squash_output: bool = False

    layer_norm: bool = False
    batch_norm: bool = False
    use_bias: bool = True
    kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_normal()
    bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros

    @nn.compact
    def __call__(self, x: jnp.ndarray, deterministic: bool = False, training: bool = True):
        for features in self.net_arch[: -1]:
            x = nn.Dense(features=features, kernel_init=self.kernel_init, bias_init=self.bias_init)(x)
            if self.batch_norm:
                x = nn.BatchNorm(use_running_average=not training, momentum=0.9)(x)
            if self.layer_norm:
                x = nn.LayerNorm()(x)
            x = self.activation_fn(x)
            x = nn.Dropout(rate=self.dropout, deterministic=deterministic)(x)

        if len(self.net_arch) > 0:
            x = nn.Dense(features=self.net_arch[-1], kernel_init=self.kernel_init, bias_init=self.bias_init)(x)

        if self.squash_output:
            return nn.tanh(x)
        else:
            return x


class LeakyReLu:
    def __init__(self, negative_slope: float = 1e-2):
        self.negative_slope = negative_slope

    def __call__(self, *args, **kwargs):
        return nn.leaky_relu(*args, **kwargs, negative_slope=self.negative_slope)


def create_mlp(
    output_dim: int,
    net_arch: List[int],
    activation_fn: Callable = nn.relu,
    dropout: float = 0.0,
    squash_output: bool = False,
    layer_norm: bool = False,
    batch_norm: bool = False,
    use_bias: bool = True,
    kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_normal(),
    bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros
):
    if output_dim > 0:
        net_arch = list(net_arch)
        net_arch.append(output_dim)

    return MLP(
        net_arch=net_arch,
        activation_fn=activation_fn,
        dropout=dropout,
        squash_output=squash_output,
        layer_norm=layer_norm,
        batch_norm=batch_norm,
        use_bias=use_bias,
        kernel_init=kernel_init,
        bias_init=bias_init
    )


def get_basic_rngs(rng: jnp.ndarray) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]:
    rng, param_key, dropout_key, batch_key = jax.random.split(rng, 4)
    return rng, {"params": param_key, "dropout": dropout_key, "batch_stats": batch_key}


@flax.struct.dataclass
class Model:
    step: int
    apply_fn: Callable[..., Any] = flax.struct.field(pytree_node=False)
    params: Params
    batch_stats: Union[Params]
    tx: Optional[optax.GradientTransformation] = flax.struct.field(pytree_node=False)
    opt_state: Optional[optax.OptState] = None

    # model_cls: Type = None

    @classmethod
    def create(
        cls,
        model_def: nn.Module,
        inputs: Sequence[jnp.ndarray],
        tx: Optional[optax.GradientTransformation] = None,
        **kwargs
    ) -> 'Model':

        variables = model_def.init(*inputs)
        # _, params = variables.pop('params')
        params = variables.pop('params')

        """
        NOTE:
            Here we unfreeze the parameter. 
            This is because some optimizer classes in optax must receive a dict, not a frozendict, which is annoying.
            https://github.com/deepmind/optax/issues/160
            And ... if we can access to the params, then why it should be freezed ? 
        """
        # NOTE : Unfreeze the parameters !!!!!
        # params = params.unfreeze()

        # Frozendict's 'pop' method does not support default value. So we use get method instead.
        batch_stats = variables.get("batch_stats", None)

        if tx is not None:
            opt_state = tx.init(params)
        else:
            opt_state = None

        return cls(
            step=1,
            apply_fn=model_def.apply,
            params=params,
            batch_stats=batch_stats,
            tx=tx,
            opt_state=opt_state,
            **kwargs
        )

    def __call__(self, *args, **kwargs):
        return self.apply_fn({"params": self.params}, *args, **kwargs)

    def apply_gradient(
        self,
        loss_fn: Optional[Callable[[Params], Any]] = None,
        grads: Optional[Any] = None,
        has_aux: bool = True
    ) -> Union[Tuple['Model', Any], 'Model']:

        assert ((loss_fn is not None) or (grads is not None), 'Either a loss function or grads must be specified.')

        if grads is None:
            grad_fn = jax.grad(loss_fn, has_aux=has_aux)
            if has_aux:
                grads, aux = grad_fn(self.params)
            else:
                grads = grad_fn(self.params)
        else:
            assert (has_aux, 'When grads are provided, expects no aux outputs.')

        updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params)
        new_params = optax.apply_updates(self.params, updates)
        new_model = self.replace(step=self.step + 1, params=new_params, opt_state=new_opt_state)

        if has_aux:
            return new_model, aux
        else:
            return new_model

    def save_dict_from_path(self, save_path: str) -> Params:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        with open(save_path, 'wb') as f:
            f.write(flax.serialization.to_bytes(self.params))
        return self.params

    def load_dict_from_path(self, load_path: str) -> "Model":
        with open(load_path, 'rb') as f:
            params = flax.serialization.from_bytes(self.params, f.read())
        return self.replace(params=params)

    def save_batch_stats_from_path(self, save_path: str) -> Params:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        with open(save_path, 'wb') as f:
            f.write(flax.serialization.to_bytes(self.batch_stats))
        return self.batch_stats

    def load_batch_stats_from_path(self, load_path: str) -> "Model":
        with open(load_path, 'rb') as f:
            batch_stats = flax.serialization.from_bytes(self.batch_stats, f.read())
        return self.replace(batch_stats=batch_stats)

    def load_dict(self, params: bytes) -> 'Model':
        params = flax.serialization.from_bytes(self.params, params)
        return self.replace(params=params)


def recursive_getattr(obj: Any, attr: str, *args) -> Any:
    """
    Recursive version of getattr
    taken from https://stackoverflow.com/questions/31174295

    Ex:
    > MyObject.sub_object = SubObject(name='test')
    > recursive_getattr(MyObject, 'sub_object.name')  # return test
    :param obj:
    :param attr: Attribute to retrieve
    :return: The attribute
    """

    def _getattr(_obj: Any, _attr: str) -> Any:
        return getattr(_obj, _attr, *args)

    return functools.reduce(_getattr, [obj] + attr.split("."))


def recursive_setattr(obj: Any, attr: str, val: Any) -> None:
    """
    Recursive version of setattr
    taken from https://stackoverflow.com/questions/31174295

    Ex:
    > MyObject.sub_object = SubObject(name='test')
    > recursive_setattr(MyObject, 'sub_object.name', 'hello')
    :param obj:
    :param attr: Attribute to set
    :param val: New value of the attribute
    """
    pre, _, post = attr.rpartition(".")
    return setattr(recursive_getattr(obj, pre) if pre else obj, post, val)


def is_json_serializable(item: Any) -> bool:
    """
    Test if an object is serializable into JSON

    :param item: The object to be tested for JSON serialization.
    :return: True if object is JSON serializable, false otherwise.
    """
    # Try with try-except struct.
    json_serializable = True
    try:
        _ = json.dumps(item)
    except TypeError:
        json_serializable = False
    return json_serializable


def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, verbose: int = 0) -> None:
    """
    Save an object to path creating the necessary folders along the way.
    If the path exists and is a directory, it will raise a warning and rename the path.
    If a suffix is provided in the path, it will use that suffix, otherwise, it will use '.pkl'.
    :param path: the path to open.
        if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
        path actually exists. If path is a io.BufferedIOBase the path exists.
    :param obj: The object to save.
    :param verbose: Verbosity level, 0 means only warnings, 2 means debug information.
    """
    with open_path(path, "w", verbose=verbose, suffix="pkl") as file_handler:
        # Use protocol>=4 to support saving replay buffers >= 4Gb
        # See https://docs.python.org/3/library/pickle.html
        pickle.dump(obj, file_handler, protocol=pickle.HIGHEST_PROTOCOL)


def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: int = 0) -> Any:
    """
    Load an object from the path. If a suffix is provided in the path, it will use that suffix.
    If the path does not exist, it will attempt to load using the .pkl suffix.
    :param path: the path to open.
        if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
        path actually exists. If path is a io.BufferedIOBase the path exists.
    :param verbose: Verbosity level, 0 means only warnings, 2 means debug information.
    """
    with open_path(path, "r", verbose=verbose, suffix="pkl") as file_handler:
        return pickle.load(file_handler)


@functools.singledispatch
def open_path(
    path: Union[str, pathlib.Path, io.BufferedIOBase],
    mode: str,
    verbose: int = 0,
    suffix: Optional[str] = None
):
    """
    Opens a path for reading or writing with a preferred suffix and raises debug information.
    If the provided path is a derivative of io.BufferedIOBase it ensures that the file
    matches the provided mode, i.e. If the mode is read ("r", "read") it checks that the path is readable.
    If the mode is write ("w", "write") it checks that the file is writable.

    If the provided path is a string or a pathlib.Path, it ensures that it exists. If the mode is "read"
    it checks that it exists, if it doesn't exist it attempts to read path.suffix if a suffix is provided.
    If the mode is "write" and the path does not exist, it creates all the parent folders. If the path
    points to a folder, it changes the path to path_2. If the path already exists and verbose == 2,
    it raises a warning.

    :param path: the path to open.
        if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
        path actually exists. If path is a io.BufferedIOBase the path exists.
    :param mode: how to open the file. "w"|"write" for writing, "r"|"read" for reading.
    :param verbose: Verbosity level, 0 means only warnings, 2 means debug information.
    :param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
        If mode is "r" then we attempt to open the path. If an error is raised and the suffix
        is not None, we attempt to open the path with the suffix.
    :return:
    """
    if not isinstance(path, io.BufferedIOBase):
        raise TypeError("Path parameter has invalid type.", io.BufferedIOBase)
    if path.closed:
        raise ValueError("File stream is closed.")
    mode = mode.lower()
    try:
        mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode]
    except KeyError as e:
        raise ValueError("Expected mode to be either 'w' or 'r'.") from e
    if ("w" == mode) and not path.writable() or ("r" == mode) and not path.readable():
        e1 = "writable" if "w" == mode else "readable"
        raise ValueError(f"Expected a {e1} file.")
    return path


def data_to_json(data: Dict[str, Any]) -> str:
    """
    Turn data (class parameters) into a JSON string for storing

    :param data: Dictionary of class parameters to be
        stored. Items that are not JSON serializable will be
        pickled with Cloudpickle and stored as bytearray in
        the JSON file
    :return: JSON string of the data serialized.
    """
    # First, check what elements can not be JSONfied,
    # and turn them into byte-strings
    serializable_data = {}
    for data_key, data_item in data.items():
        # See if object is JSON serializable
        if is_json_serializable(data_item):
            # All good, store as it is
            serializable_data[data_key] = data_item
        else:
            # Not serializable, cloudpickle it into
            # bytes and convert to base64 string for storing.
            # Also store type of the class for consumption
            # from other languages/humans, so we have an
            # idea what was being stored.
            base64_encoded = base64.b64encode(cloudpickle.dumps(data_item)).decode()

            # Use ":" to make sure we do
            # not override these keys
            # when we include variables of the object later
            cloudpickle_serialization = {
                ":type:": str(type(data_item)),
                ":serialized:": base64_encoded,
            }

            # Add first-level JSON-serializable items of the
            # object for further details (but not deeper than this to
            # avoid deep nesting).
            # First we check that object has attributes (not all do,
            # e.g. numpy scalars)
            if hasattr(data_item, "__dict__") or isinstance(data_item, dict):
                # Take elements from __dict__ for custom classes
                item_generator = data_item.items if isinstance(data_item, dict) else data_item.__dict__.items
                for variable_name, variable_item in item_generator():
                    # Check if serializable. If not, just include the
                    # string-representation of the object.
                    if is_json_serializable(variable_item):
                        cloudpickle_serialization[variable_name] = variable_item
                    else:
                        cloudpickle_serialization[variable_name] = str(variable_item)

            serializable_data[data_key] = cloudpickle_serialization
    json_string = json.dumps(serializable_data, indent=4)
    return json_string


def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
    """
    Turn JSON serialization of class-parameters back into dictionary.

    :param json_string: JSON serialization of the class-parameters
        that should be loaded.
    :param custom_objects: Dictionary of objects to replace
        upon loading. If a variable is present in this dictionary as a
        key, it will not be deserialized and the corresponding item
        will be used instead. Similar to custom_objects in
        ``keras.models.load_model``. Useful when you have an object in
        file that can not be deserialized.
    :return: Loaded class parameters.
    """
    if custom_objects is not None and not isinstance(custom_objects, dict):
        raise ValueError("custom_objects argument must be a dict or None")

    json_dict = json.loads(json_string)
    # This will be filled with deserialized data
    return_data = {}
    for data_key, data_item in json_dict.items():
        if custom_objects is not None and data_key in custom_objects.keys():
            # If item is provided in custom_objects, replace
            # the one from JSON with the one in custom_objects
            return_data[data_key] = custom_objects[data_key]
        elif isinstance(data_item, dict) and ":serialized:" in data_item.keys():
            # If item is dictionary with ":serialized:"
            # key, this means it is serialized with cloudpickle.
            serialization = data_item[":serialized:"]
            # Try-except deserialization in case we run into
            # errors. If so, we can tell bit more information to
            # user.
            try:
                base64_object = base64.b64decode(serialization.encode())
                deserialized_object = cloudpickle.loads(base64_object)
            except (RuntimeError, TypeError):
                warnings.warn(
                    f"Could not deserialize object {data_key}. "
                    + "Consider using `custom_objects` argument to replace "
                    + "this object."
                )
            return_data[data_key] = deserialized_object
        else:
            # Read as it is
            return_data[data_key] = data_item
    return return_data


@open_path.register(str)
def open_path_str(path: str, mode: str, verbose: int = 0, suffix: Optional[str] = None) -> io.BufferedIOBase:
    """
    Open a path given by a string. If writing to the path, the function ensures
    that the path exists.

    :param path: the path to open. If mode is "w" then it ensures that the path exists
        by creating the necessary folders and renaming path if it points to a folder.
    :param mode: how to open the file. "w" for writing, "r" for reading.
    :param verbose: Verbosity level, 0 means only warnings, 2 means debug information.
    :param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
        If mode is "r" then we attempt to open the path. If an error is raised and the suffix
        is not None, we attempt to open the path with the suffix.
    :return:
    """
    return open_path(pathlib.Path(path), mode, verbose, suffix)


@open_path.register(pathlib.Path)
def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0,
                      suffix: Optional[str] = None) -> io.BufferedIOBase:
    """
    Open a path given by a string. If writing to the path, the function ensures
    that the path exists.

    :param path: the path to check. If mode is "w" then it
        ensures that the path exists by creating the necessary folders and
        renaming path if it points to a folder.
    :param mode: how to open the file. "w" for writing, "r" for reading.
    :param verbose: Verbosity level, 0 means only warnings, 2 means debug information.
    :param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
        If mode is "r" then we attempt to open the path. If an error is raised and the suffix
        is not None, we attempt to open the path with the suffix.
    :return:
    """
    if mode not in ("w", "r"):
        raise ValueError("Expected mode to be either 'w' or 'r'.")

    if mode == "r":
        try:
            path = path.open("rb")
        except FileNotFoundError as error:
            if suffix is not None and suffix != "":
                newpath = pathlib.Path(f"{path}.{suffix}")
                if verbose == 2:
                    warnings.warn(f"Path '{path}' not found. Attempting {newpath}.")
                path, suffix = newpath, None
            else:
                raise error
    else:
        try:
            if path.suffix == "" and suffix is not None and suffix != "":
                path = pathlib.Path(f"{path}.{suffix}")
            if path.exists() and path.is_file() and verbose == 2:
                warnings.warn(f"Path '{path}' exists, will overwrite it.")
            path = path.open("wb")
        except IsADirectoryError:
            warnings.warn(f"Path '{path}' is a folder. Will save instead to {path}_2")
            path = pathlib.Path(f"{path}_2")
        except FileNotFoundError:  # Occurs when the parent folder doesn't exist
            warnings.warn(f"Path '{path.parent}' does not exist. Will create it.")
            path.parent.mkdir(exist_ok=True, parents=True)

    # if opening was successful uses the identity function
    # if opening failed with IsADirectory|FileNotFound, calls open_path_pathlib
    #   with corrections
    # if reading failed with FileNotFoundError, calls open_path_pathlib with suffix

    return open_path(path, mode, verbose, suffix)


def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
    """
    Retrieve system and python env info for the current system.

    :param print_info: Whether to print or not those infos
    :return: Dictionary summing up the version for each relevant package
        and a formatted string.
    """
    env_info = {
        "OS": f"{platform.platform()} {platform.version()}",
        "Python": platform.python_version(),
        "JAX": jax.__version__,
        "Numpy": np.__version__,
    }
    env_info_str = ""
    for key, value in env_info.items():
        env_info_str += f"{key}: {value}\n"
    if print_info:
        print(env_info_str)
    return env_info, env_info_str


def save_to_zip_file(
    save_path: Union[str, pathlib.Path, io.BufferedIOBase],
    data: Optional[Dict[str, Any]] = None,
    params: Optional[Dict[str, Params]] = None,
) -> None:
    """
    Save model data to a zip archive.

    :param save_path: Where to store the model.
        if save_path is a str or pathlib.Path ensures that the path actually exists.
    :param data: Class parameters being stored (non-PyTorch variables)
    :param params: Model parameters being stored expected to contain an entry for every
                   state_dict with its name and the state_dict.
    """
    save_path = open_path(save_path, "w", verbose=0, suffix="zip")
    # data/params can be None, so do not
    # try to serialize them blindly
    serialized_data = None
    if data is not None:
        serialized_data = data_to_json(data)

    # Create a zip-archive and write our objects there.
    with zipfile.ZipFile(save_path, mode="w") as archive:
        # Do not try to save "None" elements
        if data is not None:
            archive.writestr("data", serialized_data)
        if params is not None:
            for file_name, dict_ in params.items():
                try:  # jax
                    bytes_dict = flax.serialization.to_bytes(dict_)
                    with archive.open(file_name + ".jax", mode="w") as param_file:
                        param_file.write(bytes_dict)
                except:  # torch
                    with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
                        th.save(dict_, param_file)

        # Save system info about the current python env
        archive.writestr("system_info.txt", get_system_info(print_info=False)[1])


def load_from_zip_file(
    load_path: Union[str, pathlib.Path, io.BufferedIOBase],
    load_data: bool = True,
    custom_objects: Optional[Dict[str, Any]] = None,
    verbose: int = 0,
    print_system_info: bool = False,
) -> (Tuple[Optional[Dict[str, Any]], Dict[str, Params], Optional[TensorDict]]):
    """
    Load model data from a .zip archive

    :param load_path: Where to load the model from
    :param load_data: Whether we should load and return data
        (class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
    :param custom_objects: Dictionary of objects to replace
        upon loading. If a variable is present in this dictionary as a
        key, it will not be deserialized and the corresponding item
        will be used instead. Similar to custom_objects in
        ``keras.models.load_model``. Useful when you have an object in
        file that can not be deserialized.
    :param verbose: Verbosity level, 0 means only warnings, 2 means debug information.
    :param print_system_info: Whether to print or not the system info
        about the saved model.
    :return: Class parameters, model state_dicts (aka "params", dict of state_dict)
        and dict of pytorch variables
    """
    load_path = open_path(load_path, "r", verbose=verbose, suffix="zip")
    # Open the zip archive and load data
    try:
        with zipfile.ZipFile(load_path) as archive:
            namelist = archive.namelist()
            # If data or parameters is not in the
            # zip archive, assume they were stored
            # as None (_save_to_file_zip allows this).
            data = None
            pytorch_variables = None

            # Debug system info first
            if print_system_info:
                if "system_info.txt" in namelist:
                    print("== SAVED MODEL SYSTEM INFO ==")
                    print(archive.read("system_info.txt").decode())
                else:
                    warnings.warn(
                        "The model was saved with SB3 <= 1.2.0 and thus cannot print system information.",
                        UserWarning,
                    )

            if "data" in namelist and load_data:
                # Load class parameters that are stored
                # with either JSON or pickle (not PyTorch variables).
                json_data = archive.read("data").decode()
                data = json_to_data(json_data, custom_objects=custom_objects)

            # Check for all .pth files and load them using th.load.
            # "pytorch_variables.pth" stores PyTorch variables, and any other .pth
            # files store state_dicts of variables with custom names (e.g. policy, policy.optimizer)
            pth_files = [file_name for file_name in namelist if os.path.splitext(file_name)[1] in [".jax", ".pth"]]
            params = dict()
            for file_path in pth_files:
                with archive.open(file_path, mode="r") as param_file:
                    if ".jax" in file_path:  # Jax
                        _params = param_file.read()
                    else:  # Torch
                        file_content = io.BytesIO()
                        file_content.write(param_file.read())
                        # go to start of file
                        file_content.seek(0)
                        # Load the parameters with the right ``map_location``.
                        # Remove ".pth" ending with splitext
                        _params = th.load(file_content, map_location="cuda:0")

                    params[file_path.split('.')[0]] = _params

    except zipfile.BadZipFile:
        # load_path wasn't a zip file
        raise ValueError(f"Error: the file {load_path} wasn't a zip-file")

    return data, params, pytorch_variables


@jax.jit
def target_update(model: Model, target: Model, tau: float) -> Model:
    new_target_params = jax.tree_map(lambda p, tp: p * tau + tp * (1 - tau), model.params, target.params)
    return target.replace(params=new_target_params)
