#!/usr/bin/env python
#-*- coding: utf-8 -*-

"""
Function: plot latent vairables
"""


import torch
import sklearn
from sklearn import manifold
import scipy
import random
import numpy as np
import matplotlib 
import matplotlib.pyplot as plt
plt.switch_backend('Agg')


def latentPlot(z_p, z_q, name):

    """
    plot latent variables z's
    """
    #method = 'MDS'
    #method = 'tSNE'
    method = 'isomap'
    

    
    #assert len(z_p) == len(z_r)
    assert len(z_p) == len(z_q)

    length_in = len(z_p)
    sample_num = len(z_p)-1

    # print("sample_num: ", sample_num)
    # sample_list = set()
    # sample_list = set(random.sample(range(1, length_in), sample_num))
    #
    # z_c_list = []
    # z_p_list = []
    # z_q_list = []
    # z_r_list = []
    #
    # for s in sample_list:
    #     z_c_temp = z_c[s].tolist()
    #     z_p_temp = z_p[s].tolist()
    #     z_q_temp = z_q[s].tolist()
    #     z_r_temp = z_r[s].tolist()
    #
    #
    #     z_c_list.append(z_c_temp)
    #     z_p_list.append(z_p_temp)
    #     z_q_list.append(z_q_temp)
    #     z_r_list.append(z_r_temp)
    #
    #
    #
    #
    # print("z_c list shape:", np.array(z_c_list).shape)
    # print("z_p list shape:", np.array(z_p_list).shape)
    # print("z_q list shape:", np.array(z_q_list).shape)
    # print("z_r list shape:", np.array(z_r_list).shape)



    latent = np.concatenate([z_p, z_q], axis=0)
    #print("latent list shape:", np.array(latent).shape)
    #print("y1:", y1_list)
        
    if method == 'tSNE':
        approx = manifold.TSNE(init='pca', verbose=1).fit_transform(latent)
        #approx2 = manifold.TSNE(init='pca', verbose=1).fit_transform(y2_list)
    elif method == 'MDS':
        approx = manifold.MDS(n_components=2, verbose=1, max_iter=1000, n_init=1).fit_transform(latent)
        #approx2 = manifold.MDS(n_components=2, verbose=1, max_iter=1000, n_init=1).fit_transform(y2_list)
    elif method == 'isomap':
        approx = manifold.Isomap().fit_transform(latent)
        #approx2 = manifold.Isomap().fit_transform(y2_list)
    
    #print("approx:", approx)
    #print("approx size:", np.array(approx).shape)
    
    cluster_num = sample_num   
 
    f, ax = plt.subplots()
    ax.plot(approx[:cluster_num,0], approx[:cluster_num,1], 'r>', alpha=0.2)
    ax.plot(approx[cluster_num:cluster_num*2,0], approx[cluster_num:cluster_num*2,1], 'b^', alpha=0.2)
    #ax.plot(approx[cluster_num*2:cluster_num*3, 0], approx[cluster_num*2:cluster_num*3,1], 'g<', alpha=0.2)
    #ax.plot(approx[cluster_num*3:cluster_num*4, 0], approx[cluster_num*3:cluster_num*4, 1], 'y.', alpha=0.2)

    
    #ax.plot(approx1[:,0], approx1[:,1], 'r.', alpha=0.2)
    #ax.plot(approx2[:,0], approx2[:,1], 'b.', alpha=0.2)

    plt.title("latent variable cluster map")
    #plt.savefig('../output'+'/%s_%s.png'%(method, str(sample_num)))
    plt.savefig('embed'+name+'.png')
    plt.title(' ')

    ax.set_xticks([])
    ax.set_yticks([])
    plt.show()
    print('picture generatead!')
    
