# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..common.string_utils import camel_to_snake

_ACTIVATIONS = None


def _get_activations():
    global _ACTIVATIONS

    if _ACTIVATIONS is None:
        _ACTIVATIONS = {}
        for name, value in vars(nn.modules.activation).items():
            if value.__class__ is type and issubclass(value, nn.Module):
                register_activation(value, name)

    return _ACTIVATIONS


def register_activation(activation_class, name=None):
    if name is None:
        name = activation_class.__name__

    _get_activations()[camel_to_snake(name.replace('ReLU', 'Relu'))] = activation_class


def get_activation(activation):
    name, *args = activation.split('/')
    new_args = []
    for arg in args:
        for arg_type in [float, bool]:
            try:
                arg = arg_type(arg)
                break
            except Exception:
                pass
        new_args.append(arg)
    return _get_activations()[name](*new_args)


@register_activation
class Mish(nn.Module):
    def forward(self, x):
        # inlining this saves 1 second per epoch (V100 GPU) vs having a temp x
        # and then returning x(!)
        return x * (torch.tanh(F.softplus(x)))


@register_activation
class Swish(nn.Module):
    def forward(self, x):
        return x * F.sigmoid(x)
