import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Sequence, Tuple
from transformers import DataCollatorForSeq2Seq
from transformers.generation.configuration_utils import GenerationConfig
from random import choice
import random

def select_sample_elements(feature_keys, ratio):
    # choose online sampling ratio
    x = 0
    for key in feature_keys:
        if "chosen_id" in key:
            x += 1
    num_elements = int(x * ratio)
    # 生成0到x-1的整数列表
    range_list = list(range(x))
    # 随机选择元素
    selected_elements = random.sample(range_list, num_elements)
    return selected_elements


@dataclass
class D2ODataCollatorWithPadding(DataCollatorForSeq2Seq):
    r"""
    Data collator for pairwise data.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # prepare self-corrections
        correction_instructions = [
        "Please help me with my instruction and avoid all harmful content in your response, including anything dangerous, illegal, racist, sexist or socially biased.",
        "Please help me with my instruction and ensure your response eliminates all harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.",
        "Please help me with my instruction and exclude any harmful, unethical, or socially biased content, steering the conversation towards a positive direction.",
        "Please help me with my instruction and avoid all harmful content and diplomatically address any problematic assumptions from the human, such as racism, sexism, or toxicity.",
        "Please help me with my instruction and ensure your response expunge any insensitive, sexist, racist, or socially inappropriate content.",
        "Please help me with my instruction and ensure your response is devoid of offensive and harmful content, while politely acknowledging and addressing any problematic assumptions from the human.",
        "Please help me with my instruction and ensure your response eradicate any harmful, unethical, racist, toxic, dangerous, or illegal content.",
        "Please help me with my instruction and make sure your response suitable for children by ensuring it is age-appropriate.",
        "Please help me with my instruction and ensure your response eliminate all harmful content, particularly anything derogatory, toxic, racist, sexist, or socially harmful.",
        "Please help me with my instruction and avoid providing legally questionable or dangerous advice.",
        "Please help me with the tone and empathy of a thoughtful, caring, sensitive friend or therapist.",
        "Please help me with stripped of any harmful, unethical, dangerous, or illegal content."
    ]
        self.correction_ids = []
        for instruct in correction_instructions:
            sentence_to_insert_ids = self.tokenizer.encode(instruct)
            self.correction_ids.append(sentence_to_insert_ids[1:])
        
        # template prefix
        template_prefix = "BEGINNING OF CONVERSATION:  USER: "
        self.template_prefix = self.tokenizer.encode(template_prefix)
        self.template_prefix_len = len(self.template_prefix) - 1 # no <eos>


    def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
        padded_labels = []
        for feature, (prompt_len, answer_len) in zip(batch, positions):
            if self.tokenizer.padding_side == "left":
                start, end = feature.size(0) - answer_len, feature.size(0)
            else:
                start, end = prompt_len, prompt_len + answer_len
            padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
            padded_tensor[start:end] = feature[start:end]
            padded_labels.append(padded_tensor)
        return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory

    def online_sampling(self, features, feature_keys):
        """
        sampling function, features {chosen_1, chosen_2, ..., rejected}
        """
        print("sampling...")
        prompts = []
        prompt_max_length = 0
        for feature in features:
            prompt_feature =  feature["prompt_ids"]
            prompt_length = len(prompt_feature)
            prompt_max_length = max(prompt_max_length, len(feature["prompt_ids"]))
            prompts.append(
                {
                    "input_ids": prompt_feature,
                    "attention_mask": [1] * prompt_length
                }
            )
        # left padding
        prompt_features = self.sample_tokenizer.pad(
            prompts,
            padding=True,
            return_tensors=self.return_tensors,
            pad_to_multiple_of=self.pad_to_multiple_of,
        )
        
        sampled_ids_with_padding = self.sampling(
            prompt_features["input_ids"],
            prompt_features["attention_mask"],
            prompt_max_length,
            need_unload=False
        )

        assert sampled_ids_with_padding.size(0) == prompt_features["input_ids"].size(0)
        prompt_len = prompt_features["input_ids"].size(1)
        concatenated_features = []
        label_positions = []
        for key in feature_keys:
            for i, feature in enumerate(features):
                sample_indexs = select_sample_elements(feature_keys, 0.3)
                if ("chosen_id" in key) and (int(key.split("_")[-1]) in sample_indexs ):
                    # get before eos token
                    sampled_ids = sampled_ids_with_padding[i, prompt_len:]
                    valid_sample_len = (sampled_ids != self.tokenizer.pad_token_id).sum(dim=-1)
                    sampled_ids = sampled_ids[:valid_sample_len].numpy().tolist()

                    if valid_sample_len <= 1:
                        # sampled_ids = self.tokenizer.encode(" ")[-1]   
                        print(f"Sampling Empty Sentences: {self.step}, using original label")
                        sampled_ids = feature[key]

                    # barrier: print sample
                    print_answers = True
                    if print_answers:
                        print(f"----------step: {self.steps}----------------------------")
                        print("prompt: ", self.tokenizer.decode(feature["prompt_ids"]))
                        print("sampled prompt: ", self.tokenizer.decode(prompt_features["input_ids"][i]))
                        print("sample: ", self.tokenizer.decode(sampled_ids))
                        print("chosen: ", self.tokenizer.decode(feature[key]))
                        print("rejected: ", self.tokenizer.decode(feature["rejected_ids"]))
                        print("")
                    
                    raw_prompt_len, answer_len = len(feature["prompt_ids"]), len(sampled_ids)
                    concatenated_features.append({
                        "input_ids": feature["prompt_ids"] + sampled_ids,
                        "attention_mask": [1] * (raw_prompt_len + answer_len)
                    })
                    label_positions.append((raw_prompt_len, answer_len))
                else:
                    raw_prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
                    concatenated_features.append({
                        "input_ids": feature["prompt_ids"] + feature[key],
                        "attention_mask": [1] * (raw_prompt_len + answer_len)
                    })
                    label_positions.append((raw_prompt_len, answer_len))
        return concatenated_features, label_positions
    
    def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        r"""
        Pads batched data to the longest sequence in the batch.

        We generate 2 * n examples where the first n examples represent chosen examples and
        the last n examples represent rejected examples.
        """
        # sample chosen_ids (self.steps > ) and (self.steps % 320 ==0)
        feature_keys = [x for x in list(features[0].keys()) if x != 'prompt_ids'] 
        feature_keys = sorted(feature_keys)
        if (self.scheduler is not None) and self.scheduler.need_sampling(self.steps): # 8: accumulation steps 
            concatenated_features, label_positions = self.online_sampling(features, feature_keys) 
        else:
            concatenated_features = []
            label_positions = []
            for key in feature_keys: # [chosen_id_1, chosen_id_2, ..., rejected_ids] 不能有prompt_id!!!
                for i, feature in enumerate(features):
                    raw_prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
                    concatenated_features.append({
                        "input_ids": feature["prompt_ids"] + feature[key],
                        "attention_mask": [1] * (raw_prompt_len + answer_len)
                    })
                    label_positions.append((raw_prompt_len, answer_len))

        batch = self._pad_concatenated_features(concatenated_features)
        # we pad label after sampling 
        batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
        self.steps += 1
        return batch

    def _pad_concatenated_features(self, concatenated_features) -> Dict[str, torch.Tensor]:
        return self.tokenizer.pad(
            concatenated_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors,
        )
    
    def sampling(self, prompt_ids, attention_mask, prompt_max_length, need_unload, **kwargs):
        # sampling response and replace chosen
        with torch.no_grad():
            """
            generate_kwargs = dict(
                max_length=self.tokenizer.model_max_length - prompt_max_length,
                num_beams=5,
                num_beam_groups=5,
                diversity_penalty=1.0,
                early_stopping=True,
                num_return_sequences=1,
                output_scores=False,
                use_cache=True,
                return_dict_in_generate=False
            )
            """
            # to device
            # TODO: stage 3 need to synced_gpus
            device = self.accelerator.device
            prompt_ids = prompt_ids.to(device)
            attention_mask = attention_mask.to(device)

            generate_config = GenerationConfig(**dict(
                max_new_tokens=self.tokenizer.model_max_length - prompt_max_length,
                temperature=1.2 ,
                top_p=0.95,
                top_k=50,             # Top-K sampling
                do_sample=True,
                output_scores=False,
                use_cache=False,
                return_dict_in_generate=False,
                num_return_sequences=1,
                pad_token_id=self.tokenizer.pad_token_id,
            ))
            
            # pad_token_id=self.tokenizer.pad_token_id,
            random_seed1 = random.randint(0, 2**32 - 1)
            random_seed2 = random.randint(0, 2**32 - 1)
            if need_unload:
                with self.accelerator.unwrap_model(self.model).disable_adapter():
                    outputs = self.model.generate(
                                input_ids=prompt_ids,
                                attention_mask=attention_mask,
                                pad_token_id=self.tokenizer.pad_token_id,
                                generation_config = generate_config,
                                seed=[random_seed1, random_seed2]
                            )
            else:
                outputs = self.model.generate(
                    input_ids=prompt_ids,
                    attention_mask=attention_mask,
                    pad_token_id=self.tokenizer.pad_token_id,
                    generation_config = generate_config,
                )
        # return generated sequence
        return outputs.cpu()