# -*- coding: utf-8 -*-

import json
from typing import Optional, Dict
from collections import defaultdict

import torch
from tqdm import tqdm
from datasets import Dataset
from transformers import AutoTokenizer
from transformers import AutoTokenizer, DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

from .modeling_t5 import T5ForGenerativeDeduplication


ALL_SUPPORT_GEN_MODELS = {
    'T5ForGenerativeDeduplication': T5ForGenerativeDeduplication,
}


class GenDedup:
    def __init__(self,
                 model_name_or_path: str,
                 model_class_name: str = 'T5ForGenerativeDeduplication',
                 pretrained_model_name_or_path: Optional[str] = None,
                 max_length: Optional[int] = None,
                 gaussian_noise_prob: float = 0.1,
                 gaussian_noise_variance: float = 0.1):
        model_class = ALL_SUPPORT_GEN_MODELS[model_class_name]
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        self.model = model_class.from_pretrained(pretrained_model_name_or_path or model_name_or_path)
        self.model.setup_gd(
            gaussian_noise_prob=gaussian_noise_prob,
            gaussian_noise_variance=gaussian_noise_variance)
        self.max_length = max_length
        self.ds = None
        self.feature_columns = self.tokenizer('gen dedup').keys()

    def prepare_ds(self, ds: Dataset) -> Dataset:
        def preprocess_function(examples):
            inputs = [d for d in examples["sentence"]]
            targets = examples['labels']
            model_inputs = self.tokenizer(inputs, max_length=self.max_length, truncation=True)
            labels = self.tokenizer(text_target=targets, max_length=self.max_length, truncation=True)
            model_inputs["labels"] = labels["input_ids"]
            model_inputs["label_text"] = examples['labels']
            return model_inputs

        return ds.map(preprocess_function, batched=True)

    def fit(self,
            ds: Dataset,
            output_dir: str,
            batch_size: int = 32,
            epochs: int = 1,
            learning_rate: float = 1e-4,
            warmup_steps: int = 1000,
            logging_steps: int = 10,
            weight_decay: float = 0.01,
            gradient_accumulation_steps: int = 1,
            fp16: Optional[bool] = None,
            argument_kwargs: Optional[Dict] = None,
            trainer_kwargs: Optional[Dict] = None):
        self.ds = self.prepare_ds(ds)
        data_collator = DataCollatorForSeq2Seq(tokenizer=self.tokenizer, model=self.model)

        if argument_kwargs is None:
            argument_kwargs = {}
        if trainer_kwargs is None:
            trainer_kwargs = {}
        training_args = Seq2SeqTrainingArguments(
            output_dir=output_dir,
            evaluation_strategy="no",
            learning_rate=learning_rate,
            per_device_train_batch_size=batch_size,
            weight_decay=weight_decay,
            save_total_limit=1,
            num_train_epochs=epochs,
            logging_steps=logging_steps,
            warmup_steps=warmup_steps,
            gradient_accumulation_steps=gradient_accumulation_steps,
            fp16=fp16,
            predict_with_generate=True,
            **argument_kwargs
        )

        trainer = Seq2SeqTrainer(
            model=self.model,
            args=training_args,
            train_dataset=self.ds,
            tokenizer=self.tokenizer,
            data_collator=data_collator,
        )

        trainer.train()
        self.model.eval()

    def dedup(self,
              save_path: str,
              ds: Optional[Dataset] = None,
              threshold: float = 0.5,
              device: Optional[str] = None,
              generate_kwargs: Optional[Dict] = None) -> Dict:
        if generate_kwargs is None:
            generate_kwargs = {}

        self.model.eval()
        if device is None:
            device = self.model.device
        if ds is not None:
            ds = self.prepare_ds(ds)
        else:
            ds = self.ds
        
        duplicate_map = defaultdict(set)
        for obj in tqdm(ds):
            inputs = {}
            for name in self.feature_columns:
                inputs[name] = torch.LongTensor([obj[name]]).to(device)
            outputs = self.model.generate(
                **inputs,
                max_length=1,
                output_scores=True,
                return_dict_in_generate=True,
                **generate_kwargs)
            predict_id = outputs.sequences[0][-1]
            gen_label = self.tokenizer.decode([predict_id], skip_special_tokens=True)
            scores = torch.nn.functional.softmax(outputs.scores[0], dim=-1)
            proba = scores[0, predict_id]
            if gen_label in obj['label_text']:
                if threshold is not None and proba < threshold:
                    continue
                duplicate_map[obj['label_text']].add(obj['sentence'])
        
        total = 0
        with open(save_path, 'w') as writer:
            for label, texts in duplicate_map.items():
                texts = list(texts)
                for text in texts[1:]:
                    total += 1
                    writer.writelines(json.dumps({'text': text, 'labels': label}, ensure_ascii=False) + '\n')
        print(f'{total} duplicate text detected!')
        print(f'Duplicate text has been saved to {save_path}')
