# -*- coding: utf-8 -*-
"""
Created on Wed Sep 15 01:15:34 2021

@author: Anonymous
"""

import numpy as np
import functions as fs

data = 'SST-2'
data_path = './data/kfull/'+data

pre_dev_confs_over_seeds_set = []
pre_ulb_probs_over_seeds_set = []
pre_dev_probs_over_seeds_set = []
pre_dev_perfs_over_seeds_set = []
pre_pe_epochs_over_seeds_set = []
all_records_over_seeds_set = []
all_u_probs_over_seeds_set = []

for si_str in ['_si0to4','_si4to10']:
    pre_dev_confs_over_seeds_set.append(fs.load_all(data_path+'/logs/main_pre_dev_confs'+si_str+'.bin'))
    pre_ulb_probs_over_seeds_set.append(fs.load_all(data_path+'/logs/main_pre_ulb_probs'+si_str+'.bin'))
    pre_dev_probs_over_seeds_set.append(fs.load_all(data_path+'/logs/main_pre_dev_probs'+si_str+'.bin'))
    pre_dev_perfs_over_seeds_set.append(fs.load_all(data_path+'/logs/main_pre_dev_perfs'+si_str+'.bin'))
    pre_pe_epochs_over_seeds_set.append(fs.load_all(data_path+'/logs/main_pre_pe_epochs'+si_str+'.bin'))
    all_records_over_seeds_set.append(fs.load_all(data_path+'/logs/main_all_records'+si_str+'.bin'))
    all_u_probs_over_seeds_set.append(fs.load_all(data_path+'/logs/main_ulb_records'+si_str+'.bin'))
    
    print(' ')
    print(si_str)
    print(np.shape(pre_dev_confs_over_seeds_set[-1]))
    print(np.shape(pre_ulb_probs_over_seeds_set[-1]))
    print(np.shape(pre_dev_probs_over_seeds_set[-1]))
    print(np.shape(pre_dev_perfs_over_seeds_set[-1]))
    print(np.shape(pre_pe_epochs_over_seeds_set[-1]))
    print(np.shape(all_records_over_seeds_set[-1]))
    print(np.shape(all_u_probs_over_seeds_set[-1]))

pre_dev_confs_over_seeds = np.concatenate(pre_dev_confs_over_seeds_set,axis=0)
pre_ulb_probs_over_seeds = np.concatenate(pre_ulb_probs_over_seeds_set,axis=0)
pre_dev_probs_over_seeds = np.concatenate(pre_dev_probs_over_seeds_set,axis=0)
pre_dev_perfs_over_seeds = np.concatenate(pre_dev_perfs_over_seeds_set,axis=0)
pre_pe_epochs_over_seeds = np.concatenate(pre_pe_epochs_over_seeds_set,axis=0)
all_records_over_seeds = np.concatenate(all_records_over_seeds_set,axis=0)
all_u_probs_over_seeds = np.concatenate(all_u_probs_over_seeds_set,axis=0)

print(' ')
print('=============')
print(np.shape(pre_dev_confs_over_seeds))
print(np.shape(pre_ulb_probs_over_seeds))
print(np.shape(pre_dev_probs_over_seeds))
print(np.shape(pre_dev_perfs_over_seeds))
print(np.shape(pre_pe_epochs_over_seeds))
print(np.shape(all_records_over_seeds))
print(np.shape(all_u_probs_over_seeds))

si_str = '_si0to10'
fs.dump_all(pre_dev_confs_over_seeds, data_path+'/logs/main_pre_dev_confs'+si_str+'.bin')
fs.dump_all(pre_ulb_probs_over_seeds, data_path+'/logs/main_pre_ulb_probs'+si_str+'.bin')
fs.dump_all(pre_dev_probs_over_seeds, data_path+'/logs/main_pre_dev_probs'+si_str+'.bin')
fs.dump_all(pre_dev_perfs_over_seeds, data_path+'/logs/main_pre_dev_perfs'+si_str+'.bin')
fs.dump_all(pre_pe_epochs_over_seeds, data_path+'/logs/main_pre_pe_epochs'+si_str+'.bin')
fs.dump_all(all_records_over_seeds, data_path+'/logs/main_all_records'+si_str+'.bin')
fs.dump_all(all_u_probs_over_seeds, data_path+'/logs/main_ulb_records'+si_str+'.bin')
print('done')