# -*- coding: utf-8 -*-

import functools
import itertools
import sys
import time
import warnings
from abc import ABCMeta

import dataclasses

from .logger import LOGGER

###############################################################################
#                               Useful Functions                              #
###############################################################################


def identity(x):
    return x


def set_current_process_name(newname):
    try:
        from setproctitle import setproctitle
        setproctitle(newname)
    except ModuleNotFoundError:
        LOGGER.warning('Module "setproctitle" not found. try fallback method')
    try:
        from ctypes import cdll, byref, create_string_buffer
        libc = cdll.LoadLibrary('libc.so.6')
        buff = create_string_buffer(len(newname) + 1)
        buff.value = newname.encode('utf-8')
        libc.prctl(15, byref(buff), 0, 0, 0)
    except Exception:
        LOGGER.exception('Cannot set current process name')


def query_yes_no(question, default='yes', with_choice_all=False):
    """The "answer" return value is True for "yes" or False for "no"."""
    valid = {'yes': 'yes', 'y': 'yes', 'no': 'no', 'n': 'no'}
    if default is None:
        prompt = '[y/n{}] '
    elif default == 'yes':
        prompt = '[Y/n{}] '
    elif default == 'no':
        prompt = '[y/N{}] '
    else:
        raise ValueError('invalid default answer: "{}"'.format(default))

    prompt = prompt.format('/!y/!n' if with_choice_all else '')
    if with_choice_all:
        valid.update({'!yes': 'all-yes', '!y': 'all-yes', '!no': 'all-no', '!n': 'all-no'})
    while True:
        print(question, prompt, end='')
        choice = input().lower().strip()
        if default is not None and choice == '':
            return valid[default]
        elif choice in valid:
            return valid[choice]
        else:
            print('Please respond with \'yes\' or \'no\' ', end='')


###############################################################################
#                             Useful Base Classes                             #
###############################################################################


class Singleton:
    __instance = None

    def __new__(cls, *args, **kwargs):
        if not cls.__instance:
            cls.__instance = object.__new__(cls, *args, **kwargs)
        return cls.__instance

    def __copy__(cls, instance):
        return instance

    @staticmethod
    def create(name):
        return type(name, (Singleton,), {})


class DotDict(dict):
    def __getattr__(self, name):
        try:
            return self[name]
        except KeyError:
            raise AttributeError(name)

    def __setattr__(self, key, value):
        self[key] = value


class IdentityDict(dict):
    """ A dict like IdentityHashMap in java"""

    def __init__(self, **elements):
        super().__init__((id(key), value) for key, value in elements.items())

    def __setitem__(self, key, value):
        super().__setitem__(id(key), value)

    def __getitem__(self, item):
        super().__getitem__(id(item))

    def get(self, key, default=None):
        return super().get(id(key), default)

    def __getstate__(self):
        raise NotImplementedError(f'Cannot pickle {self.__class__.__name__} object.')


class AverageMeter:
    class Cell:
        __slots__ = ('min', 'max', 'value', 'count', 'values')

        def __init__(self, initial_value):
            self.value = initial_value
            self.min = self.max = self.count = 0

        def __str__(self):
            return f'#<Cell avg/min/max={self.avg()}/{self.min}/{self.max}>'

        def avg(self):
            return (self.value / self.count) if self.count > 0 else 0

        def add(self, key, value, count):
            self.value += value
            self.count += count
            self.min = min(self.min, value)
            self.max = max(self.max, value)

    class CacheCell(Cell):
        def __init__(self, initial_value):
            super().__init__(initial_value)
            self.values = []

        def add(self, key, value, count):
            super().add(key, value, count)

            self.values.extend(itertools.repeat(value, count))

    def __init__(self, save_data=False):
        self._data = {}
        self._save_data = save_data

    def __getstate__(self):
        return {'_data': self._data, '_save_data': self._save_data}

    def __setstate__(self, state):
        self._data = state['_data']
        self._save_data = state['_save_data']

    def __getattr__(self, key):
        if key in self._data:
            return self._data[key].avg()
        raise AttributeError

    def make_cell(self, initial_value):
        if self._save_data:
            return self.CacheCell(initial_value)
        return self.Cell(initial_value)

    def cell(self, key):
        return self._data[key]

    def items(self):
        return zip(self.keys(), self.values())

    def clear(self):
        for key in self._data:
            self._data[key] = self.make_cell(0)

    def keys(self):
        return self._data.keys()

    def values(self):
        return self._data.values()

    def avgs(self):
        return {name: cell.avg() for name, cell in self._data.items()}

    def add(self, key, value, count=1):
        if count == 0:
            return
        self._data.setdefault(key, self.make_cell(0)).add(key, value, count)


class ProgressReporter:
    def __init__(self, stop=-1, step=None, start=0,
                 message_fn=None, prompt='progress',
                 stream=sys.stderr, print_time=False, newline=False):
        self.current = start
        self.stop = stop
        if step is None:
            step = max(10, stop // 5)
        self._start = start
        self._step = step
        self._stream = stream
        self._fn = message_fn
        self._prompt = prompt
        self._print_time = print_time
        self._printed = False
        self._newline = newline

    def _report(self):
        try:
            suffix = self._fn(self) if self._fn else ''
        except Exception:
            suffix = '<error occurs>'
        if self._print_time and self.current != 0:
            speed = ' ({:.3f}s/tick)'.format((time.time() - self._start_time) / self.current)
        else:
            speed = ''
        stop = '?' if self.stop < 1 else self.stop
        message = f'{self._prompt}: {self.current}/{stop}{speed} {suffix}'
        self._stream.write(message)
        self._stream.write('\n' if self._newline else '\r')

    def start(self):
        self.current = self._start
        if self._print_time:
            self._start_time = time.time()

    def finish(self):
        if not self._printed:
            self._report()
        if not self._newline:
            self._stream.write('\n')

    def __enter__(self, *_):
        self.start()
        return self

    def __exit__(self, *_):
        self.finish()

    def __iter__(self):
        with self:
            for i in range(0, self.stop):
                yield i
                self.tick()

    def __call__(self, iterable):
        with self:
            for value in iterable:
                yield value
                self.tick()

    def tick(self, count=1):
        self.current += count
        if self.current % self._step == 0:
            self._printed = True
            self._report()
        else:
            self._printed = False


class MethodFactory:
    def __init__(self, include_none=False):
        self._methods = {}
        self._arg_keys = {}

        if include_none:
            self._methods['none'] = identity

    def __repr__(self):
        return repr(self._methods)

    def __iter__(self):
        return iter(self._methods)

    def items(self):
        return self._methods.items()

    def values(self):
        return self._methods.values()

    def keys(self):
        return self._methods.keys()

    def register(self, name, override=False, fn=None, suffix_args=None):
        def wrapper(cls):
            assert override or name not in self._methods, f'{name} is used {self}'
            self._methods[name] = cls

            nonlocal suffix_args
            if suffix_args is not None:
                if isinstance(suffix_args, str):
                    suffix_args = [suffix_args]
                self._arg_keys[name] = suffix_args

            return cls
        if fn is not None:
            return wrapper(fn)
        return wrapper

    def _parse_name(self, name):
        parts = name.split('/')
        if len(parts) == 1:
            return name, {}
        return parts[0], dict(zip(self._arg_keys[parts[0]], parts[1:]))

    def invoke(self, name, *args, **kwargs):
        name, suffix_args = self._parse_name(name)
        return self._methods[name](*args, **kwargs, **suffix_args)

    def normalize(self, name):
        if callable(name):
            return name
        return self._methods[name.split('/', 1)[0]]


class lazy_property:
    r""" Used as a decorator for lazy loading of class attributes. """

    def __init__(self, wrapped):
        self.wrapped = wrapped
        functools.update_wrapper(self, wrapped)

    def __get__(self, instance, obj_type=None):
        if instance is None:
            return self
        wrapped = self.wrapped
        value = wrapped(instance)
        name = wrapped.__name__

        setattr(instance, name, value)
        if not hasattr(instance, name):  # some class override setattr
            object.__setattr__(instance, name, value)

        return value


class DataClassMeta(ABCMeta):
    def __new__(cls, *args, **kwargs):
        return dataclasses.dataclass(super().__new__(cls, *args, **kwargs))


###############################################################################
#                                  Decorators                                 #
###############################################################################


def deprecated(func):
    """This is a decorator which can be used to mark functions as deprecated. It
    will result in a warning being emitted when the function is used.

    """

    @functools.wraps(func)
    def new_func(*args, **kwargs):
        warnings.simplefilter('always', DeprecationWarning)  # turn off filter
        warnings.warn(f'Call to deprecated function {func.__name__}.',
                      category=DeprecationWarning,
                      stacklevel=2)
        warnings.simplefilter('default', DeprecationWarning)  # reset filter
        return func(*args, **kwargs)
    return new_func


def WIP(func):
    """This is a decorator which can be used to mark functions is working in
    progress. It will result in a warning being emmitted when the function is
    used.

    """

    @functools.wraps(func)
    def new_func(*args, **kwargs):
        warnings.warn(f'WIP for function {func.__name__}', category=UserWarning, stacklevel=2)
        return func(*args, **kwargs)

    return new_func
