#!/usr/bin/env python

import argparse
import math
import random
import time

import numpy
import numpy as np
import torch
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler
from orion.client import report_results

import treeOM
from data import PTBLoader
from hinton import plot
from utils import get_batch, repackage_hidden, generate_idx, generate_ground_truth

parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model')
parser.add_argument('--data', type=str, default='data/dependency/UD_English-PTB/en_ptb-ud',
                    help='location of the data corpus')
parser.add_argument('--lm_loss', type=float, default=1.0,
                    help='learn the language model')
parser.add_argument('--structure_loss', type=float, default=1.0,
                    help='learn the structure')
parser.add_argument('--truth_rate', type=float, default=None,
                    help='learn the structure')
parser.add_argument('--semantic_size', type=int, default=300,
                    help='size of word embeddings')
parser.add_argument('--syntax_size', type=int, default=100,
                    help='number of hidden units per layer')
parser.add_argument('--nslot', type=int, default=15,
                    help='number of layers')
parser.add_argument('--lr', type=float, default=0.003,
                    help='initial learning rate')
parser.add_argument('--clip', type=float, default=0.25,
                    help='gradient clipping')
parser.add_argument('--epochs', type=int, default=200,
                    help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=64, metavar='N',
                    help='batch size')
parser.add_argument('--bptt', type=int, default=70,
                    help='sequence length')
parser.add_argument('--dropoute', type=float, default=0.1,
                    help='dropout to remove words from embedding layer (0 = no dropout)')
parser.add_argument('--dropout', type=float, default=0.3,
                    help='dropout for rnn layers (0 = no dropout)')
parser.add_argument('--dropouto', type=float, default=0.5,
                    help='dropout applied to layers (0 = no dropout)')
parser.add_argument('--distribution', type=str, default='softmax',
                    help='parsing distribution')
parser.add_argument('--seed', type=int, default=141,
                    help='random seed')
parser.add_argument('--nonmono', type=int, default=2,
                    help='random seed')
parser.add_argument('--shuffle', action='store_true',
                    help='use CUDA')
parser.add_argument('--sample_structure', action='store_true',
                    help='use CUDA')
parser.add_argument('--independent', action='store_true',
                    help='use CUDA')
parser.add_argument('--cuda', action='store_true',
                    help='use CUDA')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
                    help='report interval')
randomhash = ''.join(str(time.time()).split('.'))
parser.add_argument('--save', type=str, default=randomhash + '.pt',
                    help='path to save the final model')
parser.add_argument('--wdecay', type=float, default=1.2e-5,
                    help='weight decay applied to all weights')
parser.add_argument('--resume', type=str, default='',
                    help='path of model to resume')
parser.add_argument('--optimizer', type=str, default='adam', choices=['sgd', 'adam'],
                    help='optimizer to use (sgd, adam)')
parser.add_argument('--when', nargs="+", type=int, default=[-1],
                    help='When (which epochs) to divide the learning rate by 10 - accepts multiple')
parser.add_argument('--finetuning', type=int, default=500,
                    help='When (which epochs) to switch to finetuning')
parser.add_argument('--philly', action='store_true',
                    help='Use philly cluster')
parser.add_argument('--device', type=int, default=0, help='select GPU')
parser.add_argument('--margin', type=float, default=1.0,
                    help='margin at rank loss')
parser.add_argument('--test-only', action='store_true')

args = parser.parse_args()
args.tied = True

assert 0.0 <= args.margin and args.margin <= 1.0

# Set the random seed manually for reproducibility.
torch.manual_seed(args.seed)
random.seed(args.seed)
numpy.random.seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.set_device(args.device)
    if not args.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")
    else:
        torch.cuda.manual_seed(args.seed)


###############################################################################
# Load data
###############################################################################

def model_save(fn):
    if args.philly:
        fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn)
    with open(fn, 'wb') as f:
        torch.save([model, scheduler, optimizer], f)


def model_load(fn):
    global model, scheduler, optimizer
    if args.philly:
        fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn)
    with open(fn, 'rb') as f:
        if args.cuda:
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
        model, scheduler, optimizer = torch.load(f, map_location=device)


import os
import hashlib

fn = 'corpus.{}.data'.format(hashlib.md5(args.data.encode()).hexdigest())
if args.philly:
    fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn)

print("loading data ...")
corpus = PTBLoader(data_path=args.data)

valid_data, valid_dist = corpus.batchify('valid', args.batch_size, args.cuda)
test_data, test_dist = corpus.batchify('test', 1, args.cuda)
vocab = corpus.dictionary

args.vocab_size = len(vocab)

print("done loading, vocabulary size: {}".format(args.vocab_size))

eval_batch_size = args.batch_size
test_batch_size = 1

###############################################################################
# Build the model
###############################################################################

criterion = nn.CrossEntropyLoss(reduction='none')
structure_criterion = nn.NLLLoss(reduction='none')

ntokens = args.vocab_size

if args.independent:
    assert '<s>' in vocab
    sos = vocab['<s>']
    if isinstance(sos, list):
        assert len(sos) == 1
        sos = sos[0]
else:
    sos = None

model = treeOM.OrderedMemory(args.semantic_size, args.syntax_size, args.nslot, args.vocab_size,
                             dropoute=args.dropoute, dropout=args.dropout, dropouto=args.dropouto,
                             sos=sos, sample_structure=args.sample_structure, distribution=args.distribution)

if args.cuda:
    model = model.cuda()
    criterion = criterion.cuda()
    structure_criterion = structure_criterion.cuda()
###

if args.resume:
    print('Resuming model ...')
    model_load(args.resume)
    # model.reset_semantic_parameter(args.semantic_size)
    if args.cuda:
        model.cuda()


params = list(model.parameters())
semantic_params = model.semantic_parameters()
syntax_params = model.syntax_parameters()
assert set(params) == set(semantic_params + syntax_params)
assert len(set(semantic_params) & set(syntax_params)) == 0

total_params = sum(x.size()[0] * x.size()[1] if len(x.size()) > 1 else x.size()[0] for x in params if x.size())
print('Args:', args)
print('Model total parameters:', total_params)

if not args.resume:
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(params, lr=args.lr, weight_decay=args.wdecay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0, 0.999),
                                     eps=1e-9, weight_decay=args.wdecay)
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', 0.5,
                                                   patience=args.nonmono, threshold=0)


###############################################################################
# Training code
###############################################################################
@torch.no_grad()
def evaluate(data_source, dist_source, batch_size=10, print_dist=False):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_loss = 0
    total_structure_loss = 0
    total_structure_acc = 0
    total_output_acc = 0
    hidden, prev_structure = model.init_hidden(batch_size)
    for i in range(0, data_source.size(1) - 1, args.bptt):
        data, targets = get_batch(data_source, i, args.bptt)
        structure_first = get_batch(dist_source, i, args.bptt + 1, token_data=False)

        if args.truth_rate == 1:
            ground_truth = generate_ground_truth(prev_structure, structure_first, model.nslot)
            ctrl_idx = ground_truth[:, :-1]
            trg_idx = ground_truth[:, 1:]
            output, probs, hidden = model(data, hidden, ctrl_idx, trg_idx)
        else:
            output, probs, hidden = model(data, hidden)

        structure_p, output_p, p, q = probs

        structure = structure_p.max(dim=-1)[1]
        structure_idx = generate_idx(prev_structure, structure, structure_first, model.nslot)
        prev_structure = structure
        ctrl_idx = structure_idx[:, :-1]
        trg_idx = structure_idx[:, 1:]

        loss = criterion(output, targets.reshape(-1)).view_as(data)

        if print_dist:
            tokens = data[-1].cpu().numpy()
            indexes = ctrl_idx[-1].cpu().numpy()
            p0 = p[-1].cpu().numpy()
            nll = loss[-1].cpu().numpy()
            for token_id, idx, dist, token_nll in zip(tokens, indexes, p0, nll):
                print('%15s\t%10.2f\t%s\t%2d\t%2d\t%s' % (vocab.idx2word[token_id], math.exp(token_nll),
                                                          (idx == numpy.argmax(dist)), idx, numpy.argmax(dist),
                                                          plot(dist, max_val=1.)))

        if sos is None:
            mask = torch.ones_like(loss)
        else:
            mask = (targets != sos).float()

        total_loss += (loss * mask).sum()
        total_structure_loss += (
                structure_criterion(p.clamp(min=1e-6).log().reshape(-1, model.nslot), ctrl_idx.reshape(-1)) * mask.reshape(-1)
                + structure_criterion(q.clamp(min=1e-6).log().reshape(-1, model.nslot), trg_idx.reshape(-1)) * mask.reshape(-1)
        ).sum()
        total_structure_acc += ((p.max(dim=-1)[1] == ctrl_idx).float() * mask).sum()
        total_output_acc += ((q.max(dim=-1)[1] == trg_idx).float() * mask).sum()

        hidden = repackage_hidden(hidden)

    if sos is None:
        total_length = data_source.size(0) * data_source.size(1)
    else:
        total_length = (data_source != sos).float().sum()
    return total_loss.item() / total_length, \
           total_structure_loss.item() / total_length, \
           total_structure_acc / total_length, \
           total_output_acc / total_length


def train(truth_rate):
    # Turn on training mode which enables dropout.
    model.train()

    train_data, train_dist = corpus.batchify('train', args.batch_size, args.cuda, args.shuffle)

    total_lm_loss = 0
    total_structure_loss = 0
    total_structure_acc = 0
    total_output_acc = 0
    start_time = time.time()
    hidden, prev_structure = model.init_hidden(args.batch_size)
    batch, i = 0, 0
    length = train_data.size(1)
    while i < length - 1 - 1:
        bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2.
        # Prevent excessively small or negative sequence lengths
        seq_len = max(5, int(np.random.normal(bptt, 5)))
        # There's a very small chance that it could select a very long sequence
        # length resulting in OOM seq_len = min(seq_len, args.bptt + 10)

        lr2 = optimizer.param_groups[0]['lr']
        optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt
        data, targets = get_batch(train_data, i, seq_len)
        structure_first = get_batch(train_dist, i, seq_len + 1, token_data=False)

        # Starting each batch, we detach the hidden state from how it was
        # previously produced. If we didn't, the model would try
        # backpropagating all the way to start of the dataset.
        hidden = repackage_hidden(hidden)
        optimizer.zero_grad()

        if random.random() < truth_rate:
            ground_truth = generate_ground_truth(prev_structure, structure_first, model.nslot)
            ctrl_idx = ground_truth[:, :-1]
            trg_idx = ground_truth[:, 1:]
            output, probs, hidden = model(data, hidden, ctrl_idx, trg_idx)
        else:
            output, probs, hidden = model(data, hidden)
        structure_p, output_p, p, q = probs

        structure = structure_p.max(dim=-1)[1]
        structure_idx = generate_idx(prev_structure, structure, structure_first, model.nslot)
        prev_structure = structure
        ctrl_idx = structure_idx[:, :-1]
        trg_idx = structure_idx[:, 1:]

        lm_loss = criterion(output, targets.reshape(-1))
        structure_loss = \
            structure_criterion(p.clamp(min=1e-6).log().reshape(-1, model.nslot), ctrl_idx.reshape(-1)) \
            + structure_criterion(q.clamp(min=1e-6).log().reshape(-1, model.nslot), trg_idx.reshape(-1))

        if sos is None:
            mask = torch.ones_like(targets)
        else:
            mask = (targets != sos).float()

        lm_loss = (lm_loss * mask.reshape(-1)).sum() / mask.sum()
        structure_loss = (structure_loss * mask.reshape(-1)).sum() / mask.sum()


        combined_loss = lm_loss * args.lm_loss + structure_loss * args.structure_loss
        combined_loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        if args.clip > 0:
            torch.nn.utils.clip_grad_norm_(params, args.clip)
        optimizer.step()

        total_lm_loss += lm_loss.data
        total_structure_loss += structure_loss.data
        total_structure_acc += ((p.max(dim=-1)[1] == ctrl_idx).float() * mask).sum() / mask.sum()
        total_output_acc += ((q.max(dim=-1)[1] == trg_idx).float() * mask).sum() / mask.sum()

        optimizer.param_groups[0]['lr'] = lr2
        if batch % args.log_interval == 0 and batch > 0:
            cur_loss = total_lm_loss.item() / args.log_interval
            cur_structure_loss = total_structure_loss.item() / args.log_interval
            cur_structure_acc = total_structure_acc.item() / args.log_interval
            cur_output_acc = total_output_acc.item() / args.log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f} | structure loss {:5.2f} | '
                  '1 look-ahead {:3.3f} | 0 look-ahead {:3.3f}'.format(
                epoch, batch, train_data.size(1) // args.bptt, optimizer.param_groups[0]['lr'],
                              elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss),
                cur_structure_loss, cur_structure_acc, cur_output_acc
            ))
            total_lm_loss = 0
            total_structure_loss = 0
            total_structure_acc = 0
            total_output_acc = 0
            start_time = time.time()
        ###
        batch += 1
        i += seq_len


# Loop over epochs.
lr = args.lr
best_val_loss = []
stored_loss = 100000000

# At any point you can hit Ctrl + C to break out of training early.
if not args.test_only:
    try:
        for epoch in range(1, args.epochs + 1):
            epoch_start_time = time.time()
            if args.truth_rate is None:
                train(0)
            else:
                train(1 / math.pow(args.truth_rate, epoch - 1))
            if 't0' in optimizer.param_groups[0]:
                tmp = {}
                for prm in model.parameters():
                    tmp[prm] = prm.data.clone()
                    prm.data = optimizer.state[prm]['ax'].clone()

                val_loss2, val_structure_loss2, val_structure_acc2 \
                    = evaluate(valid_data, valid_dist, eval_batch_size)
                print('-' * 89)
                print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
                      'valid ppl {:8.2f} | valid bpc {:8.3f} | valid structure {:5.2f} | valid 1 look ahead {:5.2f}'.format(
                    epoch, (time.time() - epoch_start_time), val_loss2,
                    math.exp(val_loss2), val_loss2 / math.log(2), val_structure_loss2, val_structure_acc2))
                print('-' * 89)

                # combined_val_loss2 = val_loss2 * args.lm_loss + val_structure_loss2 * args.structure_loss

                if val_loss2 < stored_loss:
                    model_save(args.save)
                    print('Saving Averaged!')
                    stored_loss = val_loss2

                for prm in model.parameters():
                    prm.data = tmp[prm].clone()

                if epoch == args.finetuning:
                    print('Switching to finetuning')
                    optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr,
                                                 t0=0, lambd=0.,
                                                 weight_decay=args.wdecay)
                    best_val_loss = []

                if (epoch > args.finetuning and
                        len(best_val_loss) > args.nonmono and
                        val_loss2 > min(best_val_loss[:-args.nonmono])):
                    print('Done!')
                    import sys

                    sys.exit(1)

            else:
                val_lm_loss, val_structure_loss, val_structure_acc, val_output_acc \
                    = evaluate(valid_data, valid_dist, eval_batch_size)
                print('-' * 89)
                print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
                      'valid ppl {:8.2f} | valid structure {:5.2f} | '
                      'valid 1 look ahead {:5.3f} | valid 0 look ahead {:5.3f}'.format(
                    epoch, (time.time() - epoch_start_time), val_lm_loss,
                    math.exp(val_lm_loss), val_structure_loss, val_structure_acc, val_output_acc))
                print('-' * 89)

                if args.lm_loss == 1:
                    val_loss = val_lm_loss
                elif args.structure_loss == 1:
                    val_loss = val_structure_loss
                else:
                    val_loss = args.lm_loss * val_lm_loss + args.structure_loss * val_structure_loss

                if val_loss < stored_loss:
                    model_save(args.save)
                    print('Saving model (new best validation)')
                    stored_loss = val_loss

                if args.optimizer == 'adam':
                    scheduler.step(val_loss)

                if (args.optimizer == 'sgd' and
                        't0' not in optimizer.param_groups[0] and
                        (len(best_val_loss) > args.nonmono and
                         val_loss > min(best_val_loss[:-args.nonmono]))):
                    # # wenyu: debug only
                    # if True:
                    print('Switching to ASGD')

                    optimizer = torch.optim.ASGD(
                        model.parameters(), lr=args.lr, t0=0, lambd=0.,
                        weight_decay=args.wdecay)

                if epoch in args.when:
                    print('Saving model before learning rate decreased')
                    model_save('{}.e{}'.format(args.save, epoch))
                    print('Dividing learning rate by 10')
                    optimizer.param_groups[0]['lr'] /= 10.

                best_val_loss.append(val_loss)

            print("PROGRESS: {}%".format((epoch / args.epochs) * 100))

    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early')

# Load the best saved model.
model_load(args.save)

# Run on test data.
print('=' * 89)

if valid_data is not None:
    valid_loss, valid_structure_loss, valid_structure_acc, val_output_acc \
        = evaluate(valid_data, valid_dist, eval_batch_size, print_dist=False)
    print('| End of training | valid ppl {:8.2f} | valid structure {:5.2f} '
          '| valid 1 look ahead {:5.3f} | valid 0 look ahead {:5.3f}'.format(
        math.exp(valid_loss), valid_structure_loss, valid_structure_acc, val_output_acc))

if test_data is not None:
    test_loss, test_structure_loss, test_structure_acc, test_output_acc \
        = evaluate(test_data, test_dist, test_batch_size, print_dist=False)
    print('| End of training | test ppl {:8.2f} | test structure {:5.2f} '
          '| test 1 look ahead {:5.3f} | test 0 look ahead {:5.3f}'.format(
        math.exp(test_loss), test_structure_loss, test_structure_acc, test_output_acc))

print('=' * 89)

if not args.test_only:
    report_results([dict(
        name='valid_ppl',
        type='objective',
        value=math.exp(valid_loss))])
