from torch import Tensor
import torch
import torch.nn.functional as F
from typing import List, Tuple, Dict, Any, Optional, Union, Type, TypeVar
from .utils import log1mexp, ExpEi, reparam_trick, Bessel

tanh_eps = 1e-20
euler_gamma = 0.57721566490153286060


def _box_shape_ok(t: Tensor, learnt_temp=False) -> bool:
    if len(t.shape) < 2:
        return False
    if not learnt_temp:
        if t.size(-2) != 2:
            return False
        return True
    else:
        if t.size(-2) != 4:
            return False

        return True


def _shape_error_str(tensor_name, expected_shape, actual_shape):
    return "Shape of {} has to be {} but is {}".format(
        tensor_name, expected_shape, tuple(actual_shape)
    )


# see: https://realpython.com/python-type-checking/#type-hints-for-methods
# to know why we need to use TypeVar
TBoxTensor = TypeVar("TBoxTensor", bound="BoxTensor")


class BoxTensor(object):
    """ A wrapper to which contains single tensor which
    represents single or multiple boxes.

    Have to use composition instead of inheritance because
    it is not safe to interit from :class:`torch.Tensor` because
    creating an instance of such a class will always make it a leaf node.
    This works for :class:`torch.nn.Parameter` but won't work for a general
    box_tensor.
    """

    def __init__(self, data: Tensor, learnt_temp: bool = False) -> None:
        """
        .. todo:: Validate the values of z, Z ? z < Z

        Arguments:
            data: Tensor of shape (**, zZ, num_dims). Here, zZ=2, where
                the 0th dim is for bottom left corner and 1st dim is for
                top right corner of the box
        """

        if _box_shape_ok(data, learnt_temp):
            self.data = data
        else:
            raise ValueError(_shape_error_str("data", "(**,2,num_dims)", data.shape))
        super().__init__()

    def __repr__(self):
        return "box_tensor_wrapper(" + self.data.__repr__() + ")"

    @property
    def z(self) -> Tensor:
        """Lower left coordinate as Tensor"""

        return self.data[..., 0, :]

    @property
    def Z(self) -> Tensor:
        """Top right coordinate as Tensor"""

        return self.data[..., 1, :]

    @property
    def box_type(self):
        return "BoxTensor"

    @property
    def centre(self) -> Tensor:
        """Centre coordinate as Tensor"""

        return (self.z + self.Z) / 2

    @classmethod
    def from_zZ(cls: Type[TBoxTensor], z: Tensor, Z: Tensor) -> TBoxTensor:
        """
        Creates a box by stacking z and Z along -2 dim.
        That is if z.shape == Z.shape == (**, num_dim),
        then the result would be box of shape (**, 2, num_dim)
        """

        if z.shape != Z.shape:
            raise ValueError(
                "Shape of z and Z should be same but is {} and {}".format(
                    z.shape, Z.shape
                )
            )
        box_val: Tensor = torch.stack((z, Z), -2)

        return cls(box_val)

    @classmethod
    def from_split(cls: Type[TBoxTensor], t: Tensor, dim: int = -1) -> TBoxTensor:
        """Creates a BoxTensor by splitting on the dimension dim at midpoint

        Args:
            t: input
            dim: dimension to split on

        Returns:
            BoxTensor: output BoxTensor

        Raises:
            ValueError: `dim` has to be even
        """
        len_dim = t.size(dim)

        if len_dim % 2 != 0:
            raise ValueError(
                "dim has to be even to split on it but is {}".format(t.size(dim))
            )
        split_point = int(len_dim / 2)
        z = t.index_select(
            dim,
            torch.tensor(list(range(split_point)), dtype=torch.int64, device=t.device),
        )

        Z = t.index_select(
            dim,
            torch.tensor(
                list(range(split_point, len_dim)), dtype=torch.int64, device=t.device
            ),
        )

        return cls.from_zZ(z, Z)

    def _intersection(
        self: TBoxTensor,
        other: TBoxTensor,
        gumbel_beta: float = 1.0,
        bayesian: bool = False,
    ) -> Tuple[Tensor, Tensor]:
        t1 = self
        t2 = other

        if bayesian:
            try:
                z = gumbel_beta * torch.logaddexp(
                    t1.z / gumbel_beta, t2.z / gumbel_beta
                )
                z = torch.max(z, torch.max(t1.z, t2.z))
                Z = -gumbel_beta * torch.logaddexp(
                    -t1.Z / gumbel_beta, -t2.Z / gumbel_beta
                )
                Z = torch.min(Z, torch.min(t1.Z, t2.Z))
            except Exception as e:
                print("Gumbel intersection is not possible")
                breakpoint()
        else:
            z = torch.max(t1.z, t2.z)
            Z = torch.min(t1.Z, t2.Z)

        return z, Z

    def _box_pool_layer(
        self: Type[TBoxTensor],
        keys: TBoxTensor,
        weights: torch.tensor,
        aggregate: str = "attention",
        pool: str = "avg",
        threshold: float = -121.0,
    ):

        weights[torch.where(weights < threshold)] = float("-inf")
        if aggregate == "attention":
            softmax = torch.nn.Softmax(dim=1)
            weights = softmax(weights)

        if pool == "avg":
            return torch.tensordot(weights, keys.data, dims=1) / keys.data.shape[0]
        else:
            return self

    def gumbel_intersection_log_volume(
        self: TBoxTensor,
        other: TBoxTensor,
        volume_temp=1.0,
        intersection_temp: float = 1.0,
        scale=1.0,
    ) -> TBoxTensor:
        z, Z = self._intersection(other, gumbel_beta=intersection_temp, bayesian=True)
        vol = self._log_soft_volume_adjusted(z, Z, temp=volume_temp, scale=scale)
        return vol

    def intersection(self: TBoxTensor, other: TBoxTensor) -> TBoxTensor:
        """ Gives intersection of self and other.

        .. note:: This function can give fipped boxes, i.e. where z[i] > Z[i]
        """
        z, Z = self._intersection(other)

        return self.from_zZ(z, Z)

    def join(self: TBoxTensor, other: TBoxTensor) -> TBoxTensor:
        """Gives join"""
        z = torch.min(self.z, other.z)
        Z = torch.max(self.Z, other.Z)

        return self.from_zZ(z, Z)

    def get(self: TBoxTensor, indices: torch.LongTensor, dim: int = 0) -> TBoxTensor:
        """ Get boxes at particular indices on a particular dimension.

        Shape of indices should be
        according to the shape of BoxTensor. For instance, if shape of
        BoxTensor is (3,4,2,5), then shape of indice should be (*,*)

        """

        return self.__class__(self.data.index_select(dim, indices))

    def clamp_volume(self) -> Tensor:
        """Volume of boxes. Returns 0 where boxes are flipped.

        Returns:

            Tensor of shape (**, ) when self has shape (**, 2, num_dims)
        """

        return torch.prod((self.Z - self.z).clamp_min(0), dim=-1)

    @classmethod
    def _soft_volume(
        cls, z: Tensor, Z: Tensor, temp: float = 1.0, scale: Union[float, Tensor] = 1.0
    ) -> Tensor:
        """ scale has to be between 0 and 1"""

        if not cls._in_zero_one(scale):
            raise ValueError("Scale should be in (0,1] but is {}".format(scale))
        side_lengths = F.softplus(Z - z, beta=temp)

        return torch.prod(side_lengths, dim=-1) * scale

    def soft_volume(
        self, temp: float = 1.0, scale: Union[float, Tensor] = 1.0
    ) -> Tensor:
        """Volume of boxes. Uses softplus instead of ReLU/clamp

        Returns:
            Tensor of shape (**, ) when self has shape (**, 2, num_dims)
        """

        return self._soft_volume(self.z, self.Z, temp, scale)

    def intersection_soft_volume(
        self, other: TBoxTensor, temp: float = 1.0, scale: Union[float, Tensor] = 1.0
    ) -> Tensor:
        """ Computes the soft volume of the intersection box

        Return:
            Tensor of shape(**,) when self and other have shape (**, 2, num_dims)
        """
        # intersection
        z, Z = self._intersection(other)

        return self._soft_volume(z, Z, temp, scale)

    @classmethod
    def _log_soft_volume(
        cls, z: Tensor, Z: Tensor, temp: float = 1.0, scale: Union[float, Tensor] = 1.0
    ) -> Tensor:
        eps = torch.finfo(z.dtype).tiny  # type: ignore

        if isinstance(scale, float):
            s = torch.tensor(scale)
        else:
            s = scale

        return torch.sum(
            torch.log(F.softplus(Z - z, beta=temp) + 1e-23), dim=-1
        ) + torch.log(
            s
        )  # need this eps to that the derivative of log does not blow

    def log_soft_volume(
        self, temp: float = 1.0, scale: Union[float, Tensor] = 1.0
    ) -> Tensor:
        res = self._log_soft_volume(self.z, self.Z, temp=temp, scale=scale)

        return res

    @classmethod
    def _log_soft_volume_adjusted(
        cls,
        z: Tensor,
        Z: Tensor,
        temp: float = 1.0,
        gumbel_beta: float = 1.0,
        scale: Union[float, Tensor] = 1.0,
    ) -> Tensor:
        eps = torch.finfo(z.dtype).tiny  # type: ignore

        if isinstance(scale, float):
            s = torch.tensor(scale)
        else:
            s = scale

        return torch.sum(
            torch.log(
                F.softplus(Z - z - 2 * euler_gamma * gumbel_beta, beta=temp) + 1e-23
            ),
            dim=-1,
        ) + torch.log(s)

    def intersection_log_soft_volume(
        self,
        other: TBoxTensor,
        temp: float = 1.0,
        gumbel_beta: float = 1.0,
        bayesian: bool = False,
        scale: Union[float, Tensor] = 1.0,
    ) -> Tensor:
        z, Z = self._intersection(other, gumbel_beta, bayesian)
        vol = self._log_soft_volume(z, Z, temp=temp, scale=scale)

        return vol

    @classmethod
    def _pick_dim(cls, t: torch.Tensor, method="max"):
        # t shape=(**, num_dims)

        if method == "max":

            def op(x):
                return torch.max(x, dim=-1)[0]  # noqa

        elif method == "min":

            def op(x):
                return torch.min(x, dim=-1)[0]  # noqa

        else:
            raise ValueError

        return op(t)  # shape =(**)

    @classmethod
    def cat(cls: Type[TBoxTensor], tensors: Tuple[TBoxTensor, ...]) -> TBoxTensor:

        return cls(torch.cat(tuple(map(lambda x: x.data, tensors)), -1))

    @classmethod
    def _scaled_box(
        cls, z_F: Tensor, Z_F: Tensor, z_R: Tensor, Z_R: Tensor
    ) -> Tuple[Tensor, Tensor]:
        L_R = (Z_R - z_R).clamp_min(0)
        z_S = z_R + z_F * L_R
        Z_S = Z_R + (Z_F - 1) * L_R

        return z_S, Z_S

    def scaled_box(self, ref_box: TBoxTensor) -> "BoxTensor":
        z, Z = self._scaled_box(self.z, self.Z, ref_box.z, ref_box.Z)

        return BoxTensor.from_zZ(z, Z)

    @classmethod
    def get_wW(cls, z, Z):
        return z, Z

    @classmethod
    def _weights_init(cls, weights: torch.Tensor):
        """An in-place weight initializer method
        which can be used to do sensible init
        of weights depending on box type.
        For this base class, this method does nothing"""
        pass


def inv_sigmoid(v: Tensor) -> Tensor:
    return torch.log(v / (1.0 - v))  # type:ignore


class BoxTensorLearntTemp(BoxTensor):
    """
     Same as BoxTensor but here the volume and the intersection temperatures are learnt.
    """

    def __init__(self, data: Tensor) -> None:
        """
        .. todo:: Validate the values of z, Z ? z < Z

        Arguments:
            data: Tensor of shape (**, zZ, num_dims). Here, zZ=2, where
                the 0th dim is for bottom left corner and 1st dim is for
                top right corner of the box
        """
        if _box_shape_ok(data, learnt_temp=True):
            self.data = data
        else:
            raise ValueError(_shape_error_str("data", "(**, 4, num_dims)", data.shape))
        super().__init__(data, learnt_temp=True)

    @property
    def int_temp(self, _min: float = 0.0005, _max: float = 500.0) -> Tensor:
        """intersection temp of the box as Tensor"""

        return (_max - _min) * torch.sigmoid(self.data[..., -2, :]) + _min

    @property
    def vol_temp(self, _min: float = 0.0005, _max: float = 1000.0) -> Tensor:
        """Volume temp as Tensor"""

        return (_max - _min) * torch.sigmoid(self.data[..., -1, :]) + _min

    @property
    def box_type(self):
        """ Name of the class """
        return "BoxTensorLearntTemp"

    @classmethod
    def from_zZ(
        cls: Type[TBoxTensor], z: Tensor, Z: Tensor, int_temp: Tensor, vol_temp: Tensor
    ) -> TBoxTensor:
        """
        Creates a box by stacking z and Z along -2 dim.
        That is if z.shape == Z.shape == (**, num_dim),
        then the result would be box of shape (**, 2, num_dim)
        """

        if z.shape != Z.shape:
            raise ValueError(
                "Shape of z and Z should be same but is {} and {}".format(
                    z.shape, Z.shape
                )
            )
        box_val: Tensor = torch.stack((z, Z, int_temp, vol_temp), -2)

        return cls(box_val)

    @classmethod
    def from_split(cls: Type[TBoxTensor], t: Tensor, dim: int = -1) -> TBoxTensor:
        """Creates a BoxTensor by splitting on the dimension dim at midpoint

        Args:
            t: input
            dim: dimension to split on

        Returns:
            BoxTensor: output BoxTensor

        Raises:
            ValueError: `dim` has to be divisible by 4
        """
        len_dim = t.size(dim)

        if len_dim % 4 != 0:
            raise ValueError(
                "dim has to be even to split on it but is {}".format(t.size(dim))
            )
        split_point = int(len_dim / 4)
        w = t.index_select(
            dim,
            torch.tensor(list(range(split_point)), dtype=torch.int64, device=t.device),
        )

        W = t.index_select(
            dim,
            torch.tensor(
                list(range(split_point, split_point * 2)),
                dtype=torch.int64,
                device=t.device,
            ),
        )
        int_temp = t.index_select(
            dim,
            torch.tensor(
                list(range(split_point * 2, split_point * 3)),
                dtype=torch.int64,
                device=t.device,
            ),
        )
        vol_temp = t.index_select(
            dim,
            torch.tensor(
                list(range(split_point * 3, len_dim)),
                dtype=torch.int64,
                device=t.device,
            ),
        )
        box_val: Tensor = torch.stack((w, W, int_temp, vol_temp), -2)

        return cls(box_val)

    def _intersection(
        self: TBoxTensor, other: TBoxTensor, gumbel_beta: Tensor, bayesian: bool = False
    ) -> Tuple[Tensor, Tensor]:
        t1 = self
        t2 = other
        gumbel_beta = (t1.int_temp + t2.int_temp) / 2

        if bayesian:
            try:
                z = gumbel_beta * torch.logaddexp(
                    torch.div(t1.z, gumbel_beta), torch.div(t2.z, gumbel_beta)
                )
                z = torch.max(z, torch.max(t1.z, t2.z))
                Z = -gumbel_beta * torch.logaddexp(
                    torch.div(-t1.z, gumbel_beta), torch.div(-t2.z, gumbel_beta)
                )
                Z = torch.min(Z, torch.min(t1.Z, t2.Z))
            except:
                print("Gumbel intersection is not possible")
        else:
            z = torch.max(t1.z, t2.z)
            Z = torch.min(t1.Z, t2.Z)

        return z, Z

    @classmethod
    def _log_soft_volume_adjusted(
        cls,
        z: Tensor,
        Z: Tensor,
        temp: Tensor,
        gumbel_beta: Tensor,
        scale: Union[float, Tensor] = 1.0,
    ) -> Tensor:
        eps = torch.finfo(z.dtype).tiny  # type: ignore

        if isinstance(scale, float):
            s = torch.tensor(scale)
        else:
            s = scale

        return torch.sum(
            torch.log(
                (
                    F.softplus(torch.div((Z - z - 2 * euler_gamma * gumbel_beta), temp))
                    * temp
                ).clamp_min(eps)
            ),
            dim=-1,
        ) + torch.log(s)

    def log_soft_volume(
        self, temp: float = 1.0, scale: Union[float, Tensor] = 1.0
    ) -> Tensor:
        res = self._log_soft_volume(self.z, self.Z, temp=temp, scale=scale)

        return res

    def gumbel_intersection_log_volume(
        self: TBoxTensor,
        other: TBoxTensor,
        volume_temp: float = 1.0,
        intersection_temp: float = 1.0,
        scale: float = 1.0,
    ) -> TBoxTensor:
        gumbel_beta = (self.int_temp + other.int_temp) / 2
        volume_temp = (self.vol_temp + other.vol_temp) / 2
        z, Z = self._intersection(other, gumbel_beta=gumbel_beta, bayesian=True)
        vol = self._log_soft_volume_adjusted(
            z, Z, temp=volume_temp, gumbel_beta=gumbel_beta, scale=scale
        )
        return vol

    ##[tODO]##
    def gumbel_intersection_log_volume_w_marginal(
        self: TBoxTensor, other: TBoxTensor, scale=1.0
    ) -> TBoxTensor:
        """
        [todo]:
        Stack the tensors to take intersection and marginal.
        """
        gumbel_beta = (self.int_temp + other.int_temp) / 2
        volume_temp = (self.vol_temp + other.volume_temp) / 2
        z, Z = self._intersection(other, gumbel_beta=gumbel_beta, bayesian=True)
        vol_intersection = self._log_soft_volume_adjusted(
            z, Z, temp=volume_temp, gumbel_beta=gumbel_beta, scale=scale
        )
        other_volume = self._log_soft_volume_adjusted(
            other.z, other.Z, temp=volume_temp, gumbel_beta=gumbel_beta, scale=scale
        )
        return vol_intersection, other_volume


class SigmoidBoxTensor(BoxTensor):
    """Same as BoxTensor but with a different parameterization: (**,wW, num_dims)

    z = sigmoid(w)
    Z = z + sigmoid(W) * (1-z)

    w = inv_sigmoid(z)
    W = inv_sigmoid((Z - z)/(1-z))
    """

    @property
    def z(self) -> Tensor:
        return torch.sigmoid(self.data[..., 0, :])

    @property
    def Z(self) -> Tensor:
        z = self.z
        Z = z + torch.sigmoid(self.data[..., 1, :]) * (1.0 - z)  # type: ignore

        return Z

    @classmethod
    def from_zZ(cls: Type[TBoxTensor], z: Tensor, Z: Tensor) -> TBoxTensor:
        """ This method is blocked for now"""
        raise RuntimeError("Do not use from_zZ method of SigmoidBoxTensor")

        if z.shape != Z.shape:
            raise ValueError(
                "Shape of z and Z should be same but is {} and {}".format(
                    z.shape, Z.shape
                )
            )
        eps = torch.finfo(z.dtype).tiny  # type: ignore
        w = inv_sigmoid(z.clamp(eps, 1.0 - eps))
        W = inv_sigmoid(((Z - z) / (1.0 - z)).clamp(eps, 1.0 - eps))  # type:ignore

        box_val: Tensor = torch.stack((w, W), -2)

        return cls(box_val)

    @classmethod
    def get_wW(cls, z, Z):
        if z.shape != Z.shape:
            raise ValueError(
                "Shape of z and Z should be same but is {} and {}".format(
                    z.shape, Z.shape
                )
            )
        eps = torch.finfo(z.dtype).tiny  # type: ignore
        w = inv_sigmoid(z.clamp(eps, 1.0 - eps))
        W = inv_sigmoid(((Z - z) / (1.0 - z)).clamp(eps, 1.0 - eps))  # type:ignore

        return w, W

    @classmethod
    def from_split(cls: Type[TBoxTensor], t: Tensor, dim: int = -1) -> TBoxTensor:
        """Creates a BoxTensor by splitting on the dimension dim at midpoint

        Args:
            t: input
            dim: dimension to split on

        Returns:
            BoxTensor: output BoxTensor

        Raises:
            ValueError: `dim` has to be even
        """
        len_dim = t.size(dim)

        if len_dim % 2 != 0:
            raise ValueError(
                "dim has to be even to split on it but is {}".format(t.size(dim))
            )
        split_point = int(len_dim / 2)
        w = t.index_select(
            dim,
            torch.tensor(list(range(split_point)), dtype=torch.int64, device=t.device),
        )

        W = t.index_select(
            dim,
            torch.tensor(
                list(range(split_point, len_dim)), dtype=torch.int64, device=t.device
            ),
        )
        box_val: Tensor = torch.stack((w, W), -2)

        return cls(box_val)

    def intersection(self: TBoxTensor, other: TBoxTensor) -> TBoxTensor:
        """ Gives intersection of self and other.

        .. note:: This function can give fipped boxes, i.e. where z[i] > Z[i]
        """
        z, Z = self._intersection(other)

        return BoxTensor.from_zZ(z, Z)


class DeltaBoxTensor(SigmoidBoxTensor):
    """Same as BoxTensor but with a different parameterization: (**,wW, num_dims)

    z = w
    Z = z + delta(which is always positive)
    """

    @property
    def z(self) -> Tensor:
        return self.data[..., 0, :]

    @property
    def Z(self) -> Tensor:
        z = self.z
        Z = z + torch.nn.functional.softplus(self.data[..., 1, :], beta=10)

        return Z

    @classmethod
    def from_zZ(cls: Type[TBoxTensor], z: Tensor, Z: Tensor) -> TBoxTensor:

        if z.shape != Z.shape:
            raise ValueError(
                "Shape of z and Z should be same but is {} and {}".format(
                    z.shape, Z.shape
                )
            )
        w, W = cls.get_wW(z, Z)  # type:ignore

        box_val: Tensor = torch.stack((w, W), -2)

        return cls(box_val)

    @classmethod
    def get_wW(cls, z, Z):
        if z.shape != Z.shape:
            raise ValueError(
                "Shape of z and Z should be same but is {} and {}".format(
                    z.shape, Z.shape
                )
            )
        w = z
        W = _softplus_inverse(Z - z, beta=10.0)  # type:ignore

        return w, W


def _softplus_inverse(t: torch.Tensor, beta=1.0, threshold=20):
    below_thresh = beta * t < threshold
    res = t
    res[below_thresh] = torch.log(torch.exp(beta * t[below_thresh]) - 1.0) / beta

    return res
