import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb

from modules.attention import ScaledDotAttention


class LabelSmoothing(nn.Module):
    """
    With label smoothing,
    KL-divergence between q_{smoothed ground truth prob.}(w)
    and p_{prob. computed by model}(w) is minimized.
    """
    def __init__(self):
        super(LabelSmoothing, self).__init__()

    @staticmethod
    def _smooth_one_hot(targets: torch.Tensor, n_classes: int, smoothing=0.0):
        try:
            targets.size(0)
        except RuntimeError:
            pdb.set_trace()

        assert 0 <= smoothing < 1
        with torch.no_grad():
            targets = torch.empty(size=(targets.size(0), n_classes),
                                  device=targets.device) \
                .fill_(smoothing / (n_classes - 1)) \
                .scatter_(1, targets.data.unsqueeze(1), 1. - smoothing)
        return targets

    def forward(self, output, target, smoothing=0.1, ignore_index=-100, ignore_mask=None, reduction="mean"):
        """
        output (FloatTensor): batch_size x n_classes
        target (LongTensor): batch_size
        """
        cls_idx = (target != ignore_index).nonzero(as_tuple=False).view(-1)
        output = output.index_select(0, cls_idx)
        target = target.index_select(0, cls_idx)
        target_smoothing = LabelSmoothing._smooth_one_hot(target, output.size(-1), smoothing)

        if ignore_mask is not None:
            target_smoothing.masked_fill_(ignore_mask == 0, 0)
        return torch.abs(F.kl_div(output.log_softmax(dim=-1), target_smoothing, reduction=reduction))


class PointerPredicator(nn.Module):
    def __init__(self, embedding, d_hidden, dropout_p=0.3, ignore_index=-1, pointers=[]):
        super().__init__()
        self.d_hidden = d_hidden
        self.dropout_p = dropout_p
        self.cls_embedding = embedding

        # self.cross_entropy = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum')
        self.cross_entropy = LabelSmoothing()

        self.pointer_base = self.vocab_size = self.cls_embedding.num_embeddings
        self.ptr_attns = nn.ModuleList([ScaledDotAttention(self.d_hidden, dropout_p=dropout_p) for _ in range(len(pointers))])

        n_vec = 1 + len(pointers)
        self.state_proj = nn.Sequential(
            nn.Linear(self.d_hidden * n_vec, self.d_hidden),
            nn.Dropout(p=dropout_p),
            nn.Tanh()
        )
        self.cls_predicator = nn.Linear(self.d_hidden, self.cls_embedding.num_embeddings)

    # pointers: [bs, n_ptr, dim]
    # x: [bs, q_len]
    def embedding(self, x, pointers=None, debug=None, ex=None):
        if pointers is None:
            return self.cls_embedding(x)

        bs = x.size(0)
        x_len = x.size(1)
        ptr_lens = [ptr.size(1) for ptr in pointers]
        dim = pointers[0].size(2)
        x_embs = torch.zeros(bs, x_len, dim).cuda()

        for b in range(bs):
            for i, x_idx in enumerate(x[b]):
                if x_idx >= self.pointer_base:
                    x_idx = x_idx.item()
                    x_idx -= self.pointer_base
                    branch_idx = -1

                    for j, ptr_len in enumerate(ptr_lens):
                        if x_idx < ptr_len:
                            branch_idx = j
                            break
                        x_idx -= ptr_len
                    assert branch_idx != -1

                    try:
                        if debug:
                            assert x_idx == debug[b][i]
                    except:
                        pdb.set_trace()
                    e = pointers[branch_idx][b, x_idx, :]
                else:
                    e = self.cls_embedding(x_idx)

                x_embs[b, i, :] = e.view(-1)
        return x_embs

    # pointers: [bs, p_cnt, dim]
    # x: [bs, x_cnt, dim]
    # mode = ["cls", "pointer"]
    def forward(self, x, mask, pointers, mode="pointer", with_cv=False, ex=None):
        ctx_vecs = [x]
        ptr_score_list = []

        for i, (ptr, ptr_mask) in enumerate(pointers):
            ptr_vec, ptr_scores = self.ptr_attns[i](x, ptr, ptr, with_projection=True, with_output=False, q_mask=mask, k_mask=ptr_mask)
            ctx_vecs.append(ptr_vec)
            ptr_score_list.append(ptr_scores)

        ctx_in = torch.cat(ctx_vecs, dim=-1)
        ctx_vec = self.state_proj(ctx_in)

        cls_score = self.cls_predicator(ctx_vec)
        cls_mask = torch.ones(x.size(0), self.cls_embedding.num_embeddings, device=x.device, dtype=torch.long)

        if mode == "pointer":
            y_prob = torch.cat([cls_score] + ptr_score_list, dim=-1)
            y_prob_mask = torch.cat([cls_mask] + [mask for _, mask in pointers], dim=-1)
            if with_cv:
                return (y_prob, y_prob_mask), ctx_vec
            return (y_prob, y_prob_mask)
        elif mode == "cls":
            if with_cv:
                return (cls_score, cls_mask), ctx_vec
            return cls_score, cls_mask

    def loss(self, predict, label, ignore_index=0, ignore_mask=None, padding_mask=None, generate=True, debug=None):
        bs = predict.size(0)
        cls_num = predict.size(2)
        loss = 0.

        if generate:
            offset = 1
        else:
            offset = 0

        for b in range(bs):
            ptr_len = (padding_mask[b] != 0).sum().item()
            if ptr_len == 0:
                bs -= 1
                continue

            # loss += self.cross_entropy(
            #     predict[b, :ptr_len-offset, :].contiguous().view(-1, cls_num),
            #     label[b, offset:ptr_len].contiguous().view(-1),
            # )
            loss += self.cross_entropy(
                predict[b, :ptr_len-offset, :].contiguous().view(-1, cls_num),
                label[b, offset:ptr_len].contiguous().view(-1),
                ignore_index=ignore_index,
                ignore_mask=ignore_mask[b],
                reduction='sum',
            )

        loss /= bs
        return [loss]
