import re, string, os
from typing import List, Union, Literal
from enum import Enum
import json
import copy
import itertools
import tiktoken
import numpy as np
from langchain import OpenAI, Wikipedia
from langchain.llms.base import BaseLLM
from langchain.agents.react.base import DocstoreExplorer
from langchain.docstore.base import Docstore
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
import sys
from langchain.schema import (
    ChatMessage,
    ChatResult,
    AIMessage,
    HumanMessage,
    SystemMessage,
)

sys.path.append("..")
from .prompts_game_of_24 import standard_prompt, cot_prompt, propose_prompt, value_prompt,value_last_step_prompt, propose_prompt_123,GAME24_TEMPLATES,REFLEXION_PROMPT_TEMPLATES,REFLECTION_HEADER
import overall_utils
from agents.Agents import PLAYER

import logging
import sympy
logging.basicConfig(format='%(asctime)s [%(filename)s:%(funcName)s:%(lineno)d] %(levelname)s: %(message)s', level=logging.INFO)
completion_tokens = prompt_tokens = 0

np.random.seed(0)
class ThoughtTreeNode:
     
    # Trie node class
    def __init__(self,parent,thought):
        self.parent = parent
        self.children = []
        self.thought = thought
        self.value = 0

        self.isEndOfChain= False
    def get_logic_chain_str(self):
        if self.parent:
            return self.parent.get_logic_chain_str() + self.thought + "\n"
        else:
            return self.thought
 
class ThoughtTree:
     
    # Trie data structure class
    def __init__(self):
        self.root = self.getNode()
        self.total_number_nodes = 1
        self.total_chains = 1
 
    def getNode(self):
     
        # Returns new trie node (initialized to NULLs)
        return ThoughtTreeNode(None, "")

class PROPOSER(PLAYER):
    def __init__(self,config, llm_agent,model_name,name):
        super().__init__(config, llm_agent,model_name,name)
    def get_completion(self):
        conversations = [self.conversations]
        if "gpt" not in self.model_name:
            conversations = [_.content for _ in conversations[0]]
        n = self.config["n"]
        all_completions = []
        while n > 0:
            cnt = min(n, 20)
            n -= cnt
            temp_config = copy.deepcopy(self.config)
            temp_config["n"] = cnt
            temp_model = overall_utils._load_model(temp_config)
            completion = temp_model.generate(conversations)
            llm_outputs = completion.llm_output
            generations = completion.generations[0]
            prompt_tokens = llm_outputs['token_usage']['prompt_tokens']
            completion_tokens = llm_outputs['token_usage']['completion_tokens']
            completions = []
            if isinstance(generations, list):
                for c in generations:
                    completions.append(c.text)
            all_completions.extend(completions)
            print(completion)
            print(len(completions))
            print(llm_outputs)

            # logging.debug(f"agent_context: {agent_context}, completion: {completion}")
            # logging.debug(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
            self.prompt_token_used += prompt_tokens
            self.completion_token_used += completion_tokens

        return all_completions
class EVALUATOR(PLAYER):
    def __init__(self,config, llm_agent,model_name,name):
        super().__init__(config, llm_agent,model_name,name)
    # def add_assistant_message(self,messages):
    #     for message in messages:
    #         self.conversations.append(AIMessage(content = message))
    # def get_completion(self):
    #     completion,llm_outputs = self.llm_agent(self.conversations)
    #     prompt_tokens = llm_outputs['token_usage']['prompt_tokens']
    #     completion_tokens = llm_outputs['token_usage']['completion_tokens']
    #     print(completion)
    #     completions = []
    #     if isinstance(completion, list):
    #         for comp in completion:
    #             message = comp.message
    #             completion = message.content
    #             completions.append(completion)

    #     self.prompt_token_used += prompt_tokens
    #     self.completion_token_used += completion_tokens
    #     return completions
        

def get_current_numbers(y: str) -> str:
    last_line = y.strip().split('\n')[-1]
    return last_line.split('left: ')[-1].split(')')[0]


class Game_of_24_Agent:
    def __init__(self,
                    question: str,
                    steps: int = 4,
                    chat=True,
                    configs = None,
                    ) -> None:
        self.question = question
        self.steps = steps
        self.chat = chat
        self.configs = configs
        self.prompt_token_used = 0
        self.completion_token_used =0
        self.cost = 0
        self.answers=[]
        self.stops = ['\n'] * 4
        self.value_cache = {}
        self.agents = []

        self.cur_step = 0
        self.infos = []
        self.prompt_strategy = self.configs["prompt_strategy"]
        self.strategy = self.configs["strategy"]
        if self.prompt_strategy in GAME24_TEMPLATES:
            self.propose_prompt = GAME24_TEMPLATES[self.prompt_strategy][0]
            self.cot_prompt = GAME24_TEMPLATES[self.prompt_strategy][1]
            self.value_prompt = GAME24_TEMPLATES[self.prompt_strategy][2]
            self.value_last_step_prompt = GAME24_TEMPLATES[self.prompt_strategy][3]
        elif self.prompt_strategy in REFLEXION_PROMPT_TEMPLATES:
            self.propose_prompt = REFLEXION_PROMPT_TEMPLATES[self.prompt_strategy][0]
            self.cot_prompt = REFLEXION_PROMPT_TEMPLATES[self.prompt_strategy][2]
            self.value_prompt = REFLEXION_PROMPT_TEMPLATES[self.prompt_strategy][1]


    def _create_proposer(self, name):
        return PROPOSER(self.proposer_config, self.proposer_llm_agent, self.proposer_model_name, name)
    def _create_evaluator(self, name):
        return EVALUATOR(self.evaluator_config,self.evaluator_llm_agent,self.evaluator_model_name,name)
    def _init_models(self):
        self.proposer_config = self.configs["llms"]["proposer"]
        self.evaluator_config = self.configs["llms"]["evaluator"]
        self.proposer_model_name = self.proposer_config["model_name"]
        self.evaluator_model_name = self.evaluator_config["model_name"]

        self.proposer_llm_agent = overall_utils._load_model(self.proposer_config)
        self.evaluator_llm_agent = overall_utils._load_model(self.evaluator_config)
    def _init_tree(self):
        self.thoughttree = ThoughtTree()
    def run(self,same_first_step_as_i) -> None:
        self.same_first_step_as_i = same_first_step_as_i
        if "random_choose_branch" in self.configs:
            self.random_choose_branch = self.configs["random_choose_branch"]
        else:
            self.random_choose_branch = False
        if self.strategy == "bfs":
            self._init_models()
            self._init_tree()
            intermediate_thoughts = [self.thoughttree.root]
            for i in range(self.steps):
                logging.info(f"step {i}")
                select_new_steps = self.step(i, intermediate_thoughts)
                intermediate_thoughts = select_new_steps
                self.cur_step += 1
            self._compute_cost()
            agent_dict = self._log_agents()
        elif self.strategy == "CoT":
            self._init_models()
            self._init_tree()
            intermediate_thoughts = [self.thoughttree.root]
            select_new_steps = self.step(0, intermediate_thoughts)
            intermediate_thoughts = select_new_steps
            self.cur_step += 1
            self._compute_cost()
            agent_dict = self._log_agents()
        elif self.strategy == "Reflexion":
            self.all_feedbacks = []
            self.all_solutions = []
            self._init_models()
            intermediate_thoughts = []
            for i in range(self.steps):
                logging.info(f"\n---------step {i}")
                select_new_steps = self.step_reflexion(i, intermediate_thoughts)
                intermediate_thoughts = select_new_steps
                self.cur_step += 1
            self._compute_cost()
            agent_dict = self._log_agents()
            return self.all_solutions, {'steps': self.infos, "agents": agent_dict}
        return [_.get_logic_chain_str() for _ in intermediate_thoughts], {'steps': self.infos, "agents": agent_dict}


    def step(self,step,intermediate_thoughts, to_print=False) -> None:
        intermediate_thoughts_str = [_.get_logic_chain_str() for _ in intermediate_thoughts]
        if self.configs["method_generate"] == "propose":
            new_steps = [self.propose_inter_step(_,step,idx) for idx,_ in enumerate(intermediate_thoughts)]
        elif self.configs["method_generate"] == "sample":
            cur_config = copy.deepcopy(self.configs)
            cur_config["llms"]["model_kwargs"]["stop"] = self.stops[step]
            new_steps = [self.sample_inter_step(_) for _ in intermediate_thoughts]

        new_steps = list(itertools.chain(*new_steps))
        new_steps_str = [_.get_logic_chain_str() for _ in new_steps]
        logging.debug(f"new_steps: {new_steps_str}")
        ids = list(range(len(new_steps)))

        values = []
        if not self.random_choose_branch:
            # evaluation
            if self.configs["method_evaluate"] == 'vote':
                values = self.get_votes(task, x, new_ys, args.n_evaluate_sample)
            elif self.configs["method_evaluate"]  == 'value':
                values = self.get_values(step,new_steps)
            # selection
            if self.configs["method_select"] == 'sample':
                ps = np.array(values) / sum(values)
                select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
            elif self.configs["method_select"]  == 'greedy':
                select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:self.configs["n_select_sample"]]
        elif self.random_choose_branch:
            select_ids = np.random.choice(ids, size=self.configs["n_select_sample"], replace=False).tolist()
            print(select_ids)
        select_new_steps = [new_steps[select_id] for select_id in select_ids]

        # log
        select_new_steps_log = [_.get_logic_chain_str() for _ in select_new_steps]
        if to_print: 
            sorted_new_ys, sorted_values = zip(*sorted(zip(new_steps_str, values), key=lambda x: x[1], reverse=True))
            logging.info(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_steps_log}\n')
        
        self._log(step,intermediate_thoughts_str,new_steps_str,values,select_new_steps_log)        
    
        return select_new_steps
    def step_reflexion(self,step,intermediate_thoughts, to_print=True) -> None:

        # reflect
        reflect_responses = []
        if step > 0:
            evaluator = self._create_evaluator(f"evaluator{step}_{0}")
            self.agents.append(evaluator)
            prompt = self.value_prompt.format(input=self.question,prev_ans=intermediate_thoughts[0])
            print(f"evaluator{step}_{0} prompt: {prompt}")
            evaluator.add_user_message(prompt)
            evaluator_response = evaluator.get_completion()
            evaluator.add_assistant_message(evaluator_response)

            for reflect_response in evaluator_response:
                cur_reflections = self.all_feedbacks + [reflect_response]
                reflection_str = REFLECTION_HEADER + 'Reflections:\n- ' + '\n- '.join([r.strip() for r in cur_reflections])
                reflect_responses.append(reflection_str)
            self.all_feedbacks.extend(evaluator_response)

        # propose
        if step == 0:
            proposer = self._create_proposer(f"proposer{step}_{0}")
            self.agents.append(proposer)
            prompt = self.propose_prompt.format(input=self.question)+'Steps:'
            print(f"proposer{step} prompt: {prompt}")
            proposer.add_user_message(prompt)
            proposer_response = proposer.get_completion()
            proposer.add_assistant_message(proposer_response)
        elif step > 0:
            for j, reflect_response in enumerate(reflect_responses):
                proposer = self._create_proposer(f"proposer{step}_{j}")
                self.agents.append(proposer)
                prompt = self.cot_prompt.format(reflections=reflect_response,input = self.question)+'Steps:'
                print(f"proposer{step}_{j} prompt: {prompt}")
                proposer.add_user_message(prompt)
                proposer_response = proposer.get_completion()
                proposer.add_assistant_message(proposer_response)
        self.all_solutions.extend(proposer_response)



        # evaluation
        ############
        # log
        select_new_steps_log = [proposer_response[0]]
        
        self._log(step,[],reflect_responses,[],select_new_steps_log)        
    
        return proposer_response

    def propose_inter_step(self, prev_step,step,idx):
        # prev_step is a ThoughtTreeNode
        # return a list of new leaves
        prompt = self._build_propose_prompt(prev_step)
        logging.info(f"propose prompt: {prompt}")
        proposer = self._create_proposer(f"proposer{step}_{idx}")
        self.agents.append(proposer)
        proposer.add_user_message(prompt)
        if self.same_first_step_as_i != None and step == 0:
            proposer_response = [_.strip() for _ in self.same_first_step_as_i["steps"][0]["new_steps"]]
            proposer_response = ["\n".join(proposer_response)]
        else:
            proposer_response = proposer.get_completion()
        proposer.add_assistant_message(proposer_response)
        if self.strategy == "CoT":
            proposals = proposer_response
            leaf_nodes = []
            for i in range(len(proposals)):
                proposal = proposals[i]
                proposal_splits = proposal.split('\n')
                temp = prev_step
                for j in range(len(proposal_splits)-1):
                    new_node = ThoughtTreeNode(temp,proposal_splits[j])
                    temp = new_node
                new_node = ThoughtTreeNode(temp,proposal_splits[-1])
                leaf_nodes.append(new_node)
        else:
            output = proposer_response[0]
            logging.info(f"output: {output}")
            proposals = output.split('\n')
            leaf_nodes = []
            for i in range(len(proposals)):
                new_node = ThoughtTreeNode(prev_step,proposals[i])
                leaf_nodes.append(new_node)
        return leaf_nodes
    def sample_inter_step(self, prev_step):
        # prev_step is a ThoughtTreeNode
        # return a list of new leaves
        # sample requires prompting i.i.d. 
        pass

    def value_outputs_unwrap(self, y: str, value_outputs: list) -> float:
        value_names = [_.split('\n')[-1].lower() for _ in value_outputs]
        value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20}  # TODO: ad hoc
        value = sum(value * sum([value_name.count(name) for value_name in value_names]) for name, value in value_map.items())
        return value
    def get_value(self, step, overall_step_idx,idx,cache_value=True):
        value_prompt = self._build_value_prompt(step)
        y = step.get_logic_chain_str()
        if len(y.strip().split('\n')) == 4 and 'answer' not in y.lower():
            print("return 0 because there are 4 lines and 'answer' not inside.")
            return 0
        evaluator = self._create_evaluator(f"evaluator{overall_step_idx}_{idx}")
        self.agents.append(evaluator)
        evaluator.add_user_message(value_prompt)
        evaluator_response = evaluator.get_completion()
        evaluator.add_assistant_message(evaluator_response)

        value_prompt_str = value_prompt
        if cache_value and value_prompt_str in self.value_cache:
            return self.value_cache[value_prompt_str]
        
        logging.info(f"value_prompt: {value_prompt}")
        value_outputs = evaluator_response # value_outputs should be a list of ChatGenerations
        logging.info(f"value_outputs: {value_outputs}")
        value = self.value_outputs_unwrap(y, value_outputs)
        if cache_value:
            self.value_cache[value_prompt_str] = value
        logging.debug(f"value: {value}")
        return value
    def get_values(self, overall_step_idx,new_steps,cache_value = True):
        values = []
        local_value_cache = {}
        for idx, step in enumerate(new_steps):  # each partial output
            step_str = step.get_logic_chain_str()
            if step_str in local_value_cache:  # avoid duplicate candidates
                value = 0
            else:    
                value = self.get_value(step, overall_step_idx,idx,cache_value=cache_value)
                local_value_cache[step_str] = value
            values.append(value)
        return values
        
    def get_votes(self, task, x, ys, n_generate_sample):
        pass
    


    def _build_propose_prompt(self,prev_step, output_str = False) -> str:
        prev_step_string = prev_step.get_logic_chain_str()
        current_numbers = get_current_numbers(prev_step_string if prev_step_string else self.question)
        if current_numbers == '24' or self.strategy == "CoT":
            prompt = self.cot_prompt.format(input=self.question) + 'Steps:' + prev_step_string
            # print([prompt])
        else:
            prompt = self.propose_prompt.format(input=current_numbers)

        return prompt
    def _build_value_prompt(self,step, output_str = False) -> str:
        step_str = step.get_logic_chain_str()
        last_line = step.thought
        if 'left: ' not in last_line or self.strategy == "CoT":  # last step
            ans = last_line.lower().replace('answer: ', '')
            # print([value_last_step_prompt.format(input=x, answer=ans)])
            prompt = self.value_last_step_prompt.format(input=self.question, answer=ans)
        else:
            current_numbers = get_current_numbers(step_str)
            prompt = self.value_prompt.format(input=current_numbers)
        return prompt
 
    def is_finished(self) -> bool:
        return self.finished
    def test_output(self, output: str):
        data = self.question
        expression = output.strip().split('\n')[-1].lower().replace('answer: ', '').split('=')[0]
        numbers = re.findall(r'\d+', expression)
        # print(f"numbers: {numbers}")
        problem_numbers = re.findall(r'\d+', data)
        # print(f"problem_numbers: {problem_numbers}")
        if sorted(numbers) != sorted(problem_numbers):
            return {'r': 0}
        try:
            # print(f"expression: {expression}")
            # print(sympy.simplify(expression))
            return {'r': int(sympy.simplify(expression) == 24)}
        except Exception as e:
            # print(e)
            return {'r': 0}

    def _compute_cost(self):
        for agent in self.agents:
            self.prompt_token_used += agent.prompt_token_used
            self.completion_token_used += agent.completion_token_used
            self.cost += overall_utils.gpt_usage(agent.prompt_token_used, agent.completion_token_used,agent.model_name)["cost"]
    def _log(self,step,intermediate_thoughts_str,new_steps_str,values,select_new_steps_log):
        self.infos.append({'step': step, 'x': self.question, 'cur_steps': intermediate_thoughts_str, 'new_steps': new_steps_str, 'values': values, 'select_new_ys': select_new_steps_log})
    def _log_agents(self):
        agent_dict = {}
        for i,agent in enumerate(self.agents):
            name = agent.name
            conversations = agent.conversations
            format_conversations = [[a.type, a.content] for a in conversations]
            agent_dict[name] = format_conversations 
        return agent_dict

# class Game_of_24_Agent_GPT35:


