{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "wlFbFLUghfjo" }, "source": [ "##### Copyright 2022 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-14T12:15:09.770598Z", "iopub.status.busy": "2022-12-14T12:15:09.770033Z", "iopub.status.idle": "2022-12-14T12:15:09.773773Z", "shell.execute_reply": "2022-12-14T12:15:09.773207Z" }, "id": "4FyfuZX-gTKS" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "I_16rv9EPhB_" }, "source": [ "# Passage Ranking using TFR-BERT\n", "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "V8tMYn22vtDV" }, "source": [ "TensorFlow Ranking can handle heterogeneous dense and sparse features, and scales up to millions of data points. However, building and deploying a learning to rank model to operate at scale creates additional challenges beyond simply designing a model. The Ranking library provides workflow utility classes for building [distributed training](https://www.tensorflow.org/guide/distributed_training) for large-scale ranking applications. For more information about these features, see the TensorFlow Ranking [Overview](../overview).\n", "\n", "This tutorial shows you how to build a ranking model that uses BERT for scoring. [BERT](https://github.com/google-research/bert) is a highly effective pretrained module to effective encode textual features into contextualized word embeddings. We use BERT to initialize the ranking model and finetune the model using a ranking loss.\n", "\n", "Note: An advanced version of this code is also available as a [Python script](https://github.com/tensorflow/ranking/blob/master/tensorflow_ranking/examples/keras/tfrbert_antique_train.py)." ] }, { "cell_type": "markdown", "metadata": { "id": "UxG7i8xbDIDF" }, "source": [ "## ANTIQUE dataset\n", "\n", "In this tutorial, you will build a ranking model for ANTIQUE, a question-answering dataset using BERT as the scoring function. Bidirectional Encoder Representations from Transformers (BERT) is a transformer-based machine learning technique which has proven to be effective in many natural language processing (NLP) tasks. Recent work on [TFR-BERT](https://https://arxiv.org/abs/2004.08476) has shown BERT to be an effective scoring function for learning-to-rank tasks.\n", "\n", "Given a query, and a list of answers, the objective of the ranking model is to rank the answers with optimal rank related metrics, such as NDCG. For more details about ranking metrics, review evaluation measures [offline metrics](https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Offline_metrics).\n", "\n", "[ANTIQUE](https://ciir.cs.umass.edu/downloads/Antique/) is a publicly available dataset for open-domain non-factoid question answering, collected from Yahoo! answers.\n", "Each question has a list of answers, whose relevance are graded on a scale of 0-4, 0 for irrelevant and 4 for fully relevant.\n", "The list size can vary depending on the query, so we use a fixed \"list size\" of 50, where the list is either truncated or padded with default values.\n", "The dataset is split into 2206 queries for training and 200 queries for testing. For more details, please read the technical paper on [arXiv](https://arxiv.org/abs/1905.08957)." ] }, { "cell_type": "markdown", "metadata": { "id": "ucWaXnFazZXD" }, "source": [ "## Setup\n", "\n", "Download and install the TensorFlow Ranking and TensorFlow Model Garden packages." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:15:09.777100Z", "iopub.status.busy": "2022-12-14T12:15:09.776526Z", "iopub.status.idle": "2022-12-14T12:15:37.556293Z", "shell.execute_reply": "2022-12-14T12:15:37.555222Z" }, "id": "aPmhLkMWgPLO" }, "outputs": [], "source": [ "!pip install -q tensorflow-ranking tf-models-official" ] }, { "cell_type": "markdown", "metadata": { "id": "9OKDJUjq0rnm" }, "source": [ "Import TensorFlow Ranking and useful libraries through the notebook." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:15:37.560776Z", "iopub.status.busy": "2022-12-14T12:15:37.560457Z", "iopub.status.idle": "2022-12-14T12:15:40.227692Z", "shell.execute_reply": "2022-12-14T12:15:40.227013Z" }, "id": "fmlaz2D5Ux3J" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 12:15:38.646771: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 12:15:38.646870: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 12:15:38.646879: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] } ], "source": [ "import os\n", "import tensorflow as tf\n", "import tensorflow_ranking as tfr\n", "from official.nlp.configs import encoders\n", "from tensorflow_ranking.extension.premade import tfrbert_task" ] }, { "cell_type": "markdown", "metadata": { "id": "JNilCoqq1jJn" }, "source": [ "## Data preparation\n", "\n", "Download training and test data." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:15:40.231844Z", "iopub.status.busy": "2022-12-14T12:15:40.231478Z", "iopub.status.idle": "2022-12-14T12:15:41.713034Z", "shell.execute_reply": "2022-12-14T12:15:41.712234Z" }, "id": "Mwxtsi4wqoOJ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2022-12-14 12:15:40-- https://ciir.cs.umass.edu/downloads/Antique/tf-ranking/antique_train_seq_64_elwc.tfrecords\r\n", "Resolving ciir.cs.umass.edu (ciir.cs.umass.edu)... " ] }, { "name": "stdout", "output_type": "stream", "text": [ "128.119.246.154\r\n", "Connecting to ciir.cs.umass.edu (ciir.cs.umass.edu)|128.119.246.154|:443... connected.\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "HTTP request sent, awaiting response... " ] }, { "name": "stdout", "output_type": "stream", "text": [ "200 OK\r\n", "Length: 8743528 (8.3M)\r\n", "Saving to: ‘/tmp/train.tfrecords’\r\n", "\r\n", "\r", "/tmp/train.tfrecord 0%[ ] 0 --.-KB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "/tmp/train.tfrecord 12%[=> ] 1.02M 5.11MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "/tmp/train.tfrecord 53%[=========> ] 4.49M 11.2MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "/tmp/train.tfrecord 85%[================> ] 7.13M 11.8MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "/tmp/train.tfrecord 100%[===================>] 8.34M 12.0MB/s in 0.7s \r\n", "\r\n", "2022-12-14 12:15:41 (12.0 MB/s) - ‘/tmp/train.tfrecords’ saved [8743528/8743528]\r\n", "\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--2022-12-14 12:15:41-- https://ciir.cs.umass.edu/downloads/Antique/tf-ranking/antique_test_seq_64_elwc.tfrecords\r\n", "Resolving ciir.cs.umass.edu (ciir.cs.umass.edu)... 128.119.246.154\r\n", "Connecting to ciir.cs.umass.edu (ciir.cs.umass.edu)|128.119.246.154|:443... connected.\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "HTTP request sent, awaiting response... 200 OK\r\n", "Length: 692072 (676K)\r\n", "Saving to: ‘/tmp/test.tfrecords’\r\n", "\r\n", "\r", "/tmp/test.tfrecords 0%[ ] 0 --.-KB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "/tmp/test.tfrecords 100%[===================>] 675.85K 3.93MB/s in 0.2s \r\n", "\r\n", "2022-12-14 12:15:41 (3.93 MB/s) - ‘/tmp/test.tfrecords’ saved [692072/692072]\r\n", "\r\n" ] } ], "source": [ "!wget -O \"/tmp/train.tfrecords\" \"https://ciir.cs.umass.edu/downloads/Antique/tf-ranking/antique_train_seq_64_elwc.tfrecords\"\n", "!wget -O \"/tmp/test.tfrecords\" \"https://ciir.cs.umass.edu/downloads/Antique/tf-ranking/antique_test_seq_64_elwc.tfrecords\"" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:15:41.716875Z", "iopub.status.busy": "2022-12-14T12:15:41.716199Z", "iopub.status.idle": "2022-12-14T12:15:49.094140Z", "shell.execute_reply": "2022-12-14T12:15:49.093325Z" }, "id": "D5vL2R7aOSoe" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2022-12-14 12:15:41-- https://storage.googleapis.com/cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12.tar.gz\r\n", "Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.196.128, 142.251.107.128, 142.250.97.128, ...\r\n", "Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.196.128|:443... connected.\r\n", "HTTP request sent, awaiting response... " ] }, { "name": "stdout", "output_type": "stream", "text": [ "200 OK\r\n", "Length: 405351189 (387M) [application/octet-stream]\r\n", "Saving to: ‘/tmp/tfrbert/uncased_L-12_H-768_A-12.tar.gz’\r\n", "\r\n", "\r", " uncased_L 0%[ ] 0 --.-KB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " uncased_L- 2%[ ] 8.01M 38.8MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " uncased_L-1 10%[=> ] 40.01M 90.5MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " uncased_L-12 18%[==> ] 72.01M 104MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " uncased_L-12_ 27%[====> ] 105.16M 118MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " uncased_L-12_H 36%[======> ] 140.46M 128MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " uncased_L-12_H- 45%[========> ] 176.01M 129MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " uncased_L-12_H-7 54%[=========> ] 209.23M 134MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " uncased_L-12_H-76 62%[===========> ] 240.01M 135MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " uncased_L-12_H-768 70%[=============> ] 272.01M 136MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "uncased_L-12_H-768_ 78%[==============> ] 304.05M 138MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "ncased_L-12_H-768_A 86%[================> ] 336.01M 139MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "cased_L-12_H-768_A- 94%[=================> ] 366.19M 140MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "uncased_L-12_H-768_ 100%[===================>] 386.57M 143MB/s in 2.7s \r\n", "\r\n", "2022-12-14 12:15:44 (143 MB/s) - ‘/tmp/tfrbert/uncased_L-12_H-768_A-12.tar.gz’ saved [405351189/405351189]\r\n", "\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "tmp/temp_dir/raw/vocab.txt\r\n", "tmp/temp_dir/raw/bert_model.ckpt.index\r\n", "tmp/temp_dir/raw/bert_model.ckpt.data-00000-of-00001\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "tmp/temp_dir/raw/bert_config.json\r\n" ] } ], "source": [ "!mkdir -p /tmp/tfrbert\n", "!wget \"https://storage.googleapis.com/cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12.tar.gz\" -P \"/tmp/tfrbert\"\n", "!mkdir -p /tmp/tfrbert/uncased_L-12_H-768_A-12\n", "!tar -xvf /tmp/tfrbert/uncased_L-12_H-768_A-12.tar.gz --strip-components 3 -C \"/tmp/tfrbert/uncased_L-12_H-768_A-12/\"" ] }, { "cell_type": "markdown", "metadata": { "id": "tFbFBTUh9WXf" }, "source": [ "## Overview of TFR-BERT in Orbit\n", "\n", "BERT-based ranking models ([TFR-BERT](https://arxiv.org/abs/2004.08476)) have been shown to be effective for learning-to-rank tasks when using raw textual features for query and passages in MSMARCO passage ranking dataset.\n", "\n", "[Orbit](https://github.com/tensorflow/models/tree/master/orbit) is a flexible, lightweight library designed to make it easy to write custom training loops in TensorFlow. TensorFlow Ranking provides support for implementing ranking models, particularly for BERT based ranking models using Orbit." ] }, { "cell_type": "markdown", "metadata": { "id": "aQ-VTA56sOTA" }, "source": [ "## Create a Ranking Task for TFR-BERT\n", "\n", "We create a ranking task for TFR-BERT model which can be trained using Orbit. The steps to build this are:\n", "\n", "1. Define Feature Specifications\n", "2. Define datasets\n", "5. Setup data and task configurations\n" ] }, { "cell_type": "markdown", "metadata": { "id": "at0nVKnts8Pn" }, "source": [ "### Specify Features\n", "\n", "[Feature Specification](https://www.tensorflow.org/api_docs/python/tf/io) are TensorFlow abstractions to capture information about each feature. These help developers and model researchers understand and use a model.\n", "\n", "Create feature specifications for context features, example features, and labels, consistent with the input formats for ranking, such as ELWC format." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:15:49.098490Z", "iopub.status.busy": "2022-12-14T12:15:49.097997Z", "iopub.status.idle": "2022-12-14T12:15:49.103678Z", "shell.execute_reply": "2022-12-14T12:15:49.103089Z" }, "id": "nSXd4pEPqaQW" }, "outputs": [], "source": [ "SEQ_LENGTH = 64\n", "context_feature_spec = {}\n", "example_feature_spec = {\n", " 'input_word_ids': tf.io.FixedLenFeature(\n", " shape=(SEQ_LENGTH,), dtype=tf.int64,\n", " default_value=[0] * SEQ_LENGTH),\n", " 'input_mask': tf.io.FixedLenFeature(\n", " shape=(SEQ_LENGTH,), dtype=tf.int64,\n", " default_value=[0] * SEQ_LENGTH),\n", " 'input_type_ids': tf.io.FixedLenFeature(\n", " shape=(SEQ_LENGTH,), dtype=tf.int64,\n", " default_value=[0] * SEQ_LENGTH)}\n", "label_spec = (\n", " \"relevance\",\n", " tf.io.FixedLenFeature(shape=(1,), dtype=tf.int64, default_value=-1)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "X2Iq_YA2HpCe" }, "source": [ "Note: the `default_value` of `label_spec` feature is set to -1 to take care of the padding items to be masked out." ] }, { "cell_type": "markdown", "metadata": { "id": "nKoHvr1oj02f" }, "source": [ "### Define Datasets\n", "\n", "We define data configurations for training and validation data, which specifies parameters such as path, batch size, and dataset format. These configurations are used to create training and validation datasets." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:15:49.106747Z", "iopub.status.busy": "2022-12-14T12:15:49.106526Z", "iopub.status.idle": "2022-12-14T12:15:49.110635Z", "shell.execute_reply": "2022-12-14T12:15:49.110089Z" }, "id": "fCJZaAqBj0YE" }, "outputs": [], "source": [ "# Set up data config\n", "# We use a small list size here for demo purposes only. Users can use a larger\n", "# list size on a machine with more memory to train TFR-BERT.\n", "train_data_config = tfrbert_task.TFRBertDataConfig(\n", " input_path=\"/tmp/train.tfrecords\",\n", " is_training=True,\n", " global_batch_size=8,\n", " list_size=2,\n", " dataset_fn='tfrecord',\n", " seq_length=64)\n", "\n", "validation_data_config = tfrbert_task.TFRBertDataConfig(\n", " input_path=\"/tmp/test.tfrecords\",\n", " is_training=False,\n", " global_batch_size=8,\n", " list_size=2,\n", " dataset_fn='tfrecord',\n", " seq_length=64)" ] }, { "cell_type": "markdown", "metadata": { "id": "8-O_z1S2j_s0" }, "source": [ "### Define Task\n", "\n", "We define a task configuration which defines the training and validation dataset along with the model. This configuration creates a `TFRBertTask` object that can be trained using Orbit." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:15:49.113825Z", "iopub.status.busy": "2022-12-14T12:15:49.113447Z", "iopub.status.idle": "2022-12-14T12:15:49.119120Z", "shell.execute_reply": "2022-12-14T12:15:49.118589Z" }, "id": "uLly4jcykCC1" }, "outputs": [], "source": [ "# Set up task config\n", "task_config = tfrbert_task.TFRBertConfig(\n", " init_checkpoint='/tmp/tfrbert/uncased_L-12_H-768_A-12/bert_model.ckpt',\n", " train_data=train_data_config,\n", " validation_data=validation_data_config,\n", " model=tfrbert_task.TFRBertModelConfig(\n", " encoder=encoders.EncoderConfig(\n", " bert=encoders.BertEncoderConfig(num_layers=12))))\n", "\n", "# Set up TFRBertTask\n", "task = tfrbert_task.TFRBertTask(\n", " task_config,\n", " label_spec=label_spec,\n", " dataset_fn=tf.data.TFRecordDataset,\n", " logging_dir='/tmp/model_dir')" ] }, { "cell_type": "markdown", "metadata": { "id": "_EPvXEbomK29" }, "source": [ "## Train and evaluate the model\n", "\n", "We define the training loop here to train and evaluate the model. We define the metrics, create train and eval datasets and train the model for a specific number of training steps.\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:15:49.122360Z", "iopub.status.busy": "2022-12-14T12:15:49.121874Z", "iopub.status.idle": "2022-12-14T12:20:46.865360Z", "shell.execute_reply": "2022-12-14T12:20:46.864638Z" }, "id": "BZr8MX6VmQSj" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 12:15:49.221104: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.\n", "Instructions for updating:\n", "Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.\n", "Instructions for updating:\n", "Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:`lr` is deprecated, please use `learning_rate` instead, or use the legacy optimizer, e.g.,tf.keras.optimizers.legacy.Adam.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 0 {'MAP': 0.9375, 'NDCG@1': 0.73214287, 'NDCG@5': 0.912364, 'NDCG@10': 0.912364, 'MRR@1': 0.875, 'MRR@5': 0.9375, 'MRR@10': 0.9375}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation metrics for epoch: 0 {'MAP': 0.96875, 'NDCG@1': 0.66369045, 'NDCG@5': 0.89421266, 'NDCG@10': 0.89421266, 'MRR@1': 0.9375, 'MRR@5': 0.96875, 'MRR@10': 0.96875}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 1 {'MAP': 0.9583333, 'NDCG@1': 0.6507936, 'NDCG@5': 0.88817185, 'NDCG@10': 0.88817185, 'MRR@1': 0.9166667, 'MRR@5': 0.9583333, 'MRR@10': 0.9583333}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 2 {'MAP': 0.96875, 'NDCG@1': 0.6577381, 'NDCG@5': 0.89149714, 'NDCG@10': 0.89149714, 'MRR@1': 0.9375, 'MRR@5': 0.96875, 'MRR@10': 0.96875}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 3 {'MAP': 0.975, 'NDCG@1': 0.68095237, 'NDCG@5': 0.89981496, 'NDCG@10': 0.89981496, 'MRR@1': 0.95, 'MRR@5': 0.975, 'MRR@10': 0.975}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 4 {'MAP': 0.9791667, 'NDCG@1': 0.71031743, 'NDCG@5': 0.9095955, 'NDCG@10': 0.9095955, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 5 {'MAP': 0.98214287, 'NDCG@1': 0.7091837, 'NDCG@5': 0.9085163, 'NDCG@10': 0.9085163, 'MRR@1': 0.96428573, 'MRR@5': 0.98214287, 'MRR@10': 0.98214287}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 6 {'MAP': 0.9765625, 'NDCG@1': 0.68526787, 'NDCG@5': 0.8999288, 'NDCG@10': 0.8999288, 'MRR@1': 0.953125, 'MRR@5': 0.9765625, 'MRR@10': 0.9765625}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 7 {'MAP': 0.9791667, 'NDCG@1': 0.7030423, 'NDCG@5': 0.90591866, 'NDCG@10': 0.90591866, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 8 {'MAP': 0.98125, 'NDCG@1': 0.7255953, 'NDCG@5': 0.9132517, 'NDCG@10': 0.9132517, 'MRR@1': 0.9625, 'MRR@5': 0.98125, 'MRR@10': 0.98125}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 9 {'MAP': 0.97727275, 'NDCG@1': 0.7229437, 'NDCG@5': 0.9117598, 'NDCG@10': 0.9117598, 'MRR@1': 0.95454544, 'MRR@5': 0.97727275, 'MRR@10': 0.97727275}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 10 {'MAP': 0.9791667, 'NDCG@1': 0.7311508, 'NDCG@5': 0.91436106, 'NDCG@10': 0.91436106, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation metrics for epoch: 10 {'MAP': 0.97596157, 'NDCG@1': 0.7339744, 'NDCG@5': 0.9146096, 'NDCG@10': 0.9146096, 'MRR@1': 0.9519231, 'MRR@5': 0.97596157, 'MRR@10': 0.97596157}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 11 {'MAP': 0.9776786, 'NDCG@1': 0.73511904, 'NDCG@5': 0.9151535, 'NDCG@10': 0.9151535, 'MRR@1': 0.95535713, 'MRR@5': 0.9776786, 'MRR@10': 0.9776786}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 12 {'MAP': 0.975, 'NDCG@1': 0.7253969, 'NDCG@5': 0.91220075, 'NDCG@10': 0.91220075, 'MRR@1': 0.95, 'MRR@5': 0.975, 'MRR@10': 0.975}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 13 {'MAP': 0.9765625, 'NDCG@1': 0.73363096, 'NDCG@5': 0.9150943, 'NDCG@10': 0.9150943, 'MRR@1': 0.953125, 'MRR@5': 0.9765625, 'MRR@10': 0.9765625}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 14 {'MAP': 0.97794116, 'NDCG@1': 0.7366947, 'NDCG@5': 0.91642684, 'NDCG@10': 0.91642684, 'MRR@1': 0.9558824, 'MRR@5': 0.97794116, 'MRR@10': 0.97794116}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 15 {'MAP': 0.9791667, 'NDCG@1': 0.7433862, 'NDCG@5': 0.9187641, 'NDCG@10': 0.9187641, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 16 {'MAP': 0.9769737, 'NDCG@1': 0.73903507, 'NDCG@5': 0.91679335, 'NDCG@10': 0.91679335, 'MRR@1': 0.95394737, 'MRR@5': 0.9769737, 'MRR@10': 0.9769737}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 17 {'MAP': 0.978125, 'NDCG@1': 0.7407738, 'NDCG@5': 0.91760796, 'NDCG@10': 0.91760796, 'MRR@1': 0.95625, 'MRR@5': 0.978125, 'MRR@10': 0.978125}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 18 {'MAP': 0.9791667, 'NDCG@1': 0.73781174, 'NDCG@5': 0.9168396, 'NDCG@10': 0.9168396, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 19 {'MAP': 0.9801136, 'NDCG@1': 0.7464827, 'NDCG@5': 0.91967636, 'NDCG@10': 0.91967636, 'MRR@1': 0.96022725, 'MRR@5': 0.9801136, 'MRR@10': 0.9801136}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 20 {'MAP': 0.9782609, 'NDCG@1': 0.7380952, 'NDCG@5': 0.91687906, 'NDCG@10': 0.91687906, 'MRR@1': 0.95652175, 'MRR@5': 0.9782609, 'MRR@10': 0.9782609}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation metrics for epoch: 20 {'MAP': 0.9791667, 'NDCG@1': 0.74900794, 'NDCG@5': 0.92034245, 'NDCG@10': 0.92034245, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 21 {'MAP': 0.98, 'NDCG@1': 0.7561905, 'NDCG@5': 0.9226987, 'NDCG@10': 0.9226987, 'MRR@1': 0.96, 'MRR@5': 0.98, 'MRR@10': 0.98}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 22 {'MAP': 0.9807692, 'NDCG@1': 0.7545787, 'NDCG@5': 0.92208344, 'NDCG@10': 0.92208344, 'MRR@1': 0.96153843, 'MRR@5': 0.9807692, 'MRR@10': 0.9807692}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 23 {'MAP': 0.9814815, 'NDCG@1': 0.75837743, 'NDCG@5': 0.9234321, 'NDCG@10': 0.9234321, 'MRR@1': 0.962963, 'MRR@5': 0.9814815, 'MRR@10': 0.9814815}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 24 {'MAP': 0.98214287, 'NDCG@1': 0.7568027, 'NDCG@5': 0.9228346, 'NDCG@10': 0.9228346, 'MRR@1': 0.96428573, 'MRR@5': 0.98214287, 'MRR@10': 0.98214287}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 25 {'MAP': 0.98275864, 'NDCG@1': 0.7627258, 'NDCG@5': 0.9247799, 'NDCG@10': 0.9247799, 'MRR@1': 0.9655172, 'MRR@5': 0.98275864, 'MRR@10': 0.98275864}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 26 {'MAP': 0.98333335, 'NDCG@1': 0.76468253, 'NDCG@5': 0.9253864, 'NDCG@10': 0.9253864, 'MRR@1': 0.96666664, 'MRR@5': 0.98333335, 'MRR@10': 0.98333335}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 27 {'MAP': 0.983871, 'NDCG@1': 0.765361, 'NDCG@5': 0.9257851, 'NDCG@10': 0.9257851, 'MRR@1': 0.9677419, 'MRR@5': 0.983871, 'MRR@10': 0.983871}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 28 {'MAP': 0.984375, 'NDCG@1': 0.7671131, 'NDCG@5': 0.92632234, 'NDCG@10': 0.92632234, 'MRR@1': 0.96875, 'MRR@5': 0.984375, 'MRR@10': 0.984375}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 29 {'MAP': 0.98295456, 'NDCG@1': 0.7638889, 'NDCG@5': 0.92527056, 'NDCG@10': 0.92527056, 'MRR@1': 0.96590906, 'MRR@5': 0.98295456, 'MRR@10': 0.98295456}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 30 {'MAP': 0.9834559, 'NDCG@1': 0.7708334, 'NDCG@5': 0.9274685, 'NDCG@10': 0.9274685, 'MRR@1': 0.9669118, 'MRR@5': 0.9834559, 'MRR@10': 0.9834559}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation metrics for epoch: 30 {'MAP': 0.98392856, 'NDCG@1': 0.76513606, 'NDCG@5': 0.9256894, 'NDCG@10': 0.9262823, 'MRR@1': 0.9678571, 'MRR@5': 0.98392856, 'MRR@10': 0.98392856}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 31 {'MAP': 0.984375, 'NDCG@1': 0.765377, 'NDCG@5': 0.92589486, 'NDCG@10': 0.9264713, 'MRR@1': 0.96875, 'MRR@5': 0.984375, 'MRR@10': 0.984375}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 32 {'MAP': 0.9831081, 'NDCG@1': 0.765444, 'NDCG@5': 0.92567044, 'NDCG@10': 0.92623127, 'MRR@1': 0.9662162, 'MRR@5': 0.9831081, 'MRR@10': 0.9831081}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 33 {'MAP': 0.98355263, 'NDCG@1': 0.768797, 'NDCG@5': 0.92667186, 'NDCG@10': 0.92721796, 'MRR@1': 0.96710527, 'MRR@5': 0.98355263, 'MRR@10': 0.98355263}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 34 {'MAP': 0.98397434, 'NDCG@1': 0.7728938, 'NDCG@5': 0.92802, 'NDCG@10': 0.9285521, 'MRR@1': 0.96794873, 'MRR@5': 0.98397434, 'MRR@10': 0.98397434}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 35 {'MAP': 0.984375, 'NDCG@1': 0.775, 'NDCG@5': 0.92878187, 'NDCG@10': 0.92930067, 'MRR@1': 0.96875, 'MRR@5': 0.984375, 'MRR@10': 0.984375}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 36 {'MAP': 0.9832317, 'NDCG@1': 0.77482575, 'NDCG@5': 0.92850894, 'NDCG@10': 0.9290151, 'MRR@1': 0.9664634, 'MRR@5': 0.9832317, 'MRR@10': 0.9832317}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 37 {'MAP': 0.98214287, 'NDCG@1': 0.770975, 'NDCG@5': 0.9271499, 'NDCG@10': 0.927644, 'MRR@1': 0.96428573, 'MRR@5': 0.98214287, 'MRR@10': 0.98214287}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 38 {'MAP': 0.98255813, 'NDCG@1': 0.7743632, 'NDCG@5': 0.9282532, 'NDCG@10': 0.9287358, 'MRR@1': 0.96511626, 'MRR@5': 0.98255813, 'MRR@10': 0.98255813}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 39 {'MAP': 0.98295456, 'NDCG@1': 0.7746212, 'NDCG@5': 0.928235, 'NDCG@10': 0.9287066, 'MRR@1': 0.96590906, 'MRR@5': 0.98295456, 'MRR@10': 0.98295456}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 40 {'MAP': 0.98194444, 'NDCG@1': 0.77685183, 'NDCG@5': 0.9288045, 'NDCG@10': 0.9292657, 'MRR@1': 0.9638889, 'MRR@5': 0.98194444, 'MRR@10': 0.98194444}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation metrics for epoch: 40 {'MAP': 0.98097825, 'NDCG@1': 0.7758799, 'NDCG@5': 0.9284471, 'NDCG@10': 0.9288983, 'MRR@1': 0.9619565, 'MRR@5': 0.98233694, 'MRR@10': 0.98233694}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 41 {'MAP': 0.98138297, 'NDCG@1': 0.7730496, 'NDCG@5': 0.927543, 'NDCG@10': 0.92798454, 'MRR@1': 0.96276593, 'MRR@5': 0.98271275, 'MRR@10': 0.98271275}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 42 {'MAP': 0.9791667, 'NDCG@1': 0.7703373, 'NDCG@5': 0.9263746, 'NDCG@10': 0.9268069, 'MRR@1': 0.9583333, 'MRR@5': 0.98046875, 'MRR@10': 0.98046875}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 43 {'MAP': 0.97959185, 'NDCG@1': 0.77502424, 'NDCG@5': 0.9278771, 'NDCG@10': 0.9283007, 'MRR@1': 0.9591837, 'MRR@5': 0.9808673, 'MRR@10': 0.9808673}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 44 {'MAP': 0.97875, 'NDCG@1': 0.7705952, 'NDCG@5': 0.9264264, 'NDCG@10': 0.92684144, 'MRR@1': 0.9575, 'MRR@5': 0.98, 'MRR@10': 0.98}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 45 {'MAP': 0.9791667, 'NDCG@1': 0.7708916, 'NDCG@5': 0.9264465, 'NDCG@10': 0.9268534, 'MRR@1': 0.9583333, 'MRR@5': 0.98039216, 'MRR@10': 0.98039216}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 46 {'MAP': 0.9795673, 'NDCG@1': 0.7732371, 'NDCG@5': 0.9271634, 'NDCG@10': 0.9275625, 'MRR@1': 0.95913464, 'MRR@5': 0.9807692, 'MRR@10': 0.9807692}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 47 {'MAP': 0.9799528, 'NDCG@1': 0.77279866, 'NDCG@5': 0.9270702, 'NDCG@10': 0.92746174, 'MRR@1': 0.9599057, 'MRR@5': 0.9811321, 'MRR@10': 0.9811321}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 48 {'MAP': 0.9780093, 'NDCG@1': 0.7718253, 'NDCG@5': 0.92525107, 'NDCG@10': 0.9256354, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 49 {'MAP': 0.9784091, 'NDCG@1': 0.77077913, 'NDCG@5': 0.925101, 'NDCG@10': 0.9254783, 'MRR@1': 0.9590909, 'MRR@5': 0.9795455, 'MRR@10': 0.9795455}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 50 {'MAP': 0.9776786, 'NDCG@1': 0.7694515, 'NDCG@5': 0.92459637, 'NDCG@10': 0.92496693, 'MRR@1': 0.95758927, 'MRR@5': 0.97879463, 'MRR@10': 0.97879463}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation metrics for epoch: 50 {'MAP': 0.9780702, 'NDCG@1': 0.7703634, 'NDCG@5': 0.9249188, 'NDCG@10': 0.92528284, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 51 {'MAP': 0.9773707, 'NDCG@1': 0.7690887, 'NDCG@5': 0.9244347, 'NDCG@10': 0.9247925, 'MRR@1': 0.95689654, 'MRR@5': 0.9784483, 'MRR@10': 0.9784483}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 52 {'MAP': 0.97775424, 'NDCG@1': 0.76432604, 'NDCG@5': 0.9228256, 'NDCG@10': 0.9231773, 'MRR@1': 0.9576271, 'MRR@5': 0.9788136, 'MRR@10': 0.9788136}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 53 {'MAP': 0.978125, 'NDCG@1': 0.7664682, 'NDCG@5': 0.9235073, 'NDCG@10': 0.9238531, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 54 {'MAP': 0.9784836, 'NDCG@1': 0.7661983, 'NDCG@5': 0.9234862, 'NDCG@10': 0.9238264, 'MRR@1': 0.9590164, 'MRR@5': 0.9795082, 'MRR@10': 0.9795082}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 55 {'MAP': 0.97883064, 'NDCG@1': 0.76766515, 'NDCG@5': 0.92405087, 'NDCG@10': 0.92438555, 'MRR@1': 0.9596774, 'MRR@5': 0.9798387, 'MRR@10': 0.9798387}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 56 {'MAP': 0.9791667, 'NDCG@1': 0.7696523, 'NDCG@5': 0.9246806, 'NDCG@10': 0.92501, 'MRR@1': 0.96031743, 'MRR@5': 0.98015875, 'MRR@10': 0.98015875}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 57 {'MAP': 0.9785156, 'NDCG@1': 0.7679501, 'NDCG@5': 0.92400306, 'NDCG@10': 0.9243273, 'MRR@1': 0.9589844, 'MRR@5': 0.9794922, 'MRR@10': 0.9794922}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 58 {'MAP': 0.97884613, 'NDCG@1': 0.76767397, 'NDCG@5': 0.92397565, 'NDCG@10': 0.92429495, 'MRR@1': 0.9596154, 'MRR@5': 0.9798077, 'MRR@10': 0.9798077}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 59 {'MAP': 0.9782197, 'NDCG@1': 0.7655122, 'NDCG@5': 0.92325014, 'NDCG@10': 0.92356455, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 60 {'MAP': 0.97761196, 'NDCG@1': 0.76341504, 'NDCG@5': 0.92254627, 'NDCG@10': 0.92285603, 'MRR@1': 0.95708954, 'MRR@5': 0.9785448, 'MRR@10': 0.9785448}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation metrics for epoch: 60 {'MAP': 0.97794116, 'NDCG@1': 0.76216733, 'NDCG@5': 0.9222364, 'NDCG@10': 0.9227017, 'MRR@1': 0.9577206, 'MRR@5': 0.9788603, 'MRR@10': 0.9788603}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 61 {'MAP': 0.9782609, 'NDCG@1': 0.7623361, 'NDCG@5': 0.9223936, 'NDCG@10': 0.9228522, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 62 {'MAP': 0.9785714, 'NDCG@1': 0.76164967, 'NDCG@5': 0.92231655, 'NDCG@10': 0.92276853, 'MRR@1': 0.9589286, 'MRR@5': 0.9794643, 'MRR@10': 0.9794643}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 63 {'MAP': 0.97711265, 'NDCG@1': 0.75796443, 'NDCG@5': 0.9210157, 'NDCG@10': 0.9214613, 'MRR@1': 0.9559859, 'MRR@5': 0.97799295, 'MRR@10': 0.97799295}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 64 {'MAP': 0.9765625, 'NDCG@1': 0.7581018, 'NDCG@5': 0.9209682, 'NDCG@10': 0.9214076, 'MRR@1': 0.9548611, 'MRR@5': 0.9774306, 'MRR@10': 0.9774306}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 65 {'MAP': 0.97602737, 'NDCG@1': 0.7577462, 'NDCG@5': 0.9208503, 'NDCG@10': 0.92128366, 'MRR@1': 0.9537671, 'MRR@5': 0.9768836, 'MRR@10': 0.9768836}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 66 {'MAP': 0.9755068, 'NDCG@1': 0.75546974, 'NDCG@5': 0.9200356, 'NDCG@10': 0.92046314, 'MRR@1': 0.9527027, 'MRR@5': 0.9763514, 'MRR@10': 0.9763514}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 67 {'MAP': 0.97583336, 'NDCG@1': 0.7563492, 'NDCG@5': 0.9203415, 'NDCG@10': 0.9207634, 'MRR@1': 0.9533333, 'MRR@5': 0.9766667, 'MRR@10': 0.9766667}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 68 {'MAP': 0.9761513, 'NDCG@1': 0.7581454, 'NDCG@5': 0.9209124, 'NDCG@10': 0.9213287, 'MRR@1': 0.95394737, 'MRR@5': 0.9769737, 'MRR@10': 0.9769737}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 69 {'MAP': 0.97564936, 'NDCG@1': 0.7578077, 'NDCG@5': 0.92080134, 'NDCG@10': 0.92121226, 'MRR@1': 0.9529221, 'MRR@5': 0.97646105, 'MRR@10': 0.97646105}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 70 {'MAP': 0.97596157, 'NDCG@1': 0.75862336, 'NDCG@5': 0.92108566, 'NDCG@10': 0.92149127, 'MRR@1': 0.95352566, 'MRR@5': 0.97676283, 'MRR@10': 0.97676283}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation metrics for epoch: 70 {'MAP': 0.97626585, 'NDCG@1': 0.75761, 'NDCG@5': 0.9207071, 'NDCG@10': 0.9211076, 'MRR@1': 0.9541139, 'MRR@5': 0.977057, 'MRR@10': 0.977057}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 71 {'MAP': 0.9765625, 'NDCG@1': 0.75796133, 'NDCG@5': 0.9209201, 'NDCG@10': 0.92131555, 'MRR@1': 0.9546875, 'MRR@5': 0.97734374, 'MRR@10': 0.97734374}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 72 {'MAP': 0.9768519, 'NDCG@1': 0.75727516, 'NDCG@5': 0.92068696, 'NDCG@10': 0.9210776, 'MRR@1': 0.9552469, 'MRR@5': 0.97762346, 'MRR@10': 0.97762346}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 73 {'MAP': 0.97713417, 'NDCG@1': 0.7589286, 'NDCG@5': 0.9212119, 'NDCG@10': 0.9215977, 'MRR@1': 0.95579267, 'MRR@5': 0.97789633, 'MRR@10': 0.97789633}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 74 {'MAP': 0.97740966, 'NDCG@1': 0.7618331, 'NDCG@5': 0.92216116, 'NDCG@10': 0.92254233, 'MRR@1': 0.9563253, 'MRR@5': 0.97816265, 'MRR@10': 0.97816265}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 75 {'MAP': 0.9776786, 'NDCG@1': 0.761267, 'NDCG@5': 0.9219771, 'NDCG@10': 0.92235374, 'MRR@1': 0.9568452, 'MRR@5': 0.97842264, 'MRR@10': 0.97842264}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 76 {'MAP': 0.97794116, 'NDCG@1': 0.759874, 'NDCG@5': 0.9215532, 'NDCG@10': 0.9219254, 'MRR@1': 0.95735294, 'MRR@5': 0.9786765, 'MRR@10': 0.9786765}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 77 {'MAP': 0.9781977, 'NDCG@1': 0.75754434, 'NDCG@5': 0.9208438, 'NDCG@10': 0.92121166, 'MRR@1': 0.95784885, 'MRR@5': 0.9789244, 'MRR@10': 0.9789244}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 78 {'MAP': 0.9784483, 'NDCG@1': 0.7573208, 'NDCG@5': 0.9208061, 'NDCG@10': 0.92116976, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 79 {'MAP': 0.9786932, 'NDCG@1': 0.758861, 'NDCG@5': 0.92129385, 'NDCG@10': 0.9216534, 'MRR@1': 0.9588068, 'MRR@5': 0.97940344, 'MRR@10': 0.97940344}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 80 {'MAP': 0.97893256, 'NDCG@1': 0.7603666, 'NDCG@5': 0.9217707, 'NDCG@10': 0.9221262, 'MRR@1': 0.95926964, 'MRR@5': 0.9796348, 'MRR@10': 0.9796348}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation metrics for epoch: 80 {'MAP': 0.9791667, 'NDCG@1': 0.75925934, 'NDCG@5': 0.92172426, 'NDCG@10': 0.92167276, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.97986114}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 81 {'MAP': 0.9793956, 'NDCG@1': 0.7595501, 'NDCG@5': 0.9217872, 'NDCG@10': 0.92173624, 'MRR@1': 0.9587912, 'MRR@5': 0.9793956, 'MRR@10': 0.9800824}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 82 {'MAP': 0.97961956, 'NDCG@1': 0.7606108, 'NDCG@5': 0.92218626, 'NDCG@10': 0.92213583, 'MRR@1': 0.9592391, 'MRR@5': 0.97961956, 'MRR@10': 0.98029894}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 83 {'MAP': 0.9798387, 'NDCG@1': 0.7601127, 'NDCG@5': 0.9220197, 'NDCG@10': 0.92196983, 'MRR@1': 0.9596774, 'MRR@5': 0.9798387, 'MRR@10': 0.9805108}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 84 {'MAP': 0.9800532, 'NDCG@1': 0.7596252, 'NDCG@5': 0.92185676, 'NDCG@10': 0.9218074, 'MRR@1': 0.9601064, 'MRR@5': 0.9800532, 'MRR@10': 0.9807181}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 85 {'MAP': 0.9802632, 'NDCG@1': 0.7612783, 'NDCG@5': 0.9224118, 'NDCG@10': 0.922363, 'MRR@1': 0.9605263, 'MRR@5': 0.9802632, 'MRR@10': 0.98092103}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 86 {'MAP': 0.98046875, 'NDCG@1': 0.7626489, 'NDCG@5': 0.9228422, 'NDCG@10': 0.92279387, 'MRR@1': 0.9609375, 'MRR@5': 0.98046875, 'MRR@10': 0.9811198}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 87 {'MAP': 0.9806701, 'NDCG@1': 0.7639913, 'NDCG@5': 0.9232637, 'NDCG@10': 0.92321587, 'MRR@1': 0.9613402, 'MRR@5': 0.9806701, 'MRR@10': 0.9813144}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 88 {'MAP': 0.9808673, 'NDCG@1': 0.76372707, 'NDCG@5': 0.92320555, 'NDCG@10': 0.9231582, 'MRR@1': 0.9617347, 'MRR@5': 0.9808673, 'MRR@10': 0.9815051}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 89 {'MAP': 0.9810606, 'NDCG@1': 0.76611364, 'NDCG@5': 0.92398125, 'NDCG@10': 0.9239344, 'MRR@1': 0.9621212, 'MRR@5': 0.9810606, 'MRR@10': 0.9816919}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 90 {'MAP': 0.980625, 'NDCG@1': 0.7645835, 'NDCG@5': 0.9234557, 'NDCG@10': 0.92340934, 'MRR@1': 0.96125, 'MRR@5': 0.980625, 'MRR@10': 0.98125}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation metrics for epoch: 90 {'MAP': 0.98081684, 'NDCG@1': 0.7650285, 'NDCG@5': 0.923649, 'NDCG@10': 0.9235569, 'MRR@1': 0.9616337, 'MRR@5': 0.98081684, 'MRR@10': 0.98143566}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 91 {'MAP': 0.9810049, 'NDCG@1': 0.76581484, 'NDCG@5': 0.92394495, 'NDCG@10': 0.92385375, 'MRR@1': 0.9620098, 'MRR@5': 0.9810049, 'MRR@10': 0.9816176}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 92 {'MAP': 0.9811893, 'NDCG@1': 0.7662392, 'NDCG@5': 0.9240845, 'NDCG@10': 0.9239942, 'MRR@1': 0.9623786, 'MRR@5': 0.9811893, 'MRR@10': 0.98179615}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 93 {'MAP': 0.9813702, 'NDCG@1': 0.7678001, 'NDCG@5': 0.9246149, 'NDCG@10': 0.9245255, 'MRR@1': 0.96274036, 'MRR@5': 0.9813702, 'MRR@10': 0.98197114}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 94 {'MAP': 0.9815476, 'NDCG@1': 0.7676306, 'NDCG@5': 0.92459214, 'NDCG@10': 0.92450356, 'MRR@1': 0.96309525, 'MRR@5': 0.9815476, 'MRR@10': 0.98214287}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 95 {'MAP': 0.9817217, 'NDCG@1': 0.7698228, 'NDCG@5': 0.9253036, 'NDCG@10': 0.92521584, 'MRR@1': 0.9634434, 'MRR@5': 0.9817217, 'MRR@10': 0.9823113}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 96 {'MAP': 0.9807243, 'NDCG@1': 0.7683023, 'NDCG@5': 0.9247515, 'NDCG@10': 0.92466456, 'MRR@1': 0.9614486, 'MRR@5': 0.9807243, 'MRR@10': 0.9813084}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 97 {'MAP': 0.9797454, 'NDCG@1': 0.76714087, 'NDCG@5': 0.92425805, 'NDCG@10': 0.9241719, 'MRR@1': 0.9594907, 'MRR@5': 0.9797454, 'MRR@10': 0.9803241}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 98 {'MAP': 0.9799312, 'NDCG@1': 0.7682942, 'NDCG@5': 0.9246202, 'NDCG@10': 0.92453486, 'MRR@1': 0.9598624, 'MRR@5': 0.9799312, 'MRR@10': 0.9805046}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training metrics for epoch: 99 {'MAP': 0.9801136, 'NDCG@1': 0.767695, 'NDCG@5': 0.9244149, 'NDCG@10': 0.92433035, 'MRR@1': 0.96022725, 'MRR@5': 0.9801136, 'MRR@10': 0.98068184}\n" ] } ], "source": [ "metrics = task.build_metrics()\n", "model = task.build_model()\n", "task.initialize(model)\n", "train_dataset = task.build_inputs(task_config.train_data)\n", "vali_dataset = task.build_inputs(task_config.validation_data)\n", "train_iterator = iter(train_dataset)\n", "vali_iterator = iter(vali_dataset)\n", "optimizer = tf.keras.optimizers.Adam(lr=1e-6)\n", "\n", "NUM_TRAIN_STEPS = 100\n", "EVAL_STEPS = 10\n", "for train_step in range(NUM_TRAIN_STEPS):\n", " task.train_step(next(train_iterator), model, optimizer, metrics=metrics)\n", " train_metrics = {m.name: m.result().numpy() for m in metrics}\n", " print(\"Training metrics for epoch: \" + str(train_step) + \" \", train_metrics)\n", "\n", " if train_step % EVAL_STEPS == 0:\n", " task.validation_step(next(train_iterator), model, metrics=metrics)\n", " vali_metrics = {m.name: m.result().numpy() for m in metrics}\n", " print(\"Validation metrics for epoch: \" + str(train_step) + \" \",\n", " vali_metrics)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "tfr_bert_orbit.ipynb", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 0 }