Skip to content

Commit

Permalink
Adding SetRankEncoder layer (#206)
Browse files Browse the repository at this point in the history
* Adding rank as default required field (#153) (#155)

* Upgrading to tensorflow==2.7.0 (#156)

- Upgrading to tensorflow==2.7.0 - update interaction layer #156
- Update to tensorflow==2.7.0 - update model #157
- Update to tensorflow==2.7.0 - update requirements and build #158

* Adding setrank

* Deleting testfiles

* Renaming

* Renaming

* Redefining intrinic failure metric

* Fixing metrics_helper bug

* Updating base config

* Pinning dependencies (#164)

* Ensuring dependency compatibility

* adding python file to test CI trigger

* Updating tensorflow to 2.4 in jvm

* Checkpoint

* Checkpoint - build passing

* Checkpoint - all integration tests passing

* Fixing issue

* Updating tensorflow version

* Fixing issues after merging with master

* Updating circleci config

* Updating circleci config

* Increasing coverage threshold

* Adding Auxiliary loss for ranking models

* Adding AutoDagNetwork

* Updating architecture factory

* Updating architecture keys

* Clean up

* Renaming config

* Checkpointing

* Upgrading to tensorflow==2.9.x

* Renaming config

* Changing keras.layers.merge.Concatenate

* Changing to keras.layers.merging.concatenate.Concatenate

* Adding tf-models

* Renaming to TransformerEncoderBlock

* Fixing key in set_Rank

* Implementing SetRank using tf-models

* Fixing issues

* Fixing issues

* Fixing issues

* Fixing tests

* Cleaning up SetRankEncoderKey

* Removing SetEncoderLayerKey

* Clean up

* Adding tests

* Fixing issues

* Fixing issues

* Adding more unit tests

* Adding set_rank_encoder DNNLayerKey

* Removing tf-models from optional requirements

* Updating changelog

---------

Co-authored-by: Arvind Srikantan <[email protected]>
  • Loading branch information
lastmansleeping and arvindsrikantan authored Feb 2, 2023
1 parent 1870f2b commit 8e7d1eb
Show file tree
Hide file tree
Showing 13 changed files with 409 additions and 81 deletions.
6 changes: 2 additions & 4 deletions docs/source/misc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- AutoDAGNetwork which allows for building flexible connected architectures using config files

### Removed

- RankMatchFailure
- SetRankEncoder keras Layer to train SetRank like Ranking models
- Support for using tf-models-official deep learning garden library

## [0.1.14] - 2022-11-18

Expand Down
3 changes: 3 additions & 0 deletions python/build-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ numpy==1.21.6
pandas==1.2.1
scipy==1.5.4

# transformers
tf-models-official==2.9.2

# pytest
pytest==6.2.1
pytest-cov==2.11.0
Expand Down
1 change: 1 addition & 0 deletions python/ml4ir/applications/ranking/model/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ml4ir.applications.ranking.model.layers.set_rank_encoder import *
104 changes: 104 additions & 0 deletions python/ml4ir/applications/ranking/model/layers/set_rank_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import tensorflow as tf
from tensorflow.keras import layers
import tensorflow_models as tfm


class SetRankEncoder(layers.Layer):
"""
SetRank architecture layer that maps features for a document -> encoding
using a permutation invariant multi-head self attention technique attending to all documents in the query.
Inspired from the transformer architecture as described in the following paper
* Liang Pang, Jun Xu, Qingyao Ai, Yanyan Lan, Xueqi Cheng, Jirong Wen. 2020.
SetRank: Learning a Permutation-Invariant Ranking Model for Information Retrieval. In Proceedings of SIGIR '20
Reference -> https://arxiv.org/pdf/1912.05891.pdf
"""

def __init__(self,
encoding_size: int,
requires_mask: bool = True,
projection_dropout_rate: float = 0.0,
**kwargs):
"""
Parameters
----------
encoding_size: int
Size of the projection which will serve as both the input and output size to the encoder
requires_mask: bool
Indicates if the layer requires a mask to be passed to it during forward pass
projection_dropout_rate: float
Dropout rate to be applied after the input projection layer
kwargs:
Additional key-value args that will be used for configuring the TransformerEncoder
Notes
-----
For a full list of args that can be passed to configure this layer, please check the below official layer doc
https://www.tensorflow.org/api_docs/python/tfm/nlp/models/TransformerEncoder
"""
super(SetRankEncoder, self).__init__()

self.requires_mask = requires_mask
self.encoding_size = encoding_size
self.projection_dropout_rate = projection_dropout_rate

self.input_projection_op = layers.Dense(units=self.encoding_size)
self.projection_dropout_op = layers.Dropout(rate=self.projection_dropout_rate)
self.transformer_encoder = tfm.nlp.models.TransformerEncoder(**kwargs)

def call(self, inputs, mask=None, training=None):
"""
Invoke the set transformer encoder (permutation invariant) for the input feature tensor
Parameters
----------
inputs: Tensor object
Input ranking feature tensor
Shape: [batch_size, sequence_len, num_features]
mask: Tensor object
Mask to be used as the attention mask for the TransformerEncoder
to indicate which documents to not attend to in the query
Shape: [batch_size, sequence_len]
training: bool
If the layer should be run as training or not
Returns
-------
Tensor object
Set transformer encoder (permutation invariant) output tensor
Shape: [batch_size, sequence_len, encoding_size]
"""
# Project input from shape
# [batch_size, sequence_len, num_features] -> [batch_size, sequence_len, encoding_size]
encoder_inputs = self.input_projection_op(inputs, training=training)
encoder_inputs = self.projection_dropout_op(encoder_inputs, training=training)

# Compute attention mask if mask is present
attention_mask = None
if self.requires_mask:
# Mask encoder inputs after projection
encoder_inputs = tf.transpose(
tf.multiply(
tf.transpose(encoder_inputs),
tf.transpose(tf.cast(mask, encoder_inputs.dtype))
)
)

# Convert 2D mask to 3D mask to be used for attention
attention_mask = tf.matmul(mask[:, :, tf.newaxis], mask[:, tf.newaxis, :])

encoder_output = self.transformer_encoder(encoder_inputs=encoder_inputs,
attention_mask=attention_mask,
training=training)

return encoder_output

def get_config(self):
config = self.transformer_encoder.get_config()
config.update({
"encoding_size": self.encoding_size,
"projection_dropout_rate": self.projection_dropout_rate
})

return config
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
architecture_key: dnn
layers:
- type: set_rank_encoder
requires_mask: true
encoding_size: 128
projection_dropout_rate: 0.0
num_layers: 4
num_attention_heads: 8
intermediate_size: 128
dropout_rate: 0.1
- type: dense
name: first_dense
units: 256
activation: relu
- type: dropout
name: first_dropout
rate: 0.0
- type: dense
name: second_dense
units: 64
activation: relu
- type: dropout
name: second_dropout
rate: 0.0
- type: dense
name: final_dense
units: 1
activation: null
optimizer:
key: adam
gradient_clip_value: 5.0
lr_schedule:
key: constant
learning_rate: 0.001
8 changes: 8 additions & 0 deletions python/ml4ir/applications/ranking/tests/test_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ def test_model_serving_auto_dag_network(self):
model_config_path="ml4ir/applications/ranking/tests/data/configs/model_config_auto_dag_network.yaml"
)

def test_model_serving_set_rank(self):
"""
Train a simple auto-dag-network model and test serving flow by loading the SavedModel
"""
self.check_model_serving(
model_config_path="ml4ir/applications/ranking/tests/data/configs/model_config_set_rank.yaml"
)

def test_serving_n_records(self):
"""Test serving signature with different number of records"""
feature_config: FeatureConfig = self.get_feature_config()
Expand Down
142 changes: 142 additions & 0 deletions python/ml4ir/applications/ranking/tests/test_set_rank_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import unittest

import tensorflow as tf
from tensorflow.experimental.numpy import isclose
import numpy as np

from ml4ir.applications.ranking.model.layers.set_rank_encoder import SetRankEncoder


class TestSetRankEncoder(unittest.TestCase):
"""Unit tests for ml4ir.applications.ranking.model.layers.set_rank_encoder"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.num_features = 16
self.encoding_size = 8

def validate_encoder_output(self, encoder_output, mask):
# Check that the encodings for masked inputs are all the same in a given query
for i in range(mask.shape[0]):
# Skip query if all documents are unmasked
if tf.reduce_sum(mask[i]) == len(mask[i]):
continue
encoder_output_for_masked_i = tf.gather_nd(encoder_output[i], tf.where(mask[i] == 0))
self.assertTrue(tf.reduce_all(tf.reduce_all(
isclose(encoder_output_for_masked_i, encoder_output_for_masked_i[0, :]), axis=1)).numpy())

# Check that the encodings for unmasked inputs are not all the same
for i in range(mask.shape[0]):
# Skip query if there is only 1 unmasked document
if tf.reduce_sum(mask[i]) <= 1:
continue
encoder_output_i = tf.gather_nd(encoder_output[i], tf.where(mask[i] == 1))
self.assertFalse(tf.reduce_all(tf.reduce_all(
isclose(encoder_output_i, encoder_output_i[0, :]), axis=1)).numpy())

def test_input_projection_op(self):
"""Test if the input_projection_op behaves as expected"""
encoder = SetRankEncoder(self.encoding_size)

x = np.random.randn(4, 5, self.num_features)
projected_input = encoder.input_projection_op(x)

self.assertEqual(x.shape[0], projected_input.shape[0])
self.assertEqual(x.shape[1], projected_input.shape[1])
self.assertNotEqual(x.shape[2], projected_input.shape[2])

# Check if input feature dimension gets mapped to encoding dimension
self.assertEqual(projected_input.shape[2], self.encoding_size)

def test_transformer_encoder(self):
"""Test is the transformer_encoder behaves as expected"""
encoder = SetRankEncoder(self.encoding_size)

x = np.random.randn(4, 5, self.encoding_size)
encoder_output = encoder.transformer_encoder(x)

self.assertEqual(x.shape[0], encoder_output.shape[0])
self.assertEqual(x.shape[1], encoder_output.shape[1])
self.assertEqual(x.shape[2], encoder_output.shape[2])

self.assertEqual(encoder_output.shape[2], self.encoding_size)

def test_transformer_encoder_with_attention_mask(self):
"""Test is the transformer_encoder behaves as expected with attention mask"""
encoder = SetRankEncoder(self.encoding_size)

x = np.random.randn(4, 5, self.encoding_size)
mask = np.random.binomial(n=1, p=0.5, size=[4 * 5]).reshape(4, 5)
attention_mask = np.matmul(mask[:, :, np.newaxis], mask[:, np.newaxis, :])
masked_x = (x.T * mask.T).T

encoder_output = encoder.transformer_encoder(masked_x, attention_mask)

self.assertEqual(x.shape[0], encoder_output.shape[0])
self.assertEqual(x.shape[1], encoder_output.shape[1])
self.assertEqual(x.shape[2], encoder_output.shape[2])

self.assertEqual(encoder_output.shape[2], self.encoding_size)

self.validate_encoder_output(encoder_output, mask)

def test_transformer_encoder_kwargs(self):
"""Test if passing additional key-value args to TransformerEncoder works"""
encoder = SetRankEncoder(self.encoding_size,
num_layers=4,
intermediate_size=32)
config = encoder.get_config()

# Default Args
self.assertEqual(config["projection_dropout_rate"], 0.)
self.assertEqual(config["num_attention_heads"], 8)

# Passed Args
self.assertEqual(config["encoding_size"], self.encoding_size)
self.assertEqual(config["num_layers"], 4)
self.assertEqual(config["intermediate_size"], 32)

def test_set_rank_encoder_layer(self):
"""Test the SetRankEncoder call() function end-to-end"""
encoder = SetRankEncoder(self.encoding_size,
requires_mask=False)

x = np.random.randn(4, 5, self.num_features)
encoder_output = encoder(x)

self.assertEqual(x.shape[0], encoder_output.shape[0])
self.assertEqual(x.shape[1], encoder_output.shape[1])
self.assertNotEqual(x.shape[2], encoder_output.shape[2])

# Check if input feature dimension gets mapped to encoding dimension
self.assertEqual(encoder_output.shape[2], self.encoding_size)

def test_set_rank_encoder_layer_with_dropout(self):
"""Test the SetRankEncoder call() function end-to-end with dropout"""
encoder = SetRankEncoder(self.encoding_size,
requires_mask=False,
dropout_rate=0.5)

x = np.random.randn(4, 5, self.num_features)

self.assertFalse(tf.reduce_all(encoder(x, training=True) == encoder(x, training=True)).numpy())
self.assertFalse(tf.reduce_all(encoder(x, training=False) == encoder(x, training=True)).numpy())
self.assertTrue(tf.reduce_all(encoder(x, training=False) == encoder(x, training=False)).numpy())

def test_set_rank_encoder_layer_with_mask(self):
"""Test the SetRankEncoder call() function end-to-end with mask"""
encoder = SetRankEncoder(self.encoding_size)

x = np.random.randn(4, 5, self.num_features)
mask = np.random.binomial(n=1, p=0.5, size=[4 * 5]).reshape(4, 5)

encoder_output = encoder(x, mask)

self.assertEqual(x.shape[0], encoder_output.shape[0])
self.assertEqual(x.shape[1], encoder_output.shape[1])
self.assertNotEqual(x.shape[2], encoder_output.shape[2])

self.assertEqual(encoder_output.shape[2], self.encoding_size)

self.validate_encoder_output(encoder_output, mask)
6 changes: 2 additions & 4 deletions python/ml4ir/base/data/tfrecord_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,10 +806,8 @@ def read(

if parse_tfrecord:
# Parallel calls set to AUTOTUNE: improved training performance by 40% with a classification model
dataset = (
dataset.map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.apply(data.experimental.ignore_errors())
)
dataset = (dataset.map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.apply(data.experimental.ignore_errors()))

# Create BatchedDataSet
if batch_size:
Expand Down
Loading

0 comments on commit 8e7d1eb

Please sign in to comment.