from argparse import ArgumentParser
from copy import deepcopy

import torch
import torch.nn as nn
from transformers import XLMRobertaConfig, BertConfig
from utils.myQueue import Queue

from models.base_model import BaseModel


class ContrastiveLearning(torch.nn.Module):
    def __init__(self, args, num_intents, num_slots):
        super().__init__()
        self.args = args
        self.num_intents = num_intents
        self.num_slots = num_slots
        self.device = torch.device('cuda') if args.gpu else torch.device('cpu')
        self.lambda1 = args.lambda1
        self.lambda2 = args.lambda2
        self.lambda3 = args.lambda3
        self.temperature = args.temperature
        self.negative_num = args.negative_num
        self.slot_pad_idx = -100
        self.batch_size = args.batch_size

        self.max_seq_length = args.max_seq_length
        self.model = BaseModel(args, num_intents, num_slots)
        if (self.args.load_weights or self.args.load_encoder_weights) and self.args.train:
            self.load_model()
        self.model.to(self.device)
        slot_size = 0
        if self.args.base_model == 'XLMRoberta':
            self.config = XLMRobertaConfig.from_pretrained('sentence-transformers/paraphrase-xlm-r-multilingual-v1')
        else:
            self.config = BertConfig.from_pretrained('bert-base-multilingual-uncased')
        self.hidden_size = self.config.hidden_size
        slot_size = self.config.hidden_size
        if self.args.gpu:
            self.intent_classifier = nn.Linear(self.config.hidden_size, num_intents).cuda()
            self.slot_classifier = nn.Linear(self.config.hidden_size, num_slots).cuda()
            self.intent_criterion = nn.CrossEntropyLoss().cuda()
            self.slot_criterion = nn.CrossEntropyLoss(ignore_index=self.slot_pad_idx).cuda()
            self.intent_layer = nn.Linear(num_intents, self.config.hidden_size).cuda()
        else:
            self.intent_classifier = nn.Linear(self.config.hidden_size, num_intents)
            self.slot_classifier = nn.Linear(slot_size, num_slots)
            self.intent_criterion = nn.CrossEntropyLoss()
            self.slot_criterion = nn.CrossEntropyLoss(ignore_index=self.slot_pad_idx)
            self.intent_layer = nn.Linear(num_intents, self.config.hidden_size)
        self.cls_queue = Queue(self.hidden_size, 1, maxsize=args.negative_num, batch_size=self.batch_size)
        self.embed_queue = Queue(self.hidden_size, self.max_seq_length, maxsize=args.negative_num,
                                 batch_size=self.batch_size)

    def load_model(self):
        # check_point有关
        checkpoint_dir = self.args.saved_model_dir + self.args.load_model_name

        if self.args.load_weights:
            if self.args.restore_from == None:
                model_CKPT = torch.load(checkpoint_dir)
            else:
                model_CKPT = torch.load(self.args.saved_model_dir + self.args.restore_from)
            self.model.load_state_dict(model_CKPT['state_dict'], False)

        if self.args.load_encoder_weights:
            if self.args.load_encoder_dir:
                encoder_checkpoint_dir = self.args.saved_model_dir + self.args.load_encoder_dir
            else:
                encoder_checkpoint_dir = self.args.saved_model_dir + self.saved_encoder_name + ".encoder"
            model_CKPT = torch.load(encoder_checkpoint_dir)
            print(f"load encoder from {encoder_checkpoint_dir}")
            self.model.load_encoder_state(model_CKPT['state_dict'])

    def get_ids(self, batch, datatype):
        input_ids_list = []
        attention_mask_list = []
        token_type_ids_list = []
        slot_labels_list = []
        for data in batch:
            input_ids_list.append(data[datatype]['input_ids'])
            attention_mask_list.append(data[datatype]['attention_mask'])
            token_type_ids_list.append(data[datatype]['token_type_ids'])
            slot_labels_list.append(data[datatype]['label_ids'])
        input_ids = torch.tensor(input_ids_list, dtype=torch.long).to(self.device)
        attention_mask = torch.tensor(attention_mask_list, dtype=torch.long).to(self.device)
        token_type_ids = torch.tensor(attention_mask_list, dtype=torch.long).to(self.device)
        slot_labels = torch.tensor(slot_labels_list, dtype=torch.long).to(self.device)
        return input_ids, attention_mask, token_type_ids, slot_labels

    def forward(self, batch, evaluate=False):
        # base-model
        out = self.model.forward(batch, evaluate=evaluate)
        if self.args.use_cosda:
            pos_embedded = out.embedded
            pos_cls = out.cls
            embedded=pos_embedded
            cls=pos_cls
        else:
            origin_embedded = out.embedded
            origin_cls = out.cls
            embedded = origin_embedded
            cls = origin_cls

        intent_logits = out.intent_logits
        slot_logits = out.slot_logits
        total_loss = out.loss
        # contrastive learning
        if self.args.contrastive_learning:
            if self.args.use_cosda:
                out = self.model.forward(batch, 'original')
                origin_embedded = out.embedded
                origin_cls = out.cls
            else:
                out = self.model.forward(batch, 'positive')
                pos_embedded = out.embedded
                pos_cls = out.cls

            if self.cls_queue.size > 0:
                negative_embedded = self.embed_queue.negative_encode(len(batch))
                negative_cls = self.cls_queue.negative_encode(len(batch))
                global_loss = self.contrastive_loss_global(origin_cls, pos_cls, negative_cls)
                local_loss = self.contrastive_loss_local(origin_embedded, pos_embedded, negative_embedded)
                global_local_loss = self.contrastive_loss_global_local(origin_cls, origin_embedded, negative_embedded)
                global_local_loss += self.contrastive_loss_global_local(origin_cls, pos_embedded, negative_embedded)
                total_loss += self.lambda1 * global_loss + self.lambda2 * local_loss + self.lambda3 * global_local_loss / 2

            self.cls_queue.enqueue_batch_tensor(origin_cls.detach())
            self.cls_queue.enqueue_batch_tensor(pos_cls.detach())
            self.embed_queue.enqueue_batch_tensor(origin_embedded.detach())
            self.embed_queue.enqueue_batch_tensor(pos_embedded.detach())
        return OUT(embedded, cls, intent_logits, slot_logits, total_loss)  # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of intent and slot logits

    def contrastive_loss_global(self, origin, positive, negative):
        # batch_size
        N = origin.shape[0]
        # feature_size
        C = origin.shape[1]
        # negative_size
        K = self.cls_queue.size
        l_pos = torch.bmm(origin.view(N, 1, C), positive.view(N, C, 1)).view(N, 1)

        l_neg = torch.bmm(origin.view(N, 1, C), negative.view(N, C, K)).view(N, K)

        logits = torch.cat((l_pos, l_neg), dim=1)
        labels = torch.zeros(N, dtype=torch.long).to(self.device)
        if self.args.gpu:
            criteria = nn.CrossEntropyLoss().cuda()
        else:
            criteria = nn.CrossEntropyLoss()

        loss = criteria(torch.div(logits, self.temperature), labels)
        return loss

    def contrastive_loss_global_local(self, origin, positive, negative):
        # batch_size
        N = origin.shape[0]
        # feature_size
        C = origin.shape[1]
        # sequence_length
        L = positive.shape[1]
        # negative_size
        K = self.cls_queue.size
        l_pos = torch.bmm(origin.view(N, 1, C), positive.view(N, C, L)).view(N * L, 1)

        l_neg = torch.bmm(origin.view(N, 1, C), negative.view(N, C, K * L)).view(N * L, K)

        logits = torch.cat((l_pos, l_neg), dim=1)
        labels = torch.zeros(N * L, dtype=torch.long).to(self.device)
        if self.args.gpu:
            criteria = nn.CrossEntropyLoss().cuda()
        else:
            criteria = nn.CrossEntropyLoss()
        loss = criteria(torch.div(logits, self.temperature), labels)
        return loss / L

    def contrastive_loss_local(self, origin, positive, negative):
        # batch_size
        N = origin.shape[0]
        # feature_size
        C = origin.shape[2]
        # sequence_length
        L = origin.shape[1]
        # negative_size
        K = self.cls_queue.size
        l_pos = torch.bmm(origin.view(N, L, C), positive.view(N, C, L)).view(N * L * L, 1)

        l_neg = torch.bmm(origin.view(N, L, C), negative.view(N, C, K * L)).view(N * L * L, K)

        logits = torch.cat((l_pos, l_neg), dim=1)
        labels = torch.zeros(N * L * L, dtype=torch.long).to(self.device)
        if self.args.gpu:
            criteria = nn.CrossEntropyLoss().cuda()
        else:
            criteria = nn.CrossEntropyLoss()
        loss = criteria(torch.div(logits, self.temperature), labels)
        return loss / (L * L)

    def pad(self, outputs, seq_len=128):
        result = []
        for batch in outputs:
            result.append(deepcopy(batch + [0] * (seq_len - len(batch))))
        return result

class OUT:
    def __init__(self, embedded, cls, intent_logits, slot_logits, total_loss):
        self.embedded = embedded
        self.cls = cls
        self.intent_logits = intent_logits
        self.slot_logits = slot_logits
        self.loss = total_loss