import torch
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union, List
from transformers import BatchEncoding, Trainer, AutoTokenizer, TrainerCallback, TrainerState, TrainerControl
from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
from trl.models import PreTrainedModelWrapper
from torch.nn import functional as F
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.train.d2o_ema.scheduler import FixedSamplerScheduler, LinearSamplerScheduler, ExponentialSamplerScheduler, IncreasingDensityScheduler
from contextlib import contextmanager, nullcontext
from llmtuner.train.d2o_ema.ema_utils import moving_average
import deepspeed

if TYPE_CHECKING:
    from transformers import PreTrainedModel

class EMACallback(TrainerCallback):
    def __init__(self, trainer, update_interval=100, beta=0.992):
        # 初始化计数器和更新间隔
        self.trainer = trainer # 维护一个引用
        self.update_interval = update_interval
        self.counter = 0
        self.beta = beta

    def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
        # 在每次反向传播结束后调用
        self.counter += 1  # 更新计数器
        if self.counter % self.update_interval == 0:
            # 当计数器达到特定的间隔时执行操作
            print(f"Interval reached at step {state.global_step}, performing EMA operation.")
            # 例如，可以重置某个参数的梯度
            moving_average(self.trainer.model, self.ref_model_ema, beta=self.beta, zero_stage=self.trainer.deepspeed_stage)
            
        # 重置计数器（如果需要，可以在训练的某个特定阶段重置）
        if self.counter >= self.update_interval:
            self.counter = 0

class D2OTrainer(DPOTrainer):
    # RRHF loss
    def __init__(
        self,
        beta: float,
        multiple_K: int,
        model: Union["PreTrainedModel", torch.nn.Module],
        ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
        disable_dropout: Optional[bool] = True,
        loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
        label_smoothing: Optional[float] = 0.0,
        **kwargs
    ):
        if disable_dropout:
            disable_dropout_in_model(model)
            if ref_model is not None:
                disable_dropout_in_model(ref_model)
        self.deepspeed_stage = 0
        self.multiple_K = multiple_K
        self.is_encoder_decoder = model.config.is_encoder_decoder
        self.ref_model = ref_model
        self.use_dpo_data_collator = True # hack to avoid warning
        self.generate_during_eval = False # disable at evaluation
        self.label_pad_token_id = IGNORE_INDEX
        self.padding_value = 0
        self.beta = beta
        self.loss_type = loss_type
        self.label_smoothing = label_smoothing
        self._stored_metrics = defaultdict(lambda: defaultdict(list))

        ema_callback = EMACallback(self, update_interval=1000, beta=0.992)
        if ('callbacks' not in kwargs) or (kwargs['callbacks'] is None):
            # 如果没有，则添加一个默认的callbacks列表
            kwargs['callbacks'] = [ema_callback]
        else:
            # 如果有，可以根据需要修改它
            # 例如，添加一个新的回调到列表中
            kwargs['callbacks'].append(ema_callback)

        Trainer.__init__(self, model=model, **kwargs)
        if not hasattr(self, "accelerator"):
            raise AttributeError("Please update `transformers`.")
        # 所有ref这里load
        if ref_model is not None:
            if self.is_deepspeed_enabled:
                if not (
                    getattr(ref_model, "is_loaded_in_8bit", False)
                    or getattr(ref_model, "is_loaded_in_4bit", False)
                ): # quantized models are already set on the correct device
                    self.ref_model = self._prepare_deepspeed(self.ref_model)
                    self.ref_model_ema = self._prepare_deepspeed(ref_model)
            else:
                self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
                self.ref_model_ema = self.accelerator.prepare_model(ref_model, evaluation_mode=True)

        # add model to data_collator
        self.data_collator.model = self.model
        self.data_collator.accelerator = self.accelerator
        self.data_collator.steps = 0
        self.data_collator.ref_model = self.ref_model
        # prepare sample tokenizer
        self.data_collator.sample_tokenizer = AutoTokenizer.from_pretrained(self.tokenizer.name_or_path)
        self.data_collator.sample_tokenizer.padding_side = "left" # restore padding side
        self.data_collator.sample_tokenizer.init_kwargs["padding_side"] = "left"
        # IncreasingDensityScheduler(warmup_steps=16000, initial_interval=3200, increase_factor=1.1)
        # FixedSamplerScheduler(warmup_steps=16000, sample_interval=3200) # ExponentialSamplerScheduler(warmup_steps=500, scale=2)
        self.data_collator.scheduler = None # FixedSamplerScheduler(warmup_steps=1600, sample_interval=100)

        self.loss_num = 2
        self.alpha = self.beta # 目前先按照alpha == beta


    def dpo_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        reference_chosen_logps: torch.FloatTensor,
        reference_rejected_logps: torch.FloatTensor,
        reference_free: bool = False,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Compute the DPO loss for a batch of policy and reference model log probabilities.

        Args:
            policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
            policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
            reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
            reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
            reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.

        Returns:
            A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
            The losses tensor contains the DPO loss for each example in the batch.
            The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
        """
        if self.loss_num == 1:
            return self.formula_1(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps)
        elif self.loss_num==2:
            return self.formula_2(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps)
        elif self.loss_num==3:
            return self.formula_3(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps)
    
    def formula_1(
            self,
            multiple_policy_chosen_logps: torch.FloatTensor,
            policy_rejected_logps: torch.FloatTensor,
            multiple_reference_chosen_logps: torch.FloatTensor,
            reference_rejected_logps: torch.FloatTensor,
            reference_free: bool = False,
    )-> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        # if self.data_collator.steps < self.data_collator.scheduler.warmup_steps:
        #     logits = - policy_rejected_logps + reference_rejected_logps
        # else:
        losses = 0
        chosen_rewards = 0 
        for i in range(self.multiple_K):
            policy_chosen_logps = multiple_policy_chosen_logps[i * self._train_batch_size:i * self._train_batch_size + self._train_batch_size]
            pi_logratios = policy_chosen_logps - policy_rejected_logps
            if reference_free:
                ref_logratios = 0
            else:
                reference_chosen_logps = multiple_reference_chosen_logps[i * self._train_batch_size:i * self._train_batch_size + self._train_batch_size]
                ref_logratios = reference_chosen_logps - reference_rejected_logps

            logits = pi_logratios - ref_logratios
            losses += -F.logsigmoid(self.beta * logits)
            chosen_rewards += self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
        # mean
        losses = losses / self.multiple_K
        chosen_rewards = chosen_rewards / self.multiple_K
        rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()

        return losses, chosen_rewards, rejected_rewards
    
    def formula_2(
            self,
            multiple_policy_chosen_logps: torch.FloatTensor,
            policy_rejected_logps: torch.FloatTensor,
            multiple_reference_chosen_logps: torch.FloatTensor,
            reference_rejected_logps: torch.FloatTensor,
            reference_free: bool = False,
    )-> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        # if self.data_collator.steps < self.data_collator.scheduler.warmup_steps:
        #     logits = - policy_rejected_logps + reference_rejected_logps
        # else:
        mean_chosen_diff = 0
        chosen_rewards = 0 
        for i in range(self.multiple_K):
            policy_chosen_logps = multiple_policy_chosen_logps[i * self._train_batch_size:i * self._train_batch_size + self._train_batch_size]
            mean_chosen_diff += policy_chosen_logps
            if not reference_free:
                reference_chosen_logps = multiple_reference_chosen_logps[i * self._train_batch_size:i * self._train_batch_size + self._train_batch_size]
                mean_chosen_diff -= reference_chosen_logps

        mean_chosen_diff = mean_chosen_diff / self.multiple_K
        rejected_diff = policy_rejected_logps 
        
        if not reference_free:
            rejected_diff = rejected_diff - reference_rejected_logps

        # mean
        logits = self.beta * mean_chosen_diff - self.alpha * rejected_diff
        losses = -F.logsigmoid(logits)
        chosen_rewards = self.beta * (mean_chosen_diff).detach()
        rejected_rewards = self.alpha * (policy_rejected_logps - reference_rejected_logps).detach()

        return losses, chosen_rewards, rejected_rewards
    
    def formula_3(
            self,
            multiple_policy_chosen_logps: torch.FloatTensor,
            policy_rejected_logps: torch.FloatTensor,
            multiple_reference_chosen_logps: torch.FloatTensor,
            reference_rejected_logps: torch.FloatTensor,
            reference_free: bool = False,
    )-> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        # if self.data_collator.steps < self.data_collator.scheduler.warmup_steps:
        #     logits = - policy_rejected_logps + reference_rejected_logps
        # else:
        mean_policy_chosen_logps_diff = 0
        mean_policy_chosen_logps, mean_reference_chosen_logps = 0, 0
        chosen_rewards = 0 
        for i in range(self.multiple_K):
            policy_chosen_logps = multiple_policy_chosen_logps[i * self._train_batch_size:i * self._train_batch_size + self._train_batch_size]
            mean_policy_chosen_logps += policy_chosen_logps
            
            if not reference_free:
                reference_chosen_logps = multiple_reference_chosen_logps[i * self._train_batch_size:i * self._train_batch_size + self._train_batch_size]
                mean_reference_chosen_logps += reference_chosen_logps
                mean_policy_chosen_logps_diff += torch.exp(self.beta *(policy_chosen_logps - reference_chosen_logps))
        
        # sum()
        mean_policy_chosen_logps_diff = mean_policy_chosen_logps_diff / self.multiple_K

        rejected_diff = torch.exp((policy_rejected_logps - reference_rejected_logps) * self.beta)

        # mean
        logits = torch.log(mean_policy_chosen_logps_diff) - torch.log(mean_policy_chosen_logps_diff + torch.exp(rejected_diff))
        losses = -F.logsigmoid(logits)
        chosen_rewards = self.beta * (mean_policy_chosen_logps - mean_reference_chosen_logps).detach()
        rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()

        return losses, chosen_rewards, rejected_rewards
    
    def compute_reference_log_probs(self, padded_batch: Dict) -> Dict:
        """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset."""
        compte_ref_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext

        # compute reference logps
        with torch.no_grad(), compte_ref_context_manager():
            if self.ref_model is None:
                with self.null_ref_context():
                    (
                        reference_chosen_logps,
                        reference_rejected_logps,
                        _,
                        _,
                    ) = self.concatenated_forward(self.model, padded_batch)
            else: # ref + ema
                # BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
                chosen_batch = {}
                rejected_batch = {}
                for k, v in padded_batch.items():
                    split_point = self._train_batch_size * self.multiple_K
                    chosen, rejected = v.split(split_point, dim=0)
                    chosen_batch[k] = chosen.detach().clone()
                    rejected_batch[k] =  rejected.detach().clone()
                chosen_batch = BatchEncoding(chosen_batch)
                rejected_batch = BatchEncoding(rejected_batch)
                # seperated ref model 
                (
                    reference_chosen_logps,
                    _,
                ) = self.single_forward_ref(self.ref_model, chosen_batch)
                (
                    reference_rejected_logps,
                    _,
                ) = self.single_forward_ref(self.ref_model_ema, rejected_batch)

        return reference_chosen_logps, reference_rejected_logps

    
    def single_forward_ref(
        self,
        model: Optional[torch.nn.Module] = None,
        batch: Optional[Dict[str, torch.Tensor]] = None
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        # just one forward
        batch_copied = batch
        
        all_logits = model(
            input_ids=batch_copied["input_ids"],
            attention_mask=batch_copied["attention_mask"],
            return_dict=True
        ).logits.to(torch.float32)

        all_logps = self._get_batch_logps(
            all_logits,
            batch["labels"],
            average_log_prob=False
        )
        return all_logps, all_logits

    def concatenated_forward(
        self,
        model: Optional[torch.nn.Module] = None,
        batch: Optional[Dict[str, torch.Tensor]] = None
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        
        batch_size = self._train_batch_size
        # split batch
        # prompt_ids, prompt_attention_mask, input_ids, attention_mask = batch["prompt_ids"], batch["input_ids"], batch["attention_mask"]
        # chosen_prompt_ids, chosen_prompt_attention_mask, chosen_input_ids, chosen_attention_mask = prompt_ids[:batch_size], prompt_attention_mask[:batch_size], input_ids[:batch_size], attention_mask[:batch_size]
        # rejected_input_ids, rejected_attention_mask = input_ids[batch_size:], attention_mask[batch_size:]
        
        batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
        
        all_logits = model(
            input_ids=batch_copied["input_ids"],
            attention_mask=batch_copied["attention_mask"],
            return_dict=True
        ).logits.to(torch.float32)

        all_logps = self._get_batch_logps(
            all_logits,
            batch["labels"],
            average_log_prob=False
        )
        # [batch_size* self.multiple_K] batch_size
        chosen_logps, rejected_logps = all_logps.split(batch_size* self.multiple_K, dim=0)
        chosen_logits, rejected_logits = all_logits.split(batch_size * self.multiple_K, dim=0)
        return chosen_logps, rejected_logps, chosen_logits, rejected_logits
    
    def get_batch_loss_metrics(
        self,
        model,
        batch: Dict[str, Union[List, torch.LongTensor]],
        train_eval: Literal["train", "eval"] = "train",
    ):
        """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
        metrics = {}

        (
            policy_chosen_logps,
            policy_rejected_logps,
            policy_chosen_logits,
            policy_rejected_logits,
        ) = self.concatenated_forward(model, batch)

        # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
        if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch:
            reference_chosen_logps = batch["reference_chosen_logps"]
            reference_rejected_logps = batch["reference_rejected_logps"]
        else:
            with torch.no_grad():
                if self.ref_model is None:
                    with self.null_ref_context():
                        (
                            reference_chosen_logps,
                            reference_rejected_logps,
                            _,
                            _,
                        ) = self.concatenated_forward(self.model, batch)
                else:
                    chosen_batch = {}
                    rejected_batch = {}
                    for k, v in batch.items():
                        split_point = self._train_batch_size * self.multiple_K
                        chosen, rejected = v.split(split_point, dim=0)
                        chosen_batch[k] = chosen.detach().clone()
                        rejected_batch[k] =  rejected.detach().clone()
                    chosen_batch = BatchEncoding(chosen_batch)
                    rejected_batch = BatchEncoding(rejected_batch)
                    # seperated ref model 
                    (
                        reference_chosen_logps,
                        _,
                    ) = self.single_forward_ref(self.ref_model, chosen_batch)
                    (
                        reference_rejected_logps,
                        _,
                    ) = self.single_forward_ref(self.ref_model_ema, rejected_batch)

        losses, chosen_rewards, rejected_rewards = self.dpo_loss(
            policy_chosen_logps,
            policy_rejected_logps,
            reference_chosen_logps,
            reference_rejected_logps,
        )
        reward_accuracies = (chosen_rewards > rejected_rewards).float()

        prefix = "eval_" if train_eval == "eval" else ""
        metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
        metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
        metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
        metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
        metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
        metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
        metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
        metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()

        return losses.mean(), metrics

    def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
        # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
        deepspeed_plugin = self.accelerator.state.deepspeed_plugin
        from copy import deepcopy
        config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)

        if model is not None:
            if hasattr(model, "config"):
                hidden_size = (
                    max(model.config.hidden_sizes)
                    if getattr(model.config, "hidden_sizes", None)
                    else getattr(model.config, "hidden_size", None)
                )
                if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
                    # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
                    # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
                    config_kwargs.update(
                        {
                            "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
                            "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
                            "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
                        }
                    )

        # If ZeRO-3 is used, we shard both the active and reference model.
        # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
        if config_kwargs["zero_optimization"]["stage"] != 3:
            config_kwargs["zero_optimization"]["stage"] = 0
        self.deepspeed_stage = config_kwargs["zero_optimization"]["stage"]
        model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
        model.eval()
        return model
