# -*- coding: utf-8 -*-
"""
"""

import numpy as np
import json
import utils
import os
import argparse

np.random.seed(42)

def split_data(data_list, output_dir, train_prnt=0.70, dev_prnt=0.15, test_prnt=0.15):
    
    data_size = len(data_list)
    indices = list( range(data_size))    
    assert train_prnt + dev_prnt + test_prnt == 1.0
    
    # train
    train_inds = list( np.random.choice(indices, int(np.round(data_size*train_prnt)), replace=False ) )
    rem_inds   = list( set(indices) - set(train_inds))
    
    # dev
    dev_inds = list( np.random.choice(rem_inds, int(np.round(data_size*dev_prnt)), replace=False ) )
    
    # test
    test_inds   = list( set(rem_inds) - set(dev_inds))
    
    
    # cross-check
    assert data_size == len(train_inds+dev_inds+test_inds) == len( set(train_inds+dev_inds+test_inds) )
    
    
    # create files
    train_data = [data_list[i] for i in train_inds]
    dev_data   = [data_list[i] for i in dev_inds]
    test_data  = [data_list[i] for i in test_inds]
    
    print(f"data_size: {data_size}, train_data: {len(train_data)}, dev_data: {len(dev_data)}, test_data: {len(test_data)}")
    
    utils.write_data(train_data, os.path.join(output_dir, "train.jsonl"))
    utils.write_data(dev_data, os.path.join(output_dir, "dev.jsonl"))
    utils.write_data(test_data, os.path.join(output_dir, "test.jsonl"))
    
    

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path", help= "path to the input data file", required=True)
    parser.add_argument("--output_dir", help="path to the output directory")
    args  = parser.parse_args()
    
    input_path = args.input_path
    output_dir = args.output_dir
     
    # Step 1: Read input
    data_list = utils.read_data(input_path)  #data_list: list of json object
    
    # Step 2: prepare data splits
    split_data(data_list, output_dir, train_prnt=0.70, dev_prnt=0.15, test_prnt=0.15)
    
    