import tensorflow as tf

def model_BiLSTM_multi(x, lstmUnitNum, layer_num, forget_bias, input_keep_prob, output_keep_prob, attn_length=-1):
    cell_fw = []
    cell_bw = []
    for i in range(layer_num):
        lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(lstmUnitNum, forget_bias=forget_bias, state_is_tuple=True,
                                                    activation=None, reuse=None)
        lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(lstmUnitNum, forget_bias=forget_bias, state_is_tuple=True,
                                                    activation=None, reuse=None)
        if attn_length != -1:
            lstm_fw_cell = tf.contrib.rnn.AttentionCellWrapper(lstm_fw_cell, attn_length, state_is_tuple=True)
            lstm_bw_cell = tf.contrib.rnn.AttentionCellWrapper(lstm_bw_cell, attn_length, state_is_tuple=True)
        lstm_fw_cell = tf.contrib.rnn.DropoutWrapper(cell=lstm_fw_cell, input_keep_prob=input_keep_prob, output_keep_prob=output_keep_prob)
        lstm_bw_cell = tf.contrib.rnn.DropoutWrapper(cell=lstm_bw_cell, input_keep_prob=input_keep_prob, output_keep_prob=output_keep_prob)
        cell_fw.append(lstm_fw_cell)
        cell_bw.append(lstm_bw_cell)

    Mul_cell_fw = tf.contrib.rnn.MultiRNNCell(cell_fw, state_is_tuple=True)
    Mul_cell_bw = tf.contrib.rnn.MultiRNNCell(cell_bw, state_is_tuple=True)

    # The input 'x' is [batch_size, n_steps, (n_input)n_dimensions]
    hiddens, state = tf.nn.bidirectional_dynamic_rnn(Mul_cell_fw, Mul_cell_bw, inputs=x, dtype=tf.float32)
    # fw:hiddens[0], bw:hiddens[1]
    hiddens = tf.concat(hiddens, axis=2)
    # The hiddens is [batch_size, n_steps, lstmUnitNum*2]
    hiddens = tf.transpose(hiddens, [1, 0, 2])
    return hiddens

def model_BiLSTM_stack(x, lstmUnitNum, layer_num, forget_bias, input_keep_prob, output_keep_prob, attn_length=-1):
    # lstmCell = tf.contrib.rnn.BasicLSTMCell(lstmUnitNum)
    # lstmCell = tf.contrib.rnn.DropoutWrapper(cell=lstmCell, output_keep_prob=output_keep_prob)
    # value, _ = tf.nn.dynamic_rnn(lstmCell, x, dtype=tf.float32)
    # value = tf.transpose(value, [1, 0, 2])
    # return value
    cell_fw = []
    cell_bw = []
    for i in range(layer_num):
        lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(lstmUnitNum, forget_bias=forget_bias, state_is_tuple=True,
                                                    activation=None, reuse=None)
        lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(lstmUnitNum, forget_bias=forget_bias, state_is_tuple=True,
                                                    activation=None, reuse=None)
        if attn_length != -1:
            lstm_fw_cell = tf.contrib.rnn.AttentionCellWrapper(lstm_fw_cell, attn_length, state_is_tuple=True)
            lstm_bw_cell = tf.contrib.rnn.AttentionCellWrapper(lstm_bw_cell, attn_length, state_is_tuple=True)
        lstm_fw_cell = tf.contrib.rnn.DropoutWrapper(cell=lstm_fw_cell, input_keep_prob=input_keep_prob, output_keep_prob=output_keep_prob)
        lstm_bw_cell = tf.contrib.rnn.DropoutWrapper(cell=lstm_bw_cell, input_keep_prob=input_keep_prob, output_keep_prob=output_keep_prob)
        cell_fw.append(lstm_fw_cell)
        cell_bw.append(lstm_bw_cell)

    # The input 'x' is [batch_size, n_steps, (n_input)n_dimensions]
    hiddens, output_state_fw, output_state_bw = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs=x, dtype=tf.float32)
    # fw:hiddens[0], bw:hiddens[1]
    hiddens = tf.concat(hiddens, axis=2)
    # The hiddens is [batch_size, n_steps, lstmUnitNum*2]
    hiddens = tf.transpose(hiddens, [1, 0, 2])
    return hiddens
