import overall_utils
from gsm8k.evaluations.eval_gsm import solve_math_problems,parse_answer
from gsm8k.evaluations.eval_MATH import last_boxed_only_string,remove_boxed,grade_answer
import copy
import time
import logging
from langchain.schema import (
    ChatMessage,
    ChatResult,
    AIMessage,
    HumanMessage,
    SystemMessage,
)
import openai


def create_player(config,llm_agent, model_name, name):
    if "gpt" in model_name.lower():
        return PLAYER(config,llm_agent, model_name, name)
    else:
        return PLAYER_OPENSOURCE(config,llm_agent, model_name, name)
    
class PLAYER_OPENSOURCE(object):
    def __init__(self,config,llm_agent,model_name,name):
        self.config=config
        self.model = llm_agent[0]
        self.tokenizer = llm_agent[1]
        self.llm_agents = []
        self.name = name
        self.model_name = model_name
        self.conversations = []
        self.prompt_token_used = 0
        self.completion_token_used = 0

    def add_system_message(self, messages):
        self.add_user_message(messages)
    def add_user_message(self, messages):
        if isinstance(messages, list):
            for message in messages:
                self.conversations.append({"role": "user", "content": message})
        elif isinstance(messages, str):
            self.conversations.append({"role": "user", "content": messages})
        else:
            raise ValueError("messages must be a list or a string")
    def add_assistant_message(self, messages):
        if isinstance(messages, list):
            for message in messages:
                self.conversations.append({"role": "assistant", "content": message})
        elif isinstance(messages, str):
            self.conversations.append({"role": "assistant", "content": messages})
        else:
            raise ValueError("messages must be a list or a string")

    def get_completion_local(self):
        self.model.eval()
        prompt_str = ""
        assert len(self.conversations) == 1 # cause llama is not a chat model
        for conv in self.conversations:
            prompt_str += conv.content
        tokenized = self.tokenizer(prompt_str, return_tensors="pt").to(self.model.devices[0])
        temperature = self.config["temperature"]
        top_p = self.config["top_p"]
        max_tokens = self.config["max_tokens"]
        completions = []
        for i in range(self.config["n"]):
            completion = self.model.generate(
                **tokenized,
                do_sample=True,
                temperature=temperature,
                max_new_tokens=max_tokens,
            )

            # keep the part that's newly generated
            completion = completion[:, tokenized["input_ids"].shape[-1] :]
            decoded = self.tokenizer.decode(completion[0])
            # print(decoded)
            completions.append(decoded)

        return completions
    def get_completion(self, temperature=None,n=None):
        conversations = self.conversations
        model_name = self.config["model_name"]
        if temperature is None:
            temperature = self.config["temperature"]
        if n is None:
            n = self.config["n"]
        max_tokens = self.config["max_tokens"]
        max_retries = self.config["max_retries"]
        delay = 5
        # print(conversations)
        while max_retries > 0:
            try:
                completion = self.model.chat.completions.create(model=model_name,
                                                                    messages = conversations,
                                                                    temperature=temperature,
                                                                    n = n,
                                                                    max_tokens = max_tokens)
                break
            except Exception as e:
                max_retries -= 1
                max_tokens = max_tokens-200
                print(f"An unexpected error occurred: {e} retry...")
                if max_retries > 0:
                    time.sleep(delay)
                    delay *= 2
                else:
                    raise

        prompt_tokens = completion.usage.prompt_tokens
        completion_tokens = completion.usage.completion_tokens
        completions = []
        for c in completion.choices:
            completions.append(c.message.content)
        # print(completion)
        self.prompt_token_used += prompt_tokens
        self.completion_token_used += completion_tokens

        return completions

class PLAYER(object):
    def __init__(self,config,llm_agent,model_name,name):
        self.config=config
        self.model = llm_agent[0]
        self.tokenizer = llm_agent[1]
        self.name = name
        self.model_name = model_name
        self.conversations = []
        self.prompt_token_used = 0
        self.completion_token_used = 0

    def add_system_message(self, messages):
        if isinstance(messages, list):
            for message in messages:
                self.conversations.append({"role": "system", "content": message})
        elif isinstance(messages, str):
            self.conversations.append({"role": "system", "content": messages})
        else:
            raise ValueError("messages must be a list or a string")
    def add_user_message(self, messages):
        if isinstance(messages, list):
            for message in messages:
                self.conversations.append({"role": "user", "content": message})
        elif isinstance(messages, str):
            self.conversations.append({"role": "user", "content": messages})
        else:
            raise ValueError("messages must be a list or a string")
    def add_assistant_message(self, messages):
        if isinstance(messages, list):
            for message in messages:
                self.conversations.append({"role": "assistant", "content": message})
        elif isinstance(messages, str):
            self.conversations.append({"role": "assistant", "content": messages})
        else:
            raise ValueError("messages must be a list or a string")

    def get_completion(self, temperature=None,n=None):
        conversations = self.conversations
        model_name = self.config["model_name"]
        if temperature is None:
            temperature = self.config["temperature"]
        if n is None:
            n = self.config["n"]
        max_tokens = self.config["max_tokens"]
        max_retries = self.config["max_retries"]
        delay = 5
        while max_retries > 0:
            try:
                completion = self.model.chat.completions.create(model=model_name,
                                                                    messages = conversations,
                                                                    temperature=temperature,
                                                                    n = n,
                                                                    max_tokens = max_tokens)
                break
            except Exception as e:
                max_retries -= 1
                max_tokens = max_tokens-200
                print(f"An unexpected error occurred: {e} retry...")
                if max_retries > 0:
                    time.sleep(delay)
                    delay *= 2
                else:
                    raise

        prompt_tokens = completion.usage.prompt_tokens
        completion_tokens = completion.usage.completion_tokens
        completions = []
        for c in completion.choices:
            completions.append(c.message.content)
        # print(completion)
        self.prompt_token_used += prompt_tokens
        self.completion_token_used += completion_tokens

        return completions

    
class DEBATE_PLAYER(PLAYER):
    def __init__(self,config,llm_agent,model_name,name):
        super().__init__(config,llm_agent,model_name,name)

    def construct_message_from_other_players(self, other_players,question,use_llama_format=False, round=-1):
        agent_contexts = []
        for player in other_players:
            agent_contexts.append(player.conversations)
        if len(agent_contexts) == 0:
            return {"role": "user", "content":"Can you double check that your answer is correct. Please reiterate your answer, with your final answer a single numerical number, in the form \\boxed{{answer}}."}

        prefix_string = "These are the solutions to the problem from other agents: "

        for agent in agent_contexts:
            agent_response = agent[round]["content"]
            response = "\n\n One agent solution: ```{}```".format(agent_response)

            prefix_string = prefix_string + response

        if use_llama_format:
            prefix_string = prefix_string + """\n\n Using the solutions from other agents as additional information, can you provide your answer to the math problem? \n Your final answer should be a single numerical number at the end of your response. \nQuestion: {}\nAnswer: Let's think step by step.""".format(question)
        else:
            prefix_string = prefix_string + """\n\n Using the solutions from other agents as additional information, can you provide your answer to the math problem? \n The original math problem is {} Explain your reasonings. Your final answer should be a single numerical number, in the form \\boxed{{answer}}, at the end of your response.""".format(question)
        return {"role": "user", "content":prefix_string}
    def construct_message_from_other_players_theoremQA(self, other_players,question,use_llama_format=False,round=-1):
        agent_contexts = []
        for player in other_players:
            agent_contexts.append(player.conversations)
        if len(agent_contexts) == 0:
            return {"role": "user", "content": "Can you double check that your answer is correct. Please reiterate your answer, with your final answer a single numerical number, in the form \\boxed{{answer}}."}

        prefix_string = "These are the solutions to the problem from other agents: "

        for agent in agent_contexts:
            agent_response = agent[round]["content"]
            response = "\n\n One agent solution: ```{}```".format(agent_response)

            prefix_string = prefix_string + response

        if use_llama_format:
            prefix_string = prefix_string + """Can you solve the following math problem? Explain your reasoning. The final answer can only be one of the following forms:
1. a numerical value like 0.1, no symbol and no unit at all.
2. a list of number like [2, 3, 4].
3. True/False.
4. an option like (a), (b), (c), (d)
at the end of your response.
Question: {}
Answer: Let's think step by step.""".format(question)
        else:
            prefix_string = prefix_string + """\n\n Using the solutions from other agents as additional information, can you provide your answer to the math problem? \n The original math problem is {}. The final answer can only be one of the following forms:
1. a numerical value like 0.1, no symbol and no unit at all.
2. a list of number like [2, 3, 4].
3. True/False.
4. an option like (a), (b), (c), (d)
Your final answer should be in the form \\boxed{{answer}}, at the end of your response.""".format(question)
        return {"role": "user", "content":prefix_string}
    def construct_message_from_other_players_hotpotqa(self, other_players,question,context,fewshots,round=-1):
        agent_contexts = []
        for player in other_players:
            agent_contexts.append(player.conversations)
        if len(agent_contexts) == 0:
            return {"role": "user", "content": "Can you double check that your answer is correct. Please reiterate your answer, with your final answer a single phrase, in the form [answer]."}

        prefix_string = "These are the solutions to the problem from other agents: "

        for agent in agent_contexts:
            agent_response = agent[round]["content"]
            response = "\n\n One agent solution: ```{}```".format(agent_response)

            prefix_string = prefix_string + response

        if len(fewshots) == 0:
            # prefix_string = prefix_string + """\n\n Using the solutions from other agents as additional information, can you provide your answer to the question answering problem? \n The original question answering problem is {}. Your final answer should be a single name phrase, in the form [answer], at the end of your response.""".format(question)
            prefix_string = prefix_string + """\n\n Using the solutions from other agents as additional information, can you provide your answer to the question answering problem? \n Your final answer should be a single name phrase, in the form [answer], at the end of your response.\nQuestion: {}\nAnswer:""".format(question)
        else:
            prefix_string = prefix_string + """\n\n Using the solutions from other agents as additional information, can you provide your answer to the question answering problem? \n Your final answer should be a single name phrase, in the form [answer], at the end of your response. \n\n Here are the few-shot examples: {}\n(END OF EXAMPLES) Question: {}\nAnswer:""".format(fewshots, question)
        return {"role": "user", "content":prefix_string}
    def construct_message_from_other_players_CSQA(self, other_players,question,round=-1):
        agent_contexts = []
        for player in other_players:
            agent_contexts.append(player.conversations)
        if len(agent_contexts) == 0:
            return {"role": "user", "content": "Can you double check that your answer is correct. Please reiterate your answer, with your final answer a single choice, in the form [choice]."}

        prefix_string = "These are the solutions to the problem from other agents: "

        for agent in agent_contexts:
            agent_response = agent[round]["content"]
            response = "\n\n One agent solution: ```{}```".format(agent_response)

            prefix_string = prefix_string + response

        prefix_string = prefix_string + """\n\n Using the solutions from other agents as additional information, can you provide your answer to the question answering problem?\nYou final answer should be of the form 'So the answer is (choice).', at the end of your response, where choice = a/b/c/d/e\nQuestion: {}\nAnswer:""".format(question)
        return {"role": "user", "content":prefix_string}
    # def get_completion(self, partial_context = False, use_long_context = True):
    #     temp_config = copy.deepcopy(self.config)
    #     max_tokens = temp_config["max_tokens"]
    #     temp_model = self.llm_agent

    #     conversations = self.conversations
    #     if partial_context:
    #         conversations = conversations[-1:]
    #     while max_tokens > 20:
    #         try:
    #             completion,llm_outputs = temp_model(conversations)
    #             break
    #         except Exception as e:
    #             if e.__class__.__name__ == "InvalidRequestError":
    #                 max_tokens = max_tokens//2
    #                 print(f"new attemp with max tokens = {max_tokens}")
    #                 temp_config["max_tokens"] = max_tokens
    #                 temp_model = overall_utils._load_model(temp_config)
    #             else:
    #                 print(e)
    #                 exit(0)
    #     else:
    #         if use_long_context:
    #             temp_config = copy.deepcopy(self.config)
    #             temp_config["model_name"] = "gpt-3.5-turbo-16k-0613"
    #             temp_model = overall_utils._load_model(temp_config)
    #             completion,llm_outputs = temp_model(conversations)
    #     print(completion, llm_outputs)
    #     prompt_tokens = llm_outputs['token_usage']['prompt_tokens']
    #     completion_tokens = llm_outputs['token_usage']['completion_tokens']
    #     completions = []
    #     if isinstance(completion, list):
    #         for c in completion:
    #             completions.append(c.message.content)
    #     self.prompt_token_used += prompt_tokens
    #     self.completion_token_used += completion_tokens

    #     return completions
    def get_completion(self, partial_context = False, use_long_context = True, temperature=None,n=None):
        model = self.model
        if isinstance(self.model, list):
            model = self.model[0]
        conversations = self.conversations
        if partial_context:
            conversations = conversations[-1:]
        # print(conversations)
        model_name = self.config["model_name"]
        if temperature is None:
            temperature = self.config["temperature"]
        if n is None:
            n = self.config["n"]
        max_tokens = self.config["max_tokens"]
        max_retries = self.config["max_retries"]
        delay = 5
        while max_retries > 0:
            try:
                completion = model.chat.completions.create(model=model_name,
                                                                    messages = conversations,
                                                                    temperature=temperature,
                                                                    n = n,
                                                                    max_tokens = max_tokens)
                break
            except Exception as e:
                max_retries -= 1
                logging.info(f"An unexpected error occurred: {e} retry after {delay}...")
                max_tokens = max_tokens-300
                if max_retries > 0:
                    time.sleep(delay)
                    delay *= 2
                else:
                    raise

        prompt_tokens = completion.usage.prompt_tokens
        completion_tokens = completion.usage.completion_tokens
        completions = []
        for c in completion.choices:
            completions.append(c.message.content)
        # print(completion)
        self.prompt_token_used += prompt_tokens
        self.completion_token_used += completion_tokens

        return completions


def parse_pred_answer(dataset: str, generated_response: str, use_json_format=False,json_key="Answer"):
    if isinstance(generated_response, str):
        generated_response = [generated_response]
        
    if dataset == "gsm8k":
        if use_json_format:
            responses = []
            for gr in generated_response:
                try:
                    judge_response = eval(gr)
                except:
                    judge_response = {f"{json_key}":''}
                responses.append(judge_response[f"{json_key}"])
            return responses
        
        return [remove_boxed(last_boxed_only_string(gr)) for gr in generated_response]
    elif dataset == "counterintuitive_AR":
        return [parse_answer(gr) for gr in generated_response]
    elif dataset == "MATH":
        return [remove_boxed(last_boxed_only_string(gr)) for gr in generated_response]
    else:
        return [remove_boxed(last_boxed_only_string(gr)) for gr in generated_response]



class ZEROSHOTMAD(object):
    def __init__(self, models_tokenizers,question, answer,example,config) -> None:
        self.question = question
        self.answer = answer
        self.configs = config
        self.prompt_token_used = 0
        self.completion_token_used = 0
        self.cost = 0
        self.num_rounds = config["num_rounds"]
        self.num_agents = config["num_agents"]
        self.players = []
        self.models_tokenizers = models_tokenizers
        self.prompt_strategy = config["prompt_strategy"]
        self.example = example

        ## logging
        self.step_num = 0
        self.infos = []
        self.all_raw_solutions = []
        self.all_solutions = []
        ###

    def run(self):
        raise NotImplementedError
    def _create(self, config, llm_agent, model_name, name):
        return DEBATE_PLAYER(config,llm_agent, model_name, name)
    def _create_debate_player(self, name):
        return self._create(self.player_config, self.llm_agent_player,self.player_model_name, name)
    def _init_models(self):
        self.player_config = self.configs["llms"]["llm_agent_player"]
        self.llm_agent_player = self.models_tokenizers[0]
        self.player_model_name = self.player_config["model_name"]
    def _init_players(self):
        self._init_models()
        for i in range(self.num_agents):
            self.players.append(self._create_debate_player(f"player{i}_round0"))

    def _log(self):

        recordings = {'x': self.question, "answer": self.answer, 
                      "all_solutions":self.all_solutions,
                      "all_raw_solutions":self.all_raw_solutions,
                        "agents":{}}
        if "answer_type" in self.example:
            recordings["answer_type"] = self.example["answer_type"]
            recordings["field"] = self.example["field"]
        for i,agent in enumerate(self.players):
            name = agent.name
            self.prompt_token_used += agent.prompt_token_used
            self.completion_token_used += agent.completion_token_used
            self.cost += overall_utils.gpt_usage(agent.prompt_token_used, agent.completion_token_used,agent.model_name)["cost"]
            conversations = agent.conversations
            format_conversations = [[a["role"], a["content"]] for a in conversations]
            recordings["agents"][name] = format_conversations
        # compute cost
        
        self.infos.append(recordings)








class ZEROSHOTMADJUDGE(object):
    def __init__(self, models_tokenizers,question, answer,example,config) -> None:
        self.question = question
        self.answer = answer
        self.configs = config

        self.prompt_token_used = 0
        self.completion_token_used = 0
        self.cost = 0
        self.num_rounds = config["num_rounds"]
        self.num_agents = 2
        self.players = []
        self.json_format = True
        self.models_tokenizers = models_tokenizers


        ## logging
        self.infos = []
        self.correct = False
        self.base_solution = ""
        self.all_raw_solutions = []
        self.all_solutions = []
        self.final_solution = ""
        ###
    
    def _create(self, config, llm_agent, model_name, name):
        return create_player(config,llm_agent, model_name, name)
    def _create_player(self, name):
        return self._create(self.player_config,self.llm_agent_player, self.player_model_name, name)
    def _create_judge(self, name):
        return self._create(self.judge_config,self.llm_agent_judge, self.judge_model_name, name)
    def init_players(self):
        self.player_config = self.configs["llms"]["llm_agent_player"]
        self.judge_config = self.configs["llms"]["llm_agent_judge"]
        self.player_model_name = self.player_config["model_name"]
        self.judge_model_name = self.judge_config["model_name"]

        self.llm_agent_player = self.models_tokenizers[0]
        self.llm_agent_judge = self.models_tokenizers[1]
        for i in range(self.num_agents):
            self.players.append(self._create_player(f"player{i}"))
        self.players.append(self._create_judge("moderator"))
    def _append_to_all_solutions(self, judge_response):
        print(f"judge_response: {judge_response}")

        parsed_answer = parse_pred_answer(self.configs["dataset"], judge_response, use_json_format=self.json_format,json_key="Debate Answer")
        self.all_solutions.append(parsed_answer)
    def run(self):
        raise NotImplementedError
    def _round_dct(self, num: int):
        dct = {
            1: 'first', 2: 'second', 3: 'third', 4: 'fourth', 5: 'fifth', 6: 'sixth', 7: 'seventh', 8: 'eighth', 9: 'ninth', 10: 'tenth'
        }
        return dct[num]
    def _log(self):
        answer_value = parse_pred_answer(self.configs["dataset"], self.answer)

        recordings = {'x': self.question, "answer": self.answer, "answer_value":answer_value,
                      "base_agent":self.base_solution,
                      "agents":{},
                      "final_solution":self.final_solution,
                      "all_solutions":self.all_solutions,
                      "all_raw_solutions":self.all_raw_solutions,}
        
        for i,agent in enumerate(self.players):
            name = agent.name
            self.prompt_token_used += agent.prompt_token_used
            self.completion_token_used += agent.completion_token_used
            self.cost += overall_utils.gpt_usage(agent.prompt_token_used, agent.completion_token_used,agent.model_name)["cost"]
            conversations = agent.conversations
            format_conversations = [[a["role"], a["content"]] for a in conversations]
            recordings["agents"][name] = format_conversations
        self.infos.append(recordings)


        



class ZEROSHOT_FEEDBACK(object):
    def __init__(self, models_tokenizers,question, answer,example,config) -> None:
        self.question = question
        self.answer = answer
        self.configs = config
        self.prompt_token_used = 0
        self.completion_token_used = 0
        self.cost = 0
        self.num_rounds = config["num_rounds"]
        self.num_agents = config["num_agents"]
        self.strategy = config["strategy"]
        self.players = []
        self.prompt_strategy = config["prompt_strategy"]
        self.cot_resulst_json = False
        self.example = example
        self.models_tokenizers = models_tokenizers
        ## logging
        self.step_num = 0
        self.infos = []
        self.all_raw_solutions = []
        self.all_raw_values = []
        self.all_raw_feedbacks = []
        self.all_solutions = []
        self.all_values = [] # all the valuations
        self.all_feedbacks = []
        ###

    def _create(self, config,llm_agent, model_name, name):
        return create_player(config,llm_agent, model_name, name)
    def _create_player(self, name):
        return self._create(self.player_config,self.llm_agent_player, self.player_model_name, name)
    def _create_feedback(self, name):
        return self._create(self.feedback_config, self.llm_agent_feedback, self.feedback_model_name, name)
    def _init_models(self):
        self.player_config = self.configs["llms"]["llm_agent_player"]
        self.feedback_config = self.configs["llms"]["llm_agent_feedback"]
        self.player_model_name = self.player_config["model_name"]
        self.feedback_model_name = self.feedback_config["model_name"]

        self.llm_agent_player = self.models_tokenizers[0]
        self.llm_agent_feedback = self.models_tokenizers[1]
    def _init_players(self):
        self._init_models()
        for i in range(self.num_rounds):
            self.players.append(self._create_player(f"agent{i}"))
            self.players.append(self._create_feedback(f"relfection_agent{i}"))
    def _init_CoT_players(self):
        self._init_models()
        self.players.append(self._create_player(f"agent{0}"))

    def run(self):
        raise NotImplementedError
    def _check_correctness(self, answer, pred_answer):
        prased_answer = parse_pred_answer(self.configs["dataset"], pred_answer)
        if prased_answer != None:
            if prased_answer == answer:
                return True
        return False

    def _log(self):
        recordings = {'x': self.question, "answer": self.answer, "answer_value":self.answer_value,
                      "agents":{},
                      "all_solutions":self.all_solutions,
                      "all_raw_solutions":self.all_raw_solutions,
                      "all_values":self.all_values}
        if "answer_type" in self.example:
            recordings["answer_type"] = self.example["answer_type"]
            recordings["field"] = self.example["field"]
        for i,agent in enumerate(self.players):
            name = agent.name
            self.prompt_token_used += agent.prompt_token_used
            self.completion_token_used += agent.completion_token_used
            self.cost += overall_utils.gpt_usage(agent.prompt_token_used, agent.completion_token_used,agent.model_name)["cost"]
            conversations = agent.conversations
            format_conversations = [[a["role"], a["content"]] for a in conversations]
            recordings["agents"][name] = format_conversations
        self.infos.append(recordings)