import torch
import clip
import random as rd
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix
from PIL import Image
from torch import nn, optim
from tqdm import tqdm
from argparse import ArgumentParser
from lib.snliveloader import SNLIVE

parser = ArgumentParser("SNLI-VE")
parser.add_argument("--bitfit", action="store_true")
parser.add_argument("--clip", type=str, default="RN50x16")  # ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']
parser.add_argument("--clip_size", type=int, default=768)
parser.add_argument("--mount", type=str, default="/mnt")
parser.add_argument("--save_to", type=str, default="/DATA/SNLI_MODELS_BEST")
parser.add_argument("--lr", type=float, default=2e-5)
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--epoch", type=int, default=3)
parser.add_argument("--identifier", type=str, default="1024_128_3")
parser.add_argument("--dropout", type=float, default=0)
parser.add_argument("--size1", type=int, default=1024)
parser.add_argument("--size2", type=int, default=128)

parser.add_argument("--start", type=int, default=1)
parser.add_argument("--end", type=int, default=1)

args = parser.parse_args()

class MLP(nn.Module):
	def __init__(self, input_size, num_class):
		super(MLP, self).__init__()
		self.linear = nn.Sequential(
			nn.Linear(input_size, args.size1),
			nn.BatchNorm1d(args.size1),
			nn.Dropout(args.dropout),
			nn.ReLU(inplace=True),
			nn.Linear(args.size1, args.size2),
			nn.BatchNorm1d(args.size2),
			nn.Dropout(args.dropout),
			nn.ReLU(inplace=True),
			nn.Linear(args.size2, num_class),
		)
	def forward(self, x):
		return self.linear(x)


def fusion(vector1, vector2):
	vsum = vector1 + vector2
	vdif = vector1 - vector2
	vcat = torch.cat((vector1,vector2),1)
	vmul = vector1 * vector2
	return torch.cat((vsum,vdif,vcat,vmul),1)

def bilinear_fusion(vector1, vector2, maxpool=None, avgpool=None):
	v1_mat = vector1.unsqueeze(2)
	v2_mat = vector2.unsqueeze(1)
	bilinear_mat = torch.matmul(v1_mat,v2_mat)
	maxpoolembedding = maxpool(bilinear_mat).squeeze(-1)
	avgpoolembedding = avgpool(bilinear_mat).squeeze(-1)
	return torch.cat((vector1,vector2,maxpoolembedding,avgpoolembedding),1)


snlive_dataset = SNLIVE(root_path='/home/data/datasets/SNLI-VE/data')

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load(args.clip, device, jit=False)
# text_encoder, _ = clip.load(args.clip, device, jit=False)
clip_model = clip_model.float()
# text_encoder = text_encoder.float()

OutputLayer = MLP(args.clip_size*5, 3).to(device)

print(f'Evaluating {args.identifier}...\n')

# 0: neutral, 1: entailment, 2: contradiction
snlive_valid = snlive_dataset.valid_set
snlive_test = snlive_dataset.test_set


for eid in range(args.start,args.end + 1):
	checkpoint = torch.load(f"{args.mount}{args.save_to}/{args.identifier}")
	OutputLayer.load_state_dict(checkpoint['model_state_dict'])
	clip_model.load_state_dict(checkpoint['clip_state_dict'])

	OutputLayer.eval()
	# text_encoder.eval()

	# SNLI-VE VALID SET RESULTS
	ground_truth = []
	prediction, prediction_2 = [], []
	batch_idx = 0
	while batch_idx < len(snlive_valid):
		train_batch = snlive_valid[batch_idx:batch_idx + args.batch_size] if batch_idx + args.batch_size < len(
			snlive_valid) else snlive_valid[batch_idx:len(snlive_valid) - 1]
		image_batch, sent1_batch, sent2_batch, label_batch = [], [], [], []

		with torch.no_grad():
			for item in train_batch:
				image_batch.append(Image.open(item['Image_path']))
				sent1_batch.append(item['sentence1'])
				sent2_batch.append(item['sentence2'])
				label_batch.append(item['label_id'])
			batch_idx += args.batch_size
			ground_truth += label_batch

			sent1_ids = clip.tokenize(sent1_batch, truncate=True).to(device)
			sent2_ids = clip.tokenize(sent2_batch, truncate=True).to(device)
			image_ids = torch.stack([preprocess(img) for img in image_batch],dim=0).to(device)

			sent1_features = clip_model.encode_text(sent1_ids).to(device)
			sent2_features = clip_model.encode_text(sent2_ids).to(device)
			image_features = clip_model.encode_image(image_ids).to(device)

			sent1_embedding = sent1_features / sent1_features.norm(dim=-1, keepdim=True)
			sent2_embedding = sent2_features / sent2_features.norm(dim=-1, keepdim=True)
			image_embedding = image_features / image_features.norm(dim=-1, keepdim=True)

			text_embedding = fusion(sent1_embedding*0,sent2_embedding)
			logits = OutputLayer(text_embedding)

			image_text_embedding = fusion(image_embedding*0,sent2_embedding)
			image_logits = OutputLayer(image_text_embedding)

			prediction += torch.argmax(logits,dim=-1).detach().cpu().data.numpy().tolist()
			prediction_2 += torch.argmax(image_logits,dim=-1).detach().cpu().data.numpy().tolist()

	text_acc_dev = round(accuracy_score(ground_truth, prediction),4) * 100
	image_acc_dev = round(accuracy_score(ground_truth, prediction_2),4) * 100


	# SNLI-VE Test SET RESULTS
	ground_truth = []
	prediction, prediction_2 = [], []
	batch_idx = 0
	while batch_idx < len(snlive_test):
		train_batch = snlive_test[batch_idx:batch_idx + args.batch_size] if batch_idx + args.batch_size < len(
			snlive_test) else snlive_test[batch_idx:len(snlive_test) - 1]
		image_batch, sent1_batch, sent2_batch, label_batch = [], [], [], []

		with torch.no_grad():
			for item in train_batch:
				image_batch.append(Image.open(item['Image_path']))
				sent1_batch.append(item['sentence1'])
				sent2_batch.append(item['sentence2'])
				label_batch.append(item['label_id'])
			batch_idx += args.batch_size
			ground_truth += label_batch

			sent1_ids = clip.tokenize(sent1_batch, truncate=True).to(device)
			sent2_ids = clip.tokenize(sent2_batch, truncate=True).to(device)
			image_ids = torch.stack([preprocess(img) for img in image_batch], dim=0).to(device)

			sent1_features = clip_model.encode_text(sent1_ids).to(device)
			sent2_features = clip_model.encode_text(sent2_ids).to(device)
			image_features = clip_model.encode_image(image_ids).to(device)

			sent1_embedding = sent1_features / sent1_features.norm(dim=-1, keepdim=True)
			sent2_embedding = sent2_features / sent2_features.norm(dim=-1, keepdim=True)
			image_embedding = image_features / image_features.norm(dim=-1, keepdim=True)

			text_embedding = fusion(sent1_embedding*0,sent2_embedding)
			logits = OutputLayer(text_embedding)

			image_text_embedding = fusion(image_embedding*0,sent2_embedding)
			image_logits = OutputLayer(image_text_embedding)

			prediction += torch.argmax(logits, dim=-1).detach().cpu().data.numpy().tolist()
			prediction_2 += torch.argmax(image_logits, dim=-1).detach().cpu().data.numpy().tolist()

	text_acc_test = round(accuracy_score(ground_truth, prediction), 4) * 100
	image_acc_test = round(accuracy_score(ground_truth, prediction_2), 4) * 100

	print(f"  Epoch {eid}, "
		  f" Text Test Acc: {text_acc_dev}/{text_acc_test},"
		  f" Image Test Acc: {image_acc_dev}/{image_acc_test},"
		  f" Image Test Confusion Matrix:\n{confusion_matrix(ground_truth, prediction_2)} \n")
