import torch
import torch.nn as nn
import random
from autoencoder import Decoder
import torch.nn.functional as F

# Rename to RNNAE


class BaseARDecoder(Decoder):
    def __init__(self, config):
        super(BaseARDecoder, self).__init__(config)

        self.teacher_forcing_ratio = config.teacher_forcing_ratio
        self.unit_sphere = config.unit_sphere
        self.teacher_forcing_batchwise = config.teacher_forcing_batchwise
        self.teacher_forcing_batchwise_only = config.teacher_forcing_batchwise_only

        self.config = config
        self.device = config.device

        self.vocab_size = config.vocab_size + 1

        self.hidden_size = config.hidden_size

        self.max_sequence_len = config.max_sequence_len
        self.input_size = config.hidden_size

        self.embedding = nn.Embedding(
            self.vocab_size, config.embedding_size, padding_idx=0)  # let 0 denote padding

        self.input_projection = nn.Linear(
            config.input_size, config.embedding_size, bias=False)

        self.eos_idx = config.eos_idx
        self.sos_idx = config.sos_idx
        self.unk_idx = config.unk_idx

        self.word_dropout = config.word_dropout
        # let 0 denote padding
        # self.out_projection = nn.Linear(
        # self.hidden_size, config.embedding_size, bias=False)
        # self.out_embedding = nn.Linear(
        # config.embedding_size, config.vocab_size + 1)
        #self.out = nn.Sequential(self.out_projection, self.out_embedding)

    def _get_output_and_update_memory(self, embedded_input, state, bottleneck_embedding, t):
        raise NotImplementedError("Needs to be implemented in sub-class.")

    def init_hidden_greedy(self, x):
        raise NotImplementedError("Needs to be implemented in sub-class.")

    def init_hidden_batchwise(self, x):
        raise NotImplementedError("Needs to be implemented in sub-class.")

    def _decode_eval(self, x, beam_width):
        if beam_width != 1:
            return self.beam_decode(x, beam_width)
        else:
            return self.greedy_decode(x)

    def _decode_all(self, embedded_teacher, h, l):
        raise NotImplementedError("Needs to be implemented in sub-class.")

    def _outlayer(self, output):
        output = output @ self.input_projection.weight
        output = output @ self.embedding.weight.transpose(0, 1)
        return output

    def _get_initial_inputs(self, x):
        """
        Returns the initial inputs fed to the model before starting the auto-regressive mode.
        Outputs of these initial inputs are not captured, if any.
        Also updates the length vector to accommodate additional inputs, if any.
        """
        batch_size = x[0].shape[0]
        initial = self.embedding(torch.tensor(
            [[self.sos_idx]], device=self.device).repeat(batch_size, 1))
        return initial

    def decode(self, x, train=False, actual=None, lengths=None, beam_width=1):
        if self.unit_sphere:
            x = x / x.norm(p=None, dim=-1, keepdim=True)

        if not train:
            return self._decode_eval(x, beam_width)
        else:
            h = self.init_hidden_greedy(x)

            embedded_input = self._get_initial_inputs(x)

            predictions = []
            pos = 1
            batch_size = x[0].shape[0]
            for t in range(1, lengths.max()):

                embedded_input = self.input_projection(embedded_input)
                output, h, pos = self._get_output_and_update_memory(
                    embedded_input, h, x, pos)
                #res = self.out(output.squeeze(1))
                res = self._outlayer(output.squeeze(1))

                ret = res.clone()
                ret *= torch.gt(lengths.reshape(-1, 1), t).float()

                predictions.append(ret)

                if self.training and (not self.teacher_forcing_batchwise or not self.teacher_forcing_batchwise_only) and random.random() < self.teacher_forcing_ratio:
                    assert self.training
                    next_token = actual[:, t].reshape(-1, 1)
                else:
                    topv, topi = res.topk(1)
                    next_token = topi.detach()
                if self.training and random.random() < self.word_dropout:
                    next_token = torch.tensor(
                        [[self.unk_idx]], device=self.device).repeat(batch_size, 1)

                embedded_input = self.embedding(next_token)

            predictions = torch.stack(predictions).permute(1, 0, 2)
            # is: seq, batch, pred
            # want: batch, seq, pred

            # Add SOS prediction to the output
            batch_size = x[0].shape[0]
            sos_padding = torch.zeros(
                (batch_size, 1, self.vocab_size), device=self.device)
            sos_padding[:, :, self.sos_idx] = 1
            return torch.cat((sos_padding, predictions), 1)

    def decode_teacher_forcing(self, x, actual, lengths):
        h = self.init_hidden_batchwise(x)

        # We want to feed everything but the last element (so the network can
        # predict the <EOS> token). We copy the actual sequence, remove <EOS>
        # token, then reshape the seq_len.
        # TODO: Whether to use .clone(), .clone().detach(), or
        # .clone().detach().requires_grad_(True)
        teacher_input = actual.clone()
        teacher_input[torch.arange(
            teacher_input.shape[0], device=self.device), lengths - 1] = 0
        if self.train and self.word_dropout > 0.:
            mask = torch.rand_like(
                teacher_input, device=teacher_input.device) < self.word_dropout
            teacher_input[mask] = self.unk_idx
        embedded_teacher = self.embedding(
            teacher_input[:, :teacher_input.shape[1] - 1])
        embedded_teacher = self.input_projection(embedded_teacher)

        output = self._decode_all(embedded_teacher, h, lengths - 1)

        # A "hacky" way to run the dense layer per timestep
        predictions = self._outlayer(
            output.contiguous().view(
                -1, output.shape[2])).reshape(
                    output.shape[0], output.shape[1], self.vocab_size)

        batch_size = x[0].shape[0]
        # Add SOS prediction to the output
        sos_padding = torch.zeros(
            (batch_size, 1, self.vocab_size), device=self.device)
        sos_padding[:, :, self.sos_idx] = 1

        return torch.cat((sos_padding, predictions), 1)
        # return self.softmax(predictions)  # Commented since cross entropy
        # does a softmax

    def decode_train_greedy(self, x, lengths):
        h = self.init_hidden_greedy(x)

        embedded_input = self._get_initial_inputs(x)
        predictions = []

        pos = 1
        for t in range(1, lengths.max()):

            output, h, pos = self._get_output_and_update_memory(
                embedded_input, h, x, pos)

            res = self._outlayer(output.squeeze(1))

            ret = res.clone()
            ret *= torch.gt(lengths.reshape(-1, 1), t).float()

            predictions.append(ret)

            topv, topi = res.topk(1)

            embedded_input = self.embedding(topi.detach())

        predictions = torch.stack(predictions).permute(1, 0, 2)
        # is: seq, batch, pred
        # want: batch, seq, pred

        # Add SOS prediction to the output
        batch_size = x[0].shape[0]
        sos_padding = torch.zeros(
            (batch_size, 1, self.vocab_size), device=self.device)
        sos_padding[:, :, self.sos_idx] = 1
        return torch.cat((sos_padding, predictions), 1)

    # Removes the extra EOS tokens added
    def clip_predictions(self, pred):
        results = []
        for s in pred:
            curr = []
            for idx in s:
                curr.append(idx)
                if idx == self.eos_idx:
                    break
            results.append(curr)
        return results

    # Greedy decode for LSTMAE and LSTMAE
    # # TODO: Implement temperature
    # # https://nlp.stanford.edu/blog/maximum-likelihood-decoding-with-rnns-the-good-the-bad-and-the-ugly/
    def greedy_decode(self, x):
        batch_size = x[0].shape[0]
        h = self.init_hidden_greedy(x)

        embedded_input = self._get_initial_inputs(x)

        predictions = [[self.sos_idx] for _ in range(batch_size)]

        pos = 1
        for t in range(1, self.max_sequence_len):
            embedded_input = self.input_projection(embedded_input)
            output, h, pos = self._get_output_and_update_memory(
                embedded_input, h, x, pos)

            res = self._outlayer(output.squeeze(1))
            topv, topi = res.topk(1)

            done_count = 0
            for b in range(batch_size):
                if predictions[b][-1] != self.eos_idx:
                    predictions[b].append(topi[b].cpu().item())

                    # if last token placed, and not eos, just cut off
                    if t == self.max_sequence_len - 1 and predictions[b][-1] != self.eos_idx:
                        predictions[b].append(self.eos_idx)
                else:
                    done_count += 1
            if done_count == batch_size:
                break

            embedded_input = self.embedding(topi.detach())

        return self.clip_predictions(predictions)

    class BeamNode:
        def __init__(self, hidden_state, previous_node, word_id, log_prob, length):
            self.hidden_state = hidden_state
            self.previous_node = previous_node
            self.word_id = word_id
            self.log_prob = log_prob
            self.length = length

    def _hidden_from_beam(self, incomplete):
        raise NotImplementedError("Needs to be implemented in sub-class.")

    def _hidden_to_beam(self, h, indices):
        raise NotImplementedError("Needs to be implemented in sub-class.")

    # Only works for LSTM
    def beam_decode(self, x, beam_width=10):
        # x = (batch, hidden_size)
        # hidden_lstm = (layers, batch, hidden)
        batch_size = x[0].shape[0]
        h = self.init_hidden_greedy(x)
        decoded = [None for i in range(batch_size)]

        # beam_width nodes per batch
        incomplete = {ba: [
            self.BeamNode(h, None, torch.tensor(self.sos_idx, device=self.device), 0, 1) for be in range(beam_width)
        ] for ba in range(batch_size)}

        # create first hypotheses:
        # lstm input: (batch, seq_len, input_size)
        # lstm output: (batch, seq_len, hidden_size)
        embedded_input = self._get_initial_inputs(x)
        embedded_input = self.input_projection(embedded_input)
        pos = 1
        decoder_output, h, pos = self._get_output_and_update_memory(
            embedded_input, h, x, pos)

        for b in range(batch_size):
            # decoder_output[b] shape: (1, hidden_size)
            log_probs = F.log_softmax(
                self._outlayer(decoder_output[b]), dim=1).squeeze(0)
            k_log_probs, k_indices = torch.topk(log_probs, beam_width)
            for i in range(beam_width):
                prev_node = incomplete[b][i]
                incomplete[b][i] = self.BeamNode(self._hidden_to_beam(h, b),
                                                 prev_node,
                                                 k_indices[i],
                                                 k_log_probs[i],
                                                 2)

        for t in range(2, self.max_sequence_len):
            if len(incomplete) == 0:
                break
            # Prepare step [ batch1_beams | batch2_beams | | ]
            embedding_input = torch.tensor(
                [beam.word_id for batch in incomplete for beam in incomplete[batch]], device=self.device)
            # keep track of the order which beams are put in
            input_order = [batch for batch in incomplete]
            # embedding_input shape: (batch * beam_len)
            embedding_input = embedding_input.reshape(-1, 1)
            # embedding_input shape: (batch*beam_len, 1[seq_len])
            embedded_input = self.embedding(embedding_input)
            # embedded_input shape: (batch*beam_len, 1, input_size)
            embedded_input = self.input_projection(embedded_input)

            h = self._hidden_from_beam(incomplete)

            decoder_output, h, pos = self._get_output_and_update_memory(
                embedded_input, h, x, pos)
            # lstm output: (batch*beam_len, 1, hidden_size)
            for batch_index, batch in enumerate(input_order):
                # Each batch is a seperate beam search.
                # Get the probabilites from each beam
                log_probs = F.log_softmax(self._outlayer(
                    decoder_output[batch_index * beam_width:(batch_index + 1) * beam_width].squeeze(1)), dim=1)

                # Put all the beam probabilities in a single vector, with the
                # full seq prob
                seq_probs = torch.cat(
                    [incomplete[batch][i].log_prob + log_probs[i] for i in range(beam_width)])

                # Get the top k
                k_seq_probs, k_indices = torch.topk(seq_probs, beam_width)

                new_beams = []

                for seq_prob, index in zip(k_seq_probs, k_indices):
                    beam_index = index // self.vocab_size
                    word_index = index % self.vocab_size
                    prev_beam = incomplete[batch][beam_index]
                    if word_index == self.eos_idx:
                        # we hit the end of the sequence! Therefore, this element
                        # of the batch is now complete.

                        # Since we wont be training, we will turn these into regular
                        # values, rather than tensors.
                        seq = [self.eos_idx]
                        prev = prev_beam
                        while prev != None:
                            seq.append(prev.word_id.cpu().item())
                            prev = prev.previous_node
                        seq = seq[::-1]
                        decoded[batch] = seq
                        del incomplete[batch]
                        break
                    new_beams.append(
                        self.BeamNode(
                            self._hidden_to_beam(
                                h, batch_index * beam_width + beam_index),
                            prev_beam,
                            word_index,
                            seq_prob,
                            prev_beam.length + 1))

                # if we didn't complete the sequence
                if batch in incomplete:
                    incomplete[batch] = new_beams

        # For elements which hit the max seq length, we will cut them off at the
        # most probable sequence so far.
        for batch in incomplete:
            seq = [self.eos_idx]
            # The first beam will be the most probable sequence so far
            prev = incomplete[batch][0]
            while prev != None:
                seq.append(prev.word_id.cpu().item())
                prev = prev.previous_node
            seq = seq[::-1]
            decoded[batch] = seq

        return self.clip_predictions(decoded)
