# coding=utf-8
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BlenderbotForConditionalGeneration, AutoModelForSeq2SeqLM
from transformers.modeling_outputs import BaseModelOutput
from my_rewards3 import BertScore
from myWork.utils import *






class Model(nn.Module):
    '''
    MCCL
    '''
    def __init__(self, args):
        super(Model, self).__init__()

        self.tau = args.tau
        self.pos_eps = args.pos_eps
        self.neg_eps = args.neg_eps 

        self.eta = args.eta
        self.neg_mode = args.neg_mode

        self.bot_model = AutoModelForSeq2SeqLM.from_pretrained(args.bert_model)

        self.src_projection = MLPLayer(args.hidden_size, activation='relu')
        self.tgt_projection = MLPLayer(args.hidden_size, activation='relu')

        self.reward_function = BertScore(args)

        self.max_sample_len = args.max_tgt_length

        self.encoder_warm = args.warm_epoch
        self.max_epoch = args.num_train_epochs


        self.lsCELoss = LabelSmoothing(0.1)

        # temperature control parameter
        self.cos_sim = Similarity(0.1) # args.cos_temp

        self.reward_function = BertScore(args)
        self.z_size = args.hidden_size // 10
        self.c2z = Hidden2Gaussian(args.hidden_size, self.z_size)
        self.z2noise = nn.Linear(self.z_size, args.hidden_size)


    def hidden2vocab(self, hiddens):
        hiddens = self.bot_model.lm_head(hiddens) + self.bot_model.final_logits_bias # [b, t, v]
        return hiddens


    def forward(self, model_inputs):
        train_mode = model_inputs.get('train_mode', 's2s')
        if train_mode == 's2s':
            return self.forward_s2s(model_inputs)
        elif train_mode == 'kl':
            return self.forward_kl(model_inputs)
        elif train_mode == 'rl':
            return self.forward_rl(model_inputs)
        else:
            print('error train mode')
            exit()

    def forward_s2s(self, model_inputs):

        input_ids = model_inputs['input_ids']
        attention_mask = model_inputs['attention_mask']
        decoder_input_ids = model_inputs['decoder_input_ids']
        decoder_attention_mask = model_inputs['decoder_attention_mask']
        labels = model_inputs['labels']
        epoch = model_inputs.get('epoch', 0)
        nll_only = model_inputs.get('nll_only', False)

        batch_size = input_ids.size(0)
        encoder = self.bot_model.get_encoder()
        decoder = self.bot_model.get_decoder()
        encoder_outputs = encoder(input_ids=input_ids,
                            attention_mask=attention_mask,
                            return_dict=True
                            )
        hidden_states = encoder_outputs['last_hidden_state'] # [b, t, d]

        decoder_outputs = decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
        )

        sequence_output = decoder_outputs[0] # [b, t, d]

        lm_logits = self.hidden2vocab(sequence_output)

        vocab_size = lm_logits.size(-1)
        loss_fct = nn.CrossEntropyLoss()

        nll = loss_fct(lm_logits.view(-1,vocab_size), labels.view(-1))
        if nll_only:
            return nll

        cos_sim = Similarity(0.1) 


        per_neg_enc, _ = self.generate_encoder_perturbation(hidden_states, attention_mask, sequence_output, decoder_attention_mask)
        per_pos_enc = self.generate_pos_perturbation(hidden_states, decoder_input_ids,decoder_attention_mask, attention_mask, labels)

        # projection[b, d]
        avg_src = self.src_proj_pool(hidden_states, attention_mask)
        avg_tgt = self.tgt_proj_pool(sequence_output, decoder_attention_mask) 
        avg_neg_src = self.src_proj_pool(per_neg_enc, attention_mask)
        avg_pos_src = self.src_proj_pool(per_pos_enc, attention_mask)
    
        cont_crit = self.lsCELoss
        sim_matrix = cos_sim(avg_src.unsqueeze(1), avg_tgt.unsqueeze(0)) # [b, b]        
        labels = torch.arange(batch_size, device=input_ids.device)
        base_cont_loss = cont_crit(sim_matrix, labels)
        cont_loss = base_cont_loss

        neg_logit1 = cos_sim(avg_tgt.detach(), avg_neg_src).unsqueeze(1) 
        neg_logit2 = cos_sim(avg_tgt, avg_neg_src.detach()).unsqueeze(1) 
        adv_sim = torch.cat([neg_logit1, neg_logit2], 1)
        neg_cont_loss = cont_crit(torch.cat([sim_matrix, adv_sim], 1) , labels)

        pos_sim = cos_sim(avg_tgt, avg_pos_src).unsqueeze(-1)
        identity = torch.eye(batch_size, device=input_ids.device)
        pos_sim = identity * pos_sim

        neg_sim = sim_matrix.masked_fill(identity == 1, 0)
        new_sim_matrix = pos_sim + neg_sim
        new_logits = torch.cat([new_sim_matrix, adv_sim], 1)
        new_cont_loss = cont_crit(new_logits, labels)

        cont_loss = 0.5 * (neg_cont_loss + new_cont_loss) 
        loss = nll + cont_loss

        return loss


    def generate_encoder_perturbation(self, enc_hiddens, attention_mask, dec_hiddens, decoder_attention_mask):
        # https://github.com/snakeztc/NeuralDialog-LaRL/blob/master/latent_dialog/models_deal.py#L258
        enc_hiddens = detach(enc_hiddens)

        p_mu, p_logvar = self.c2z(enc_hiddens)
        sample_z = torch.normal(p_mu, torch.sqrt(torch.exp(p_logvar)))

        if self.neg_mode:
            la = self.eta
            perturbed_enc = self.z2noise(sample_z) * la + enc_hiddens 
        else:
            noise = self.z2noise(sample_z)
            noise = norm(noise)
            perturbed_enc = enc_hiddens + self.neg_eps * noise
        
        zero_tensor = torch.zeros(1).type(torch.cuda.FloatTensor)
        logprob_sample_z = gaussian_logprob(p_mu, zero_tensor, sample_z)
        joint_logpz = torch.sum(logprob_sample_z, dim=-1) / np.sqrt(self.z_size)

        return perturbed_enc, joint_logpz



    def forward_kl(self, model_inputs):
        input_ids = model_inputs['input_ids']
        attention_mask = model_inputs['attention_mask']
        decoder_input_ids = model_inputs['decoder_input_ids']
        decoder_attention_mask = model_inputs['decoder_attention_mask']
        labels = model_inputs['labels']
        encoder = self.bot_model.get_encoder()
        decoder = self.bot_model.get_decoder()
        encoder_outputs = encoder(input_ids=input_ids,
                            attention_mask=attention_mask,
                            return_dict=True
                            )
        hidden_states = encoder_outputs['last_hidden_state'] # [b, t, d]

        decoder_outputs = decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
        )

        sequence_output = decoder_outputs[0] # [b, t, d]
        lm_logits = self.hidden2vocab(sequence_output)

        #  negative perturb encoder
        # per_neg_enc, _ = self.generate_encoder_perturbation(hidden_states, attention_mask, sequence_output, decoder_attention_mask)
        per_pos_enc = self.generate_pos_perturbation(hidden_states, decoder_input_ids,decoder_attention_mask, attention_mask, labels)

        # kl_neg = self.forward_kl_base(model_inputs, lm_logits, per_neg_enc)
        kl_neg = 0.0
        kl_pos = self.forward_kl_base(model_inputs, lm_logits, per_pos_enc)

        kl_loss = kl_pos + kl_neg

        return kl_loss



    def forward_kl_base(self, model_inputs, lm_logits, kl_hiddens):
        attention_mask = model_inputs['attention_mask']
        decoder_input_ids = model_inputs['decoder_input_ids']
        decoder_attention_mask = model_inputs['decoder_attention_mask']

        decoder = self.bot_model.get_decoder()

        vocab_size = lm_logits.size(-1)

        #  negative perturb encoder
        perturbed_enc = kl_hiddens

        # Adversarial  KL
        perturb_decoder_outputs = decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=perturbed_enc,
            encoder_attention_mask=attention_mask,
        )
        perturb_decoder_outputs = perturb_decoder_outputs[0]
        perturb_logits = self.hidden2vocab(perturb_decoder_outputs)
        perturb_log_probs = F.log_softmax(perturb_logits, -1)

        true_probs = F.softmax(lm_logits, -1)
        true_probs = true_probs * decoder_attention_mask.unsqueeze(-1).float()

        kl_crit = nn.KLDivLoss(reduction="sum")
        kl_loss = kl_crit(perturb_log_probs.view(-1, vocab_size), true_probs.view(-1, vocab_size))
        kl_loss = kl_loss / torch.sum(decoder_attention_mask).float() 

        print('kl_loss: ', kl_loss)
        return kl_loss


    def forward_rl(self, model_inputs):
        input_ids = model_inputs['input_ids']
        attention_mask = model_inputs['attention_mask']

        src_text = model_inputs['src_text']

        encoder = self.bot_model.get_encoder()
        encoder_outputs = encoder(input_ids=input_ids,
                            attention_mask=attention_mask,
                            return_dict=True
                            )
        hidden_states = encoder_outputs['last_hidden_state'] # [b, t, d]
        perturbed_enc, noise_prob = self.generate_encoder_perturbation(hidden_states, None, None, None)

        # sample
        with torch.no_grad():
            sample_ids, _, _ = sample(input_ids, attention_mask, perturbed_enc, self.bot_model, self.max_sample_len)
        # sample_mask = (~(sample_ids == 0)).long()[:,1:]
        sample_ids = sample_ids.detach()

        # obtain reward about the pertub
        reward_baseline = 0.0
        reward_logits = self.reward_function.score_genBatch(src_text, sample_ids)
        reward_logits = reward_logits.to(input_ids.device)
        rewards = F.softmax(reward_logits, dim=1)[:, 1].detach() - reward_baseline
        # print(rewards)
        noise_prob.masked_fill_(attention_mask == 0, 0.0)
        reward_loss = -(rewards * noise_prob.transpose(0, 1)) 
        reward_loss = torch.sum(reward_loss) / torch.sum(attention_mask).float() 

        return reward_loss


    def generate_pos_perturbation(self, enc_hiddens,
                 decoder_input_ids,decoder_attention_mask,
                 attention_mask, labels):

        enc_hiddens = enc_hiddens.detach()
        enc_hiddens.requires_grad = True

        decoder = self.bot_model.get_decoder()

        decoder_outputs = decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=enc_hiddens,
            encoder_attention_mask=attention_mask,
        )

        sequence_output = decoder_outputs[0] # [b, t, d]

        lm_logits = self.hidden2vocab(sequence_output)
        vocab_size = lm_logits.size(-1)
        loss_fct = nn.CrossEntropyLoss()

        loss = loss_fct(lm_logits.view(-1,vocab_size), labels.view(-1))
        loss.backward()

        enc_grad = enc_hiddens.grad.detach()
        l2_norm = torch.norm(enc_grad, dim=-1)
        enc_grad /= (l2_norm.unsqueeze(-1) + 1e-12)

        
        perturbed_enc = enc_hiddens + self.pos_eps * enc_grad.detach()

        self.zero_grad()
        return perturbed_enc


    def avg_pool(self, hidden_states, mask):
        length = torch.sum(mask, 1, keepdim=True).float()
        mask = mask.unsqueeze(2)
        hidden = hidden_states.masked_fill(mask == 0, 0.0)
        avg_hidden = torch.sum(hidden, 1) / length

        return avg_hidden
    
    def src_proj_pool(self, hidden_states, mask):
        proj_h = self.src_projection(hidden_states)
        avg_h = self.avg_pool(proj_h, mask) # [b, d]

        return avg_h

    def tgt_proj_pool(self, hidden_states, mask):
        proj_h = self.tgt_projection(hidden_states)
        avg_h = self.avg_pool(proj_h, mask) # [b, d]

        return avg_h
