from model.transformer import *
from util.batch_generator import *
from util.files import *
from util.trainer import EMNLPTrainer
import os
from util.args import EMNLPArgument
from util.losses import *

import apex
from pytorch_transformers import WarmupLinearSchedule

def get_model(args):
    model = Transformer_Model(args.vocab_size, args.batch_seqlen, args.hidden_dim, args.projection_dim, args.n_heads,
                              args.head_dim, args.n_layers, args.cutoffs, args.dropout_rate, args.dropatt_rate,
                              args.padding_index, rel_att=args.relative_pos, experimental_loss=args.experimental_loss)
    if args.model_checkpoint=="":
        initializer = Initializer('normal', 0.02, 0.1)
        initializer.initialize(model)
    else:
        state_dict=torch.load(args.model_checkpoint)
        model.load_state_dict(state_dict)

    model = model.to(args.device)
    return model

def get_batchfier(args):
    if args.dataset =='bugs':
        train_batchfier = Lyrics_Batchfier([args.train_path], args.batch_size, seq_len=args.batch_seqlen,
                                           padding_index=args.padding_index, epoch_shuffle=True)
        test_batchfier = Lyrics_Batchfier([args.test_path], args.batch_size, seq_len=args.batch_seqlen,
                                          padding_index=args.padding_index, epoch_shuffle=True)
    else:
        train_batchfier = BpttIterator(load_json(args.train_path), args.batch_size, args.batch_seqlen, device=args.device)
        test_batchfier = BpttIterator(load_json(args.test_path), args.batch_size, args.batch_seqlen, device=args.device)
    return train_batchfier, test_batchfier

def get_loss(args):
    lt = args.loss_type
    if lt in ('experimental', 'experimental2'):
        loss = FactorizedLoss(args.padding_index)
    elif lt == 'plain':
        loss = PlainLoss(args.padding_index)
    elif lt == 'unlikelihood-token':
        loss = CandidateLoss(rank_alpha=1.0, padding_idx=args.padding_index)
    elif lt == 'face':
        loss = FACELoss(padding_idx=args.padding_index,vocab_size=args.vocab_size,ignore_freq_index=[args.padding_index],ft="out",wt="pre")
        if loss.ft=="out" and args.train_phase=="train":
            raise NotImplementedError("ft-out only can be used in fine-tune phase")
    elif "-seq" in lt:
        seq_loss = SequencePenaltyCriterion(4,50,100,"repeat")
        loss = CandidateLoss(rank_alpha=1.0, padding_idx=args.padding_index)
        loss=(seq_loss,loss)
    else:
        raise NotImplementedError
    return loss

def get_trainer(args, model, train_batchfier, test_batchfier):
    if args.dataset == 'bugs':
        optimizer = torch.optim.Adam(model.parameters(), args.learning_rate)
    else:
        optimizer = torch.optim.AdamW(model.parameters(), args.learning_rate, weight_decay=args.weight_decay)
    # optimizer = torch.optim.AdamW(model.parameters(), args.learning_rate, weight_decay=args.weight_decay)
    if args.mixed_precision:
        print('mixed_precision')
        opt_level = 'O2'
        model, optimizer = apex.amp.initialize(model, optimizer, opt_level=opt_level)
    decay_step = len(train_batchfier) * args.n_epoch
    scheduler = WarmupLinearSchedule(optimizer, args.warmup_step, decay_step)
    criteria = get_loss(args)
    trainer = EMNLPTrainer(model, train_batchfier, test_batchfier, optimizer, scheduler, args.update_step, criteria,
                      args.clip_norm, args.mixed_precision)
    return trainer

if __name__ == '__main__':
    args = EMNLPArgument()
    print(args.learning_rate, 'experimental : {} cutoffs : {}'.format(
        args.experimental_loss, len(args.cutoffs)))
    print(args.__dict__)
    model = get_model(args)
    train_batchfier, test_batchfier = get_batchfier(args)
    print(args.savename)
    trainer = get_trainer(args, model, train_batchfier, test_batchfier)
    prev_step = 0
    res = []
    if args.finetune:
        args.n_epoch=1
    for i in range(args.n_epoch):
        print('epoch {}'.format(i + 1))
        if not args.finetune:
            trainer.train_epoch()
            test_loss=trainer.test_epoch()
            savepath = os.path.join(args.savename + '_epoch_{}'.format(i))
            if not os.path.exists(os.path.dirname(savepath)):
                os.makedirs(os.path.dirname(savepath))
            torch.save(model.state_dict(),savepath)

        else:
            if "-seq" in args.loss_type:
                trainer.seq_level_finetune(args.savename,args)
                test_loss = trainer.test_epoch()
                res.append(test_loss)
            if args.loss_type=="face":
                args.nprefix = 50
                args.ngenerate = 100
                args.top_k = 1
                args.temperature = 1.0
                args.experimental_loss = False
                args.sampling_mode = 0 # hyperparams for measuring d-1 metric
                trainer.finetune_face(args)
    print(res)

    # train_lstm(model,batchfier,optimizer)
