-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Adding rank as default required field (#153) (#155) * Upgrading to tensorflow==2.7.0 (#156) - Upgrading to tensorflow==2.7.0 - update interaction layer #156 - Update to tensorflow==2.7.0 - update model #157 - Update to tensorflow==2.7.0 - update requirements and build #158 * Adding setrank * Deleting testfiles * Renaming * Renaming * Redefining intrinic failure metric * Fixing metrics_helper bug * Updating base config * Pinning dependencies (#164) * Ensuring dependency compatibility * adding python file to test CI trigger * Updating tensorflow to 2.4 in jvm * Checkpoint * Checkpoint - build passing * Checkpoint - all integration tests passing * Fixing issue * Updating tensorflow version * Fixing issues after merging with master * Updating circleci config * Updating circleci config * Increasing coverage threshold * Adding Auxiliary loss for ranking models * Adding AutoDagNetwork * Updating architecture factory * Updating architecture keys * Clean up * Renaming config * Checkpointing * Upgrading to tensorflow==2.9.x * Renaming config * Changing keras.layers.merge.Concatenate * Changing to keras.layers.merging.concatenate.Concatenate * Adding tf-models * Renaming to TransformerEncoderBlock * Fixing key in set_Rank * Implementing SetRank using tf-models * Fixing issues * Fixing issues * Fixing issues * Fixing tests * Cleaning up SetRankEncoderKey * Removing SetEncoderLayerKey * Clean up * Adding tests * Fixing issues * Fixing issues * Adding more unit tests * Adding set_rank_encoder DNNLayerKey * Removing tf-models from optional requirements * Updating changelog --------- Co-authored-by: Arvind Srikantan <[email protected]>
- Loading branch information
1 parent
1870f2b
commit 8e7d1eb
Showing
13 changed files
with
409 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from ml4ir.applications.ranking.model.layers.set_rank_encoder import * |
104 changes: 104 additions & 0 deletions
104
python/ml4ir/applications/ranking/model/layers/set_rank_encoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import tensorflow as tf | ||
from tensorflow.keras import layers | ||
import tensorflow_models as tfm | ||
|
||
|
||
class SetRankEncoder(layers.Layer): | ||
""" | ||
SetRank architecture layer that maps features for a document -> encoding | ||
using a permutation invariant multi-head self attention technique attending to all documents in the query. | ||
Inspired from the transformer architecture as described in the following paper | ||
* Liang Pang, Jun Xu, Qingyao Ai, Yanyan Lan, Xueqi Cheng, Jirong Wen. 2020. | ||
SetRank: Learning a Permutation-Invariant Ranking Model for Information Retrieval. In Proceedings of SIGIR '20 | ||
Reference -> https://arxiv.org/pdf/1912.05891.pdf | ||
""" | ||
|
||
def __init__(self, | ||
encoding_size: int, | ||
requires_mask: bool = True, | ||
projection_dropout_rate: float = 0.0, | ||
**kwargs): | ||
""" | ||
Parameters | ||
---------- | ||
encoding_size: int | ||
Size of the projection which will serve as both the input and output size to the encoder | ||
requires_mask: bool | ||
Indicates if the layer requires a mask to be passed to it during forward pass | ||
projection_dropout_rate: float | ||
Dropout rate to be applied after the input projection layer | ||
kwargs: | ||
Additional key-value args that will be used for configuring the TransformerEncoder | ||
Notes | ||
----- | ||
For a full list of args that can be passed to configure this layer, please check the below official layer doc | ||
https://www.tensorflow.org/api_docs/python/tfm/nlp/models/TransformerEncoder | ||
""" | ||
super(SetRankEncoder, self).__init__() | ||
|
||
self.requires_mask = requires_mask | ||
self.encoding_size = encoding_size | ||
self.projection_dropout_rate = projection_dropout_rate | ||
|
||
self.input_projection_op = layers.Dense(units=self.encoding_size) | ||
self.projection_dropout_op = layers.Dropout(rate=self.projection_dropout_rate) | ||
self.transformer_encoder = tfm.nlp.models.TransformerEncoder(**kwargs) | ||
|
||
def call(self, inputs, mask=None, training=None): | ||
""" | ||
Invoke the set transformer encoder (permutation invariant) for the input feature tensor | ||
Parameters | ||
---------- | ||
inputs: Tensor object | ||
Input ranking feature tensor | ||
Shape: [batch_size, sequence_len, num_features] | ||
mask: Tensor object | ||
Mask to be used as the attention mask for the TransformerEncoder | ||
to indicate which documents to not attend to in the query | ||
Shape: [batch_size, sequence_len] | ||
training: bool | ||
If the layer should be run as training or not | ||
Returns | ||
------- | ||
Tensor object | ||
Set transformer encoder (permutation invariant) output tensor | ||
Shape: [batch_size, sequence_len, encoding_size] | ||
""" | ||
# Project input from shape | ||
# [batch_size, sequence_len, num_features] -> [batch_size, sequence_len, encoding_size] | ||
encoder_inputs = self.input_projection_op(inputs, training=training) | ||
encoder_inputs = self.projection_dropout_op(encoder_inputs, training=training) | ||
|
||
# Compute attention mask if mask is present | ||
attention_mask = None | ||
if self.requires_mask: | ||
# Mask encoder inputs after projection | ||
encoder_inputs = tf.transpose( | ||
tf.multiply( | ||
tf.transpose(encoder_inputs), | ||
tf.transpose(tf.cast(mask, encoder_inputs.dtype)) | ||
) | ||
) | ||
|
||
# Convert 2D mask to 3D mask to be used for attention | ||
attention_mask = tf.matmul(mask[:, :, tf.newaxis], mask[:, tf.newaxis, :]) | ||
|
||
encoder_output = self.transformer_encoder(encoder_inputs=encoder_inputs, | ||
attention_mask=attention_mask, | ||
training=training) | ||
|
||
return encoder_output | ||
|
||
def get_config(self): | ||
config = self.transformer_encoder.get_config() | ||
config.update({ | ||
"encoding_size": self.encoding_size, | ||
"projection_dropout_rate": self.projection_dropout_rate | ||
}) | ||
|
||
return config |
34 changes: 34 additions & 0 deletions
34
python/ml4ir/applications/ranking/tests/data/configs/model_config_set_rank.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
architecture_key: dnn | ||
layers: | ||
- type: set_rank_encoder | ||
requires_mask: true | ||
encoding_size: 128 | ||
projection_dropout_rate: 0.0 | ||
num_layers: 4 | ||
num_attention_heads: 8 | ||
intermediate_size: 128 | ||
dropout_rate: 0.1 | ||
- type: dense | ||
name: first_dense | ||
units: 256 | ||
activation: relu | ||
- type: dropout | ||
name: first_dropout | ||
rate: 0.0 | ||
- type: dense | ||
name: second_dense | ||
units: 64 | ||
activation: relu | ||
- type: dropout | ||
name: second_dropout | ||
rate: 0.0 | ||
- type: dense | ||
name: final_dense | ||
units: 1 | ||
activation: null | ||
optimizer: | ||
key: adam | ||
gradient_clip_value: 5.0 | ||
lr_schedule: | ||
key: constant | ||
learning_rate: 0.001 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
142 changes: 142 additions & 0 deletions
142
python/ml4ir/applications/ranking/tests/test_set_rank_encoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import unittest | ||
|
||
import tensorflow as tf | ||
from tensorflow.experimental.numpy import isclose | ||
import numpy as np | ||
|
||
from ml4ir.applications.ranking.model.layers.set_rank_encoder import SetRankEncoder | ||
|
||
|
||
class TestSetRankEncoder(unittest.TestCase): | ||
"""Unit tests for ml4ir.applications.ranking.model.layers.set_rank_encoder""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
self.num_features = 16 | ||
self.encoding_size = 8 | ||
|
||
def validate_encoder_output(self, encoder_output, mask): | ||
# Check that the encodings for masked inputs are all the same in a given query | ||
for i in range(mask.shape[0]): | ||
# Skip query if all documents are unmasked | ||
if tf.reduce_sum(mask[i]) == len(mask[i]): | ||
continue | ||
encoder_output_for_masked_i = tf.gather_nd(encoder_output[i], tf.where(mask[i] == 0)) | ||
self.assertTrue(tf.reduce_all(tf.reduce_all( | ||
isclose(encoder_output_for_masked_i, encoder_output_for_masked_i[0, :]), axis=1)).numpy()) | ||
|
||
# Check that the encodings for unmasked inputs are not all the same | ||
for i in range(mask.shape[0]): | ||
# Skip query if there is only 1 unmasked document | ||
if tf.reduce_sum(mask[i]) <= 1: | ||
continue | ||
encoder_output_i = tf.gather_nd(encoder_output[i], tf.where(mask[i] == 1)) | ||
self.assertFalse(tf.reduce_all(tf.reduce_all( | ||
isclose(encoder_output_i, encoder_output_i[0, :]), axis=1)).numpy()) | ||
|
||
def test_input_projection_op(self): | ||
"""Test if the input_projection_op behaves as expected""" | ||
encoder = SetRankEncoder(self.encoding_size) | ||
|
||
x = np.random.randn(4, 5, self.num_features) | ||
projected_input = encoder.input_projection_op(x) | ||
|
||
self.assertEqual(x.shape[0], projected_input.shape[0]) | ||
self.assertEqual(x.shape[1], projected_input.shape[1]) | ||
self.assertNotEqual(x.shape[2], projected_input.shape[2]) | ||
|
||
# Check if input feature dimension gets mapped to encoding dimension | ||
self.assertEqual(projected_input.shape[2], self.encoding_size) | ||
|
||
def test_transformer_encoder(self): | ||
"""Test is the transformer_encoder behaves as expected""" | ||
encoder = SetRankEncoder(self.encoding_size) | ||
|
||
x = np.random.randn(4, 5, self.encoding_size) | ||
encoder_output = encoder.transformer_encoder(x) | ||
|
||
self.assertEqual(x.shape[0], encoder_output.shape[0]) | ||
self.assertEqual(x.shape[1], encoder_output.shape[1]) | ||
self.assertEqual(x.shape[2], encoder_output.shape[2]) | ||
|
||
self.assertEqual(encoder_output.shape[2], self.encoding_size) | ||
|
||
def test_transformer_encoder_with_attention_mask(self): | ||
"""Test is the transformer_encoder behaves as expected with attention mask""" | ||
encoder = SetRankEncoder(self.encoding_size) | ||
|
||
x = np.random.randn(4, 5, self.encoding_size) | ||
mask = np.random.binomial(n=1, p=0.5, size=[4 * 5]).reshape(4, 5) | ||
attention_mask = np.matmul(mask[:, :, np.newaxis], mask[:, np.newaxis, :]) | ||
masked_x = (x.T * mask.T).T | ||
|
||
encoder_output = encoder.transformer_encoder(masked_x, attention_mask) | ||
|
||
self.assertEqual(x.shape[0], encoder_output.shape[0]) | ||
self.assertEqual(x.shape[1], encoder_output.shape[1]) | ||
self.assertEqual(x.shape[2], encoder_output.shape[2]) | ||
|
||
self.assertEqual(encoder_output.shape[2], self.encoding_size) | ||
|
||
self.validate_encoder_output(encoder_output, mask) | ||
|
||
def test_transformer_encoder_kwargs(self): | ||
"""Test if passing additional key-value args to TransformerEncoder works""" | ||
encoder = SetRankEncoder(self.encoding_size, | ||
num_layers=4, | ||
intermediate_size=32) | ||
config = encoder.get_config() | ||
|
||
# Default Args | ||
self.assertEqual(config["projection_dropout_rate"], 0.) | ||
self.assertEqual(config["num_attention_heads"], 8) | ||
|
||
# Passed Args | ||
self.assertEqual(config["encoding_size"], self.encoding_size) | ||
self.assertEqual(config["num_layers"], 4) | ||
self.assertEqual(config["intermediate_size"], 32) | ||
|
||
def test_set_rank_encoder_layer(self): | ||
"""Test the SetRankEncoder call() function end-to-end""" | ||
encoder = SetRankEncoder(self.encoding_size, | ||
requires_mask=False) | ||
|
||
x = np.random.randn(4, 5, self.num_features) | ||
encoder_output = encoder(x) | ||
|
||
self.assertEqual(x.shape[0], encoder_output.shape[0]) | ||
self.assertEqual(x.shape[1], encoder_output.shape[1]) | ||
self.assertNotEqual(x.shape[2], encoder_output.shape[2]) | ||
|
||
# Check if input feature dimension gets mapped to encoding dimension | ||
self.assertEqual(encoder_output.shape[2], self.encoding_size) | ||
|
||
def test_set_rank_encoder_layer_with_dropout(self): | ||
"""Test the SetRankEncoder call() function end-to-end with dropout""" | ||
encoder = SetRankEncoder(self.encoding_size, | ||
requires_mask=False, | ||
dropout_rate=0.5) | ||
|
||
x = np.random.randn(4, 5, self.num_features) | ||
|
||
self.assertFalse(tf.reduce_all(encoder(x, training=True) == encoder(x, training=True)).numpy()) | ||
self.assertFalse(tf.reduce_all(encoder(x, training=False) == encoder(x, training=True)).numpy()) | ||
self.assertTrue(tf.reduce_all(encoder(x, training=False) == encoder(x, training=False)).numpy()) | ||
|
||
def test_set_rank_encoder_layer_with_mask(self): | ||
"""Test the SetRankEncoder call() function end-to-end with mask""" | ||
encoder = SetRankEncoder(self.encoding_size) | ||
|
||
x = np.random.randn(4, 5, self.num_features) | ||
mask = np.random.binomial(n=1, p=0.5, size=[4 * 5]).reshape(4, 5) | ||
|
||
encoder_output = encoder(x, mask) | ||
|
||
self.assertEqual(x.shape[0], encoder_output.shape[0]) | ||
self.assertEqual(x.shape[1], encoder_output.shape[1]) | ||
self.assertNotEqual(x.shape[2], encoder_output.shape[2]) | ||
|
||
self.assertEqual(encoder_output.shape[2], self.encoding_size) | ||
|
||
self.validate_encoder_output(encoder_output, mask) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.