{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)\n",
      "03:23:37:BUS-stop:INFO: ***Logging start***\n",
      "03:23:37:BUS-stop:INFO: os.environ['CUDA_VISIBLE_DEVICES'] = 0\n",
      "03:23:37:BUS-stop:INFO: devices = ['/device:GPU:0']\n",
      "03:23:37:BUS-stop:INFO: Number of devices: 1\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "import logging\n",
    "import argparse\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import preliminary\n",
    "import pt_modeler\n",
    "import preprocessing as pp\n",
    "\n",
    "from collections import deque\n",
    "from pt_modeler import ConstructPtModeler\n",
    "from huggingface_utils import MODELS\n",
    "from scipy.special import softmax\n",
    "from sklearn.utils import shuffle\n",
    "from sklearn.metrics import log_loss, f1_score, accuracy_score\n",
    "from scipy.spatial.distance import cosine,euclidean\n",
    "\n",
    "logger = logging.getLogger('BUS-stop')\n",
    "formatter = logging.Formatter('%(asctime)s:%(name)s:%(levelname)s: %(message)s',\"%H:%M:%S\")\n",
    "logger.setLevel(logging.DEBUG)\n",
    "\n",
    "fhandler = logging.FileHandler(filename='./logs/run-cell-by-cell.log', mode='w')\n",
    "fhandler.setFormatter(formatter)\n",
    "fhandler.setLevel(logging.INFO)\n",
    "logger.addHandler(fhandler)\n",
    "\n",
    "consoleHandler = logging.StreamHandler(sys.stdout)\n",
    "consoleHandler.setFormatter(formatter)\n",
    "consoleHandler.setLevel(logging.DEBUG)\n",
    "logger.addHandler(consoleHandler)\n",
    "\n",
    "#Variables for preprocessing\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "GLOBAL_SEED = 0\n",
    "pt_model = \"TFBertModel\"\n",
    "pt_model_checkpoint = \"./params/bert_base/\"\n",
    "\n",
    "for indx, model in enumerate(MODELS):\n",
    "    if model[0].__name__ == pt_model:\n",
    "        TFModel, Tokenizer, Config = MODELS[indx]\n",
    "\n",
    "tokenizer = Tokenizer.from_pretrained(pt_model_checkpoint)\n",
    "\n",
    "devices = []\n",
    "for gpu_num in os.environ[\"CUDA_VISIBLE_DEVICES\"].split(','):\n",
    "    devices.append('/device:GPU:{}'.format(gpu_num))\n",
    "\n",
    "strategy = tf.distribute.MirroredStrategy(devices=devices)\n",
    "gpus = strategy.num_replicas_in_sync\n",
    "\n",
    "logger.info(\"***Logging start***\")\n",
    "logger.info(\"os.environ['CUDA_VISIBLE_DEVICES'] = {}\".format(os.environ['CUDA_VISIBLE_DEVICES']))\n",
    "logger.info(\"devices = {}\".format(devices))\n",
    "logger.info(\"Number of devices: {}\".format(gpus))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "03:23:48:BUS-stop:INFO: *******************\n",
      "03:23:48:BUS-stop:INFO: ***Preprocessing***\n",
      "03:23:48:BUS-stop:INFO: *******************\n",
      "03:23:48:BUS-stop:INFO: In file ./data/SST-2/labeled.tsv, we read 100 samples, \n",
      "03:23:48:BUS-stop:INFO: where the class distribution is {'0': 50, '1': 50, None: 0}.\n",
      "03:23:48:BUS-stop:INFO: In file ./data/SST-2/test_with_gold.tsv, we read 1000 samples, \n",
      "03:23:48:BUS-stop:INFO: where the class distribution is {'0': 200, '1': 800, None: 0}.\n",
      "03:23:48:BUS-stop:INFO: In file ./data/SST-2/unlabeled.tsv, we read 1000 samples, \n",
      "03:23:48:BUS-stop:INFO: where the class distribution is {'0': 0, '1': 0, None: 1000}.\n",
      "03:23:49:BUS-stop:INFO: Labeled//Test//Unlabeled matrix shape = (100, 64) // (1000, 64) // (1000, 64)\n",
      "03:23:49:BUS-stop:INFO: ***Labeled***\n",
      "03:23:49:BUS-stop:INFO: Example 0\n",
      "03:23:49:BUS-stop:INFO: Label 0\n",
      "03:23:49:BUS-stop:INFO: Token ids [  101  1037 18856 18163  6588  7609 14427 17312  1010  1037  2806  1011\n",
      "  2489  6912  1999 16924  1998 26865  1012   102     0     0     0     0\n",
      "     0     0     0     0     0     0     0     0     0     0     0     0\n",
      "     0     0     0     0     0     0     0     0     0     0     0     0\n",
      "     0     0     0     0     0     0     0     0     0     0     0     0\n",
      "     0     0     0     0]\n",
      "03:23:49:BUS-stop:INFO: Tokens ['[CLS]', 'a', 'cl', '##ums', '##ily', 'manufactured', 'exploitation', 'flick', ',', 'a', 'style', '-', 'free', 'exercise', 'in', 'manipulation', 'and', 'mayhem', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']\n",
      "03:23:49:BUS-stop:INFO: Token mask [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
      "03:23:49:BUS-stop:INFO: ***Labeled***\n",
      "03:23:49:BUS-stop:INFO: Example 1\n",
      "03:23:49:BUS-stop:INFO: Label 0\n",
      "03:23:49:BUS-stop:INFO: Token ids [  101  2004 17266  7507 11467  2004  2009  2003  4487 13102  8820  3468\n",
      "  1012   102     0     0     0     0     0     0     0     0     0     0\n",
      "     0     0     0     0     0     0     0     0     0     0     0     0\n",
      "     0     0     0     0     0     0     0     0     0     0     0     0\n",
      "     0     0     0     0     0     0     0     0     0     0     0     0\n",
      "     0     0     0     0]\n",
      "03:23:49:BUS-stop:INFO: Tokens ['[CLS]', 'as', 'sac', '##cha', '##rine', 'as', 'it', 'is', 'di', '##sp', '##osa', '##ble', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']\n",
      "03:23:49:BUS-stop:INFO: Token mask [1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
      "03:23:49:BUS-stop:INFO: ***Test***\n",
      "03:23:49:BUS-stop:INFO: Example 0\n",
      "03:23:49:BUS-stop:INFO: Label 0\n",
      "03:23:49:BUS-stop:INFO: Token ids [ 101 2339 2002 2001 2445 2489 5853 2058 2023 2622 1011 1011 2002 2626\n",
      " 1010 2856 1010 5652 1998 2550 1011 1011 2003 3458 2033 1012  102    0\n",
      "    0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
      "    0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
      "    0    0    0    0    0    0    0    0]\n",
      "03:23:49:BUS-stop:INFO: Tokens ['[CLS]', 'why', 'he', 'was', 'given', 'free', 'reign', 'over', 'this', 'project', '-', '-', 'he', 'wrote', ',', 'directed', ',', 'starred', 'and', 'produced', '-', '-', 'is', 'beyond', 'me', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']\n",
      "03:23:49:BUS-stop:INFO: Token mask [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
      "03:23:49:BUS-stop:INFO: ***Test***\n",
      "03:23:49:BUS-stop:INFO: Example 1\n",
      "03:23:49:BUS-stop:INFO: Label 0\n",
      "03:23:49:BUS-stop:INFO: Token ids [  101  2045  1005  1055  8636  1997 12537  1998  2059  2045  1005  1055\n",
      "  2919  3898 18560  1012  1012  1012  2023  2143 15173  1037  2813  7361\n",
      "  1997  1996  3732  1012   102     0     0     0     0     0     0     0\n",
      "     0     0     0     0     0     0     0     0     0     0     0     0\n",
      "     0     0     0     0     0     0     0     0     0     0     0     0\n",
      "     0     0     0     0]\n",
      "03:23:49:BUS-stop:INFO: Tokens ['[CLS]', 'there', \"'\", 's', 'suspension', 'of', 'disbelief', 'and', 'then', 'there', \"'\", 's', 'bad', 'screen', '##writing', '.', '.', '.', 'this', 'film', 'packs', 'a', 'wall', '##op', 'of', 'the', 'latter', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']\n",
      "03:23:49:BUS-stop:INFO: Token mask [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
      "03:23:49:BUS-stop:INFO: ***Unlabeled***\n",
      "03:23:49:BUS-stop:INFO: Example 0\n",
      "03:23:49:BUS-stop:INFO: Token ids [ 101 2339 2002 2001 2445 2489 5853 2058 2023 2622 1011 1011 2002 2626\n",
      " 1010 2856 1010 5652 1998 2550 1011 1011 2003 3458 2033 1012  102    0\n",
      "    0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
      "    0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
      "    0    0    0    0    0    0    0    0]\n",
      "03:23:49:BUS-stop:INFO: Tokens ['[CLS]', 'why', 'he', 'was', 'given', 'free', 'reign', 'over', 'this', 'project', '-', '-', 'he', 'wrote', ',', 'directed', ',', 'starred', 'and', 'produced', '-', '-', 'is', 'beyond', 'me', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']\n",
      "03:23:49:BUS-stop:INFO: Token mask [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
      "03:23:49:BUS-stop:INFO: ***Unlabeled***\n",
      "03:23:49:BUS-stop:INFO: Example 1\n",
      "03:23:49:BUS-stop:INFO: Token ids [  101  2045  1005  1055  8636  1997 12537  1998  2059  2045  1005  1055\n",
      "  2919  3898 18560  1012  1012  1012  2023  2143 15173  1037  2813  7361\n",
      "  1997  1996  3732  1012   102     0     0     0     0     0     0     0\n",
      "     0     0     0     0     0     0     0     0     0     0     0     0\n",
      "     0     0     0     0     0     0     0     0     0     0     0     0\n",
      "     0     0     0     0]\n",
      "03:23:49:BUS-stop:INFO: Tokens ['[CLS]', 'there', \"'\", 's', 'suspension', 'of', 'disbelief', 'and', 'then', 'there', \"'\", 's', 'bad', 'screen', '##writing', '.', '.', '.', 'this', 'film', 'packs', 'a', 'wall', '##op', 'of', 'the', 'latter', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']\n",
      "03:23:49:BUS-stop:INFO: Token mask [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n"
     ]
    }
   ],
   "source": [
    "task = \"SST-2\"\n",
    "task_path = os.path.join(\"./data\",task)\n",
    "max_seq_length = 64\n",
    "\n",
    "logger.info(\"*******************\")\n",
    "logger.info(\"***Preprocessing***\")\n",
    "logger.info(\"*******************\")\n",
    "task = task.strip()\n",
    "Processor = pp.task_to_processor(task)\n",
    "processor = Processor(task, task_path, tokenizer, max_seq_length) \n",
    "\n",
    "label_list = processor.get_label_list()\n",
    "lab_examples = processor.tsv_to_examples('labeled.tsv')\n",
    "tst_examples = processor.tsv_to_examples('test_with_gold.tsv') \n",
    "unl_examples = processor.tsv_to_examples('unlabeled.tsv')\n",
    "\n",
    "X_lab,y_lab = processor.examples_to_features(lab_examples)\n",
    "X_tst,y_tst = processor.examples_to_features(tst_examples)\n",
    "X_unl,y_unl = processor.examples_to_features(unl_examples)\n",
    "\n",
    "lab_len, unl_len = len(lab_examples), len(unl_examples)\n",
    "num_labels = len(label_list)\n",
    "\n",
    "logger.info('Labeled//Test//Unlabeled matrix shape = {} // {} // {}'.format(\n",
    "    X_lab['input_ids'].shape,X_tst['input_ids'].shape,X_unl['input_ids'].shape))\n",
    "\n",
    "for i in range(2):\n",
    "    logger.info(\"***Labeled***\")\n",
    "    logger.info (\"Example {}\".format(i))\n",
    "    logger.info (\"Label {}\".format(y_lab[i]))\n",
    "    logger.info (\"Token ids {}\".format(X_lab[\"input_ids\"][i]))\n",
    "    logger.info (\"Tokens {}\".format(tokenizer.convert_ids_to_tokens(X_lab[\"input_ids\"][i])))\n",
    "    #logger.info (\"Token type ids {}\".format(X_lab[\"token_type_ids\"][i]))\n",
    "    logger.info (\"Token mask {}\".format(X_lab[\"attention_mask\"][i]))\n",
    "\n",
    "for i in range(2):\n",
    "    logger.info(\"***Test***\")\n",
    "    logger.info (\"Example {}\".format(i))\n",
    "    logger.info (\"Label {}\".format(y_tst[i]))\n",
    "    logger.info (\"Token ids {}\".format(X_tst[\"input_ids\"][i]))\n",
    "    logger.info (\"Tokens {}\".format(tokenizer.convert_ids_to_tokens(X_tst[\"input_ids\"][i])))\n",
    "    #logger.info (\"Token type ids {}\".format(X_tst[\"token_type_ids\"][i]))\n",
    "    logger.info (\"Token mask {}\".format(X_tst[\"attention_mask\"][i]))\n",
    "\n",
    "for i in range(2):\n",
    "    logger.info(\"***Unlabeled***\")\n",
    "    logger.info (\"Example {}\".format(i))\n",
    "    logger.info (\"Token ids {}\".format(X_unl[\"input_ids\"][i]))\n",
    "    logger.info (\"Tokens {}\".format(tokenizer.convert_ids_to_tokens(X_unl[\"input_ids\"][i])))\n",
    "    #logger.info (\"Token type ids {}\".format(X_unl[\"token_type_ids\"][i]))\n",
    "    logger.info (\"Token mask {}\".format(X_unl[\"attention_mask\"][i]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']\n",
      "- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPretraining model).\n",
      "- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "All the weights of TFBertModel were initialized from the PyTorch model.\n",
      "If your task is similar to the task the model of the ckeckpoint was trained on, you can already use TFBertModel for predictions without further training.\n"
     ]
    }
   ],
   "source": [
    "drop_rate = 0.2\n",
    "with strategy.scope():\n",
    "    modeler = ConstructPtModeler(TFModel, Config, pt_model_checkpoint, max_seq_length, \n",
    "                                 num_labels, dense_dropout_prob=drop_rate, word_freeze=True,\n",
    "                                 attention_probs_dropout_prob=drop_rate, hidden_dropout_prob=drop_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "03:24:00:BUS-stop:INFO: ***********************\n",
      "03:24:00:BUS-stop:INFO: ***Preliminary stage***\n",
      "03:24:00:BUS-stop:INFO: ***********************\n",
      "03:24:00:BUS-stop:DEBUG: Labels in the labeled set mixed evenly like this, ['0', '1', '0', '1', '...'].\n",
      "03:24:00:BUS-stop:DEBUG:  \n",
      "03:24:00:BUS-stop:DEBUG: 0-th run / total 3 runs\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "03:24:30:BUS-stop:DEBUG: For base_i 0, [val_acc,val_loss] = [0.76,0.5262].\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "03:24:59:BUS-stop:DEBUG: For base_i 1, [val_acc,val_loss] = [0.68,0.559].\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "03:25:24:BUS-stop:DEBUG: For base_i 2, [val_acc,val_loss] = [0.58,0.6962].\n",
      "03:25:30:BUS-stop:INFO: In 0-th run, [best_val_acc,best_val_loss] = [0.76,0.5262].\n",
      "03:25:30:BUS-stop:DEBUG:  \n",
      "03:25:30:BUS-stop:DEBUG: 1-th run / total 3 runs\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "03:26:01:BUS-stop:DEBUG: For base_i 0, [val_acc,val_loss] = [0.84,0.472].\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "03:26:33:BUS-stop:DEBUG: For base_i 1, [val_acc,val_loss] = [0.76,0.5066].\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "03:27:01:BUS-stop:DEBUG: For base_i 2, [val_acc,val_loss] = [0.64,0.6945].\n",
      "03:27:06:BUS-stop:INFO: In 1-th run, [best_val_acc,best_val_loss] = [0.84,0.472].\n",
      "03:27:06:BUS-stop:DEBUG:  \n",
      "03:27:06:BUS-stop:DEBUG: 2-th run / total 3 runs\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "03:27:36:BUS-stop:DEBUG: For base_i 0, [val_acc,val_loss] = [0.6,0.6395].\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "03:28:05:BUS-stop:DEBUG: For base_i 1, [val_acc,val_loss] = [0.82,0.4004].\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "03:28:33:BUS-stop:DEBUG: For base_i 2, [val_acc,val_loss] = [0.8,0.3643].\n",
      "03:28:38:BUS-stop:INFO: In 2-th run, [best_val_acc,best_val_loss] = [0.8,0.3643].\n",
      "03:28:38:BUS-stop:INFO: p_l_conf = [0.5317, 0.5524, 0.5553, ..., 0.9804, 0.9808, 0.9867]\n",
      "03:28:38:BUS-stop:INFO: class distribution of unlabeled data: pred [0.313 0.687] -> cali [0.1822 0.8178]\n"
     ]
    }
   ],
   "source": [
    "logger.info(\"***********************\")\n",
    "logger.info(\"***Preliminary stage***\")\n",
    "logger.info(\"***********************\")\n",
    "preliminary_records = preliminary.run_stage(strategy, modeler, processor, lab_examples, X_unl, rand_seed=GLOBAL_SEED,\n",
    "                                            epochs=30, patience=10, batch_size=16, learning_rate=3e-5, val_ratio=0.5, \n",
    "                                            T=3, n_base=3, verbose=0) #verbose=0/1/2 -> print silent/progress_bar/one_line_per_epoch, \n",
    "p_l_conf, c_u_cali = preliminary.obtain_outputs(preliminary_records, cali_acc_or_f1='f1', bias_lab_or_val='val')\n",
    "\n",
    "#logger.info(\"preliminary_records = {}\".format(preliminary_records))\n",
    "p_l_ = list(np.around(p_l_conf,4))\n",
    "logger.info(\"p_l_conf = [{}, {}, {}, ..., {}, {}, {}]\".format(p_l_[0],p_l_[1],p_l_[2],p_l_[-3],p_l_[-2],p_l_[-1]))\n",
    "logger.info(\"class distribution of unlabeled data: pred {} -> cali {}\".format(\n",
    "    np.around(np.mean(preliminary_records['ulb_dist'],0),4), np.around(c_u_cali,4) ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "03:28:46:BUS-stop:INFO: ****************\n",
      "03:28:46:BUS-stop:INFO: ***Main stage***\n",
      "03:28:46:BUS-stop:INFO: ****************\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "WARNING:tensorflow:Gradients do not exist for variables ['encoder/bert/pooler/dense/kernel:0', 'encoder/bert/pooler/dense/bias:0'] when minimizing the loss.\n",
      "4/4 [==============================] - 0s 51ms/step - acc: 0.8800 - loss: 0.6139\n",
      "32/32 [==============================] - 3s 83ms/step - acc: 0.7430 - loss: 0.6510\n",
      "03:29:07:BUS-stop:INFO: Epoch 1, s_conf=3.1572, s_class=0.8669, tst_acc=0.743, tst_loss=0.651\n",
      "4/4 [==============================] - 0s 51ms/step - acc: 0.9000 - loss: 0.4872\n",
      "32/32 [==============================] - 3s 83ms/step - acc: 0.8120 - loss: 0.5640\n",
      "03:29:14:BUS-stop:INFO: Epoch 2, s_conf=2.6461, s_class=0.9118, tst_acc=0.812, tst_loss=0.564\n",
      "4/4 [==============================] - 0s 51ms/step - acc: 0.9700 - loss: 0.2899\n",
      "32/32 [==============================] - 3s 84ms/step - acc: 0.8610 - loss: 0.4057\n",
      "03:29:22:BUS-stop:INFO: Epoch 3, s_conf=1.8832, s_class=0.9693, tst_acc=0.861, tst_loss=0.4057\n",
      "4/4 [==============================] - 0s 52ms/step - acc: 0.9900 - loss: 0.1101\n",
      "32/32 [==============================] - 3s 85ms/step - acc: 0.8570 - loss: 0.3246\n",
      "03:29:29:BUS-stop:INFO: Epoch 4, s_conf=1.4792, s_class=0.9832, tst_acc=0.857, tst_loss=0.3246\n",
      "4/4 [==============================] - 0s 51ms/step - acc: 1.0000 - loss: 0.0333\n",
      "32/32 [==============================] - 3s 84ms/step - acc: 0.8970 - loss: 0.2829\n",
      "03:29:36:BUS-stop:INFO: Epoch 5, s_conf=1.5341, s_class=0.9999, tst_acc=0.897, tst_loss=0.2829\n",
      "4/4 [==============================] - 0s 50ms/step - acc: 1.0000 - loss: 0.0089\n",
      "32/32 [==============================] - 3s 85ms/step - acc: 0.8600 - loss: 0.4012\n",
      "03:29:43:BUS-stop:INFO: Epoch 6, s_conf=1.8281, s_class=0.9911, tst_acc=0.86, tst_loss=0.4012\n",
      "4/4 [==============================] - 0s 51ms/step - acc: 1.0000 - loss: 0.0042\n",
      "32/32 [==============================] - 3s 84ms/step - acc: 0.8850 - loss: 0.4063\n",
      "03:29:50:BUS-stop:INFO: Epoch 7, s_conf=1.872, s_class=0.9986, tst_acc=0.885, tst_loss=0.4063\n",
      "4/4 [==============================] - 0s 52ms/step - acc: 1.0000 - loss: 0.0024\n",
      "32/32 [==============================] - 3s 84ms/step - acc: 0.8720 - loss: 0.4771\n",
      "03:29:56:BUS-stop:INFO: Epoch 8, s_conf=1.9279, s_class=0.9977, tst_acc=0.872, tst_loss=0.4771\n",
      "4/4 [==============================] - 0s 53ms/step - acc: 1.0000 - loss: 0.0016\n",
      "32/32 [==============================] - 3s 83ms/step - acc: 0.8450 - loss: 0.6400\n",
      "03:30:03:BUS-stop:INFO: Epoch 9, s_conf=1.968, s_class=0.9893, tst_acc=0.845, tst_loss=0.64\n",
      "03:30:03:BUS-stop:INFO: ***End training***\n",
      "03:30:03:BUS-stop:INFO: ***Load the model and Evaluate on test data***\n",
      "03:30:03:BUS-stop:INFO: BUS-stop's stop_epoch = 5\n",
      "32/32 [==============================] - 3s 84ms/step - acc: 0.8970 - loss: 0.2829\n",
      "03:30:06:BUS-stop:INFO: Final tst_acc : 0.897, tst_loss : 0.2829 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "epochs = 50\n",
    "batch_size = 16\n",
    "n_que = 5\n",
    "\n",
    "logger.info(\"****************\")\n",
    "logger.info(\"***Main stage***\")\n",
    "logger.info(\"****************\")\n",
    "with strategy.scope():\n",
    "    model = modeler.build_model()\n",
    "    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08), \n",
    "                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), \n",
    "                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name=\"acc\")])\n",
    "\n",
    "steps_per_epoch = lab_len//batch_size\n",
    "queue = deque(n_que*[0], n_que)\n",
    "best_conf, n_pat = np.inf, 0\n",
    "rand_indices = np.arange(lab_len)\n",
    "for epoch in range(1,epochs+1):\n",
    "    \n",
    "    rand_indices = shuffle(rand_indices,random_state=GLOBAL_SEED)\n",
    "    for step in range(steps_per_epoch):\n",
    "        batch_indices = rand_indices[step*batch_size:(step+1)*batch_size]\n",
    "        X_bat = {}\n",
    "        for key in X_lab.keys():\n",
    "            X_bat[key] = pp.select_by_index(X_lab[key], batch_indices)\n",
    "        y_bat = pp.select_by_index(y_lab, batch_indices)\n",
    "        model.train_on_batch(X_bat,y_bat)\n",
    "    \n",
    "    trn_loss,trn_acc = model.evaluate(X_lab, y_lab)\n",
    "    tst_loss,tst_acc = model.evaluate(X_tst, y_tst) # \n",
    "    \n",
    "    unl_probs = softmax(model.predict(X_unl),axis=1)\n",
    "    unl_confs = unl_probs.max(1)\n",
    "    unl_dist = unl_probs.mean(0)\n",
    "    \n",
    "    _ids = np.arange(0,unl_len,unl_len/lab_len).astype('int32') # for downsampling\n",
    "    s_conf = euclidean(unl_confs[_ids], p_l_conf)\n",
    "    s_class = 1.-cosine(unl_dist, c_u_cali)\n",
    "    logger.info(\"Epoch {}, s_conf={}, s_class={}, tst_acc={}, tst_loss={}\".format(\n",
    "                    epoch, round(s_conf,4), round(s_class,4), round(tst_acc,4), round(tst_loss,4)))\n",
    "    \n",
    "    if s_conf < best_conf:\n",
    "        n_pat = 0\n",
    "        queue = deque(n_que*[0], n_que)\n",
    "        best_conf = s_conf\n",
    "    else:\n",
    "        n_pat += 1 \n",
    "    \n",
    "    if n_pat < n_que:\n",
    "        if s_class > max(queue):\n",
    "            best_weights = model.get_weights()\n",
    "            stop_epoch = epoch\n",
    "        queue.append(s_class)\n",
    "    else:\n",
    "        break\n",
    "\n",
    "logger.info('***End training***')\n",
    "\n",
    "logger.info('***Load the model and Evaluate on test data***')\n",
    "logger.info(\"BUS-stop's stop_epoch = {}\".format(stop_epoch))\n",
    "model.set_weights(best_weights)\n",
    "tst_loss,tst_acc = model.evaluate(X_tst, y_tst) # \n",
    "logger.info('Final tst_acc : {}, tst_loss : {} \\n'.format(round(tst_acc,4),round(tst_loss,4)))\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py36_tf",
   "language": "python",
   "name": "py36_tf"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
