from torch import nn
import numpy as np

from src.models.stochastic.util import get_encoder


class Classifier(nn.Module):
    """
    The Encoder takes an input text (and rationale z) and computes p(y|x,z)
    Supports a sigmoid on the final result (for regression)
    If not sigmoid, will assume cross-entropy loss (for classification)
    """

    def __init__(self,
                embed:        nn.Embedding = None,
                hidden_size:  int = 200,
                output_size:  int = 1,
                dropout:      float = 0.1,
                layer:        str = "lstm",
            ):

        super(Classifier, self).__init__()

        emb_size = embed.weight.shape[1]

        self.embed_layer = nn.Sequential(
            embed,
            nn.Dropout(p=dropout)
        )

        self.enc_layer = get_encoder(layer, emb_size, hidden_size)

        if hasattr(self.enc_layer, "cnn"):
            enc_size = self.enc_layer.cnn.out_channels
        else:
            enc_size = hidden_size * 2

        self.output_layer = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(enc_size, output_size)
        )

        self.report_params()

    def report_params(self):
        # This has 1604 fewer params compared to the original, since only 1
        # aspect is trained, not all. The original code has 5 output classes,
        # instead of 1, and then only supervise 1 output class.
        count = 0
        for name, p in self.named_parameters():
            if p.requires_grad and "embed" not in name:
                count += np.prod(list(p.shape))
        print("{} #params: {}".format(self.__class__.__name__, count))

    def forward(self, x, mask, z=None):

        max_len = min(mask.size(-1), z.size(-1))
        mask = mask[:,:max_len]
        z = z[:,:max_len]

        rnn_mask = mask
        emb = self.embed_layer(x)

        # z is also used to control when the encoder layer is active
        lengths = mask.long().sum(1)

        # apply z to main inputs
        if z is not None:
            z_mask = (mask.float() * z).unsqueeze(-1)  # [B, T, 1]
            rnn_mask = z_mask.squeeze(-1) > 0.  # z could be continuous           
            emb = emb[:, :max(lengths)] * z_mask

        # encode the sentence
        _, final = self.enc_layer(emb, rnn_mask, lengths)

        # predict sentiment from final state(s)
        logits = self.output_layer(final)

        return logits