from translate_utils import *
import argparse


def main():
	parser = argparse.ArgumentParser()
	parser.add_argument('--input', type=str, default='./directional_implications/directional_implications.tsv')
	parser.add_argument('--txt_suff', type=str, default='_in_lines_%d.txt')
	parser.add_argument('--mapping_suff', type=str, default='_mapping.json')
	parser.add_argument('--trans_txt_suff', type=str, default='_in_lines_translated_%d.txt')

	parser.add_argument('--pos_in_suff', type=str, default='_pos_input.json')
	parser.add_argument('--pos_out_suff', type=str, default='_pos_output.json')
	parser.add_argument('--trans_formatted_suff', type=str, default='_translated.tsv')
	parser.add_argument('--raw_out_suff', type=str, default='_translated_raw.tsv')
	parser.add_argument('--rel_levy_mapping_suff', type=str, default='_rellevy_mapping.txt')
	parser.add_argument('--typed_rel_suff', type=str, default='_all_rels.txt')

	parser.add_argument('--trans_split_suff', type=str, default='_%s_translated.tsv')
	parser.add_argument('--raw_split_suff', type=str, default='_%s_translated_raw.tsv')
	parser.add_argument('--typed_rel_split_suff', type=str, default='_%s_rels.txt')
	parser.add_argument('--rel_levy_mapping_split_suff', type=str, default='_%s_rellevy_mapping.txt')
	parser.add_argument('--split_mapping_suff', type=str, default='_split_mapping.json')

	parser.add_argument('--jia_parser_input_fn', type=str, default='/Users/teddy/PycharmProjects/open-entity-relation-extraction/data/levy_%s.json')
	parser.add_argument('--jia_parser_output_fn', type=str, default='/Users/teddy/PycharmProjects/open-entity-relation-extraction/data/levy_%s_result.json')

	parser.add_argument('--num_splits', type=int, default=16)
	parser.add_argument('--mode', type=str, default='DUMP', help='DUMP/SORT')
	parser.add_argument('--shuffle', type=int, default=0)
	parser.add_argument('--exhaust', type=int, default=0)
	parser.add_argument('--fine_only', type=int, default=0)
	parser.add_argument('--add_crossed', type=int, default=0)
	args = parser.parse_args()
	assert args.input[-4:] == '.tsv'
	root = args.input[:-4]
	subset_name = root.split('/')[-1]
	args.add_crossed = True if args.add_crossed > 0 else False
	args.exhaust = True if args.exhaust > 0 else False
	args.fine_only = True if args.fine_only > 0 else False

	if args.mode == 'DUMP':
		print(f"Dumping {args.input} into text file and mapping json!")
		txt_fn = root+args.txt_suff
		mapping_fn = root+args.mapping_suff
		trans_fn = root+args.trans_txt_suff
		convert_tsv_to_txt(args.input, txt_fn, mapping_fn, trans_fn, args.num_splits)
	elif args.mode == 'SORT':
		trans_fn = root+args.trans_txt_suff
		mapping_fn = root + args.mapping_suff
		postag_in_fn = root + args.pos_in_suff
		postag_nosame_in_fn = root + '_nosame' + args.pos_in_suff
		print(f"Sorting {trans_fn} back into SVO structures!")
		mapped = merge_txt_to_dict(trans_fn, mapping_fn, args.num_splits)
		contruct_translated_doc_for_postag(mapped, args.input, postag_in_fn, postag_nosame_in_fn)
	elif args.mode == 'SORTBT': # sorting for back-translation
		corpus_name = root.split('/')[1]
		backtransed_fn = f'levyholts_{corpus_name}_bt_gen.json'
		mapping_fn = root + args.mapping_suff
		out_fn = f'{corpus_name}_backtranslated.tsv'
		outrel_fn = f'{corpus_name}_backtranslated_rels.txt'
		refrel_fn = f'/Users/teddy/eclipse-workspace/entgraph_eval/gfiles/ent/{corpus_name}_rels.txt'
		print(f"Sorting {backtransed_fn} into TSV file!")
		mapped = merge_backtranslation(backtransed_fn, mapping_fn)
		construct_tsv_backtranslation(mapped, args.input, out_fn, outrel_fn, refrel_fn)
	elif args.mode == 'SORT2':
		root_extended = root
		if args.exhaust:
			root_extended += '_exhaust'
		if args.fine_only:
			root_extended += '_fineonly'

		postag_out_fn = root + args.pos_out_suff
		trans_formated_fn = root_extended + args.trans_formatted_suff
		trans_raw_out_fn = root_extended + args.raw_out_suff
		rel_levy_mapping_fn = root_extended + args.rel_levy_mapping_suff
		construct_translated_tsv(postag_out_fn, trans_formated_fn, trans_raw_out_fn, rel_levy_mapping_fn, amend=True,
								 fine_only=args.fine_only, exhaust=args.exhaust, add_crossed=args.add_crossed)
	elif args.mode == 'SORT2BSL':
		postag_out_fn = root + args.pos_out_suff
		trans_formated_fn = root + '_bsl' + args.trans_formatted_suff
		trans_raw_out_fn = root + '_bsl' + args.raw_out_suff
		rel_levy_mapping_fn = root + '_bsl' + args.rel_levy_mapping_suff
		construct_translated_tsv(postag_out_fn, trans_formated_fn, trans_raw_out_fn, rel_levy_mapping_fn, amend=False,
								 fine_only=args.fine_only, exhaust=args.exhaust, add_crossed=args.add_crossed)
	elif args.mode == 'SORTJIA':
		root_extended = root
		if args.exhaust:
			root_extended += '_exhaust'
		trans_fn = root + args.trans_txt_suff
		mapping_fn = root + args.mapping_suff
		parsing_input_fn = args.jia_parser_input_fn % subset_name

		mapped = merge_txt_to_dict(trans_fn, mapping_fn, args.num_splits)
		with open(parsing_input_fn, 'w', encoding='utf8') as input_fp:
			json.dump(mapped, fp=input_fp, indent=4, ensure_ascii=False)

		# Now, go to "/Users/teddy/PycharmProjects/open-entity-relation-extraction" to parse the levy dev/test sets
		# with Jia et. al. parser before getting back to SORT2JIA.

	elif args.mode == 'SORT2JIA':
		root_extended = root
		if args.exhaust:
			root_extended += '_exhaust'

		parsing_output_fn = args.jia_parser_output_fn % subset_name
		trans_formated_fn = root_extended + '_jia' + args.trans_formatted_suff
		trans_raw_out_fn = root_extended + '_jia' + args.raw_out_suff
		rel_levy_mapping_fn = root_extended + '_jia' + args.rel_levy_mapping_suff

		with open(parsing_output_fn, 'r', encoding='utf8') as result_fp:
			mapped_parsed = json.load(result_fp)

		construct_translated_tsv_jia(mapped_parsed, args.input, trans_formated_fn, trans_raw_out_fn, rel_levy_mapping_fn, exhaust=args.exhaust)

	elif args.mode == 'DEVTEST':
		trans_fn = root+args.trans_formatted_suff
		raw_fn = root+args.raw_out_suff
		rels_fn = root+args.typed_rel_suff
		rellevy_mapping_fn = root + args.rel_levy_mapping_suff
		trans_ofn = root+args.trans_split_suff
		raw_ofn = root+args.raw_split_suff
		rels_ofn = root+args.typed_rel_split_suff
		rellevy_mapping_ofn = root+args.rel_levy_mapping_split_suff
		split_mapping_fn = root+args.split_mapping_suff
		args.shuffle = True if args.shuffle > 0 else False
		split_eval_to_devtest(trans_fn, raw_fn, rels_fn, rellevy_mapping_fn, trans_ofn, raw_ofn, rels_ofn, rellevy_mapping_ofn, shuffle=args.shuffle, split_mapping_fn=split_mapping_fn)
	else:
		raise AssertionError


if __name__ == '__main__':
	main()
