diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7ed5bd15..3001b2e0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,6 +1,6 @@ -name: CI +name: CI_TF2 -on: +on: push: path: - 'deepctr/*' @@ -9,7 +9,7 @@ on: path: - 'deepctr/*' - 'tests/*' - + jobs: build: @@ -17,9 +17,9 @@ jobs: timeout-minutes: 180 strategy: matrix: - python-version: [3.6,3.7,3.8,3.9,3.10.7] - tf-version: [1.4.0,1.15.0,2.6.0,2.7.0,2.8.0,2.9.0,2.10.0] - + python-version: [ 3.6,3.7,3.8, 3.9,3.10.7 ] + tf-version: [ 2.6.0,2.7.0,2.8.0,2.9.0,2.10.0 ] + exclude: - python-version: 3.7 tf-version: 1.4.0 @@ -64,31 +64,31 @@ jobs: - python-version: 3.10.7 tf-version: 2.7.0 steps: - - - uses: actions/checkout@v3 - - - name: Setup python environment - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - pip3 install -q tensorflow==${{ matrix.tf-version }} - pip install -q protobuf==3.19.0 - pip install -q requests - pip install -e . - - name: Test with pytest - timeout-minutes: 180 - run: | - pip install -q pytest - pip install -q pytest-cov - pip install -q python-coveralls - pytest --cov=deepctr --cov-report=xml - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3.1.0 - with: - token: ${{secrets.CODECOV_TOKEN}} - file: ./coverage.xml - flags: pytest - name: py${{ matrix.python-version }}-tf${{ matrix.tf-version }} + - uses: actions/checkout@v3 + + - name: Setup python environment + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + pip3 install -q tensorflow==${{ matrix.tf-version }} + pip install -q protobuf==3.19.0 + pip install -q requests + pip install -e . + - name: Test with pytest + timeout-minutes: 180 + run: | + pip install -q pytest + pip install -q pytest-cov + pip install -q python-coveralls + pytest --cov=deepctr --cov-report=xml + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3.1.0 + with: + token: ${{secrets.CODECOV_TOKEN}} + file: ./coverage.xml + flags: pytest + name: py${{ matrix.python-version }}-tf${{ matrix.tf-version }} diff --git a/.github/workflows/ci2.yml b/.github/workflows/ci2.yml new file mode 100644 index 00000000..e9901cb1 --- /dev/null +++ b/.github/workflows/ci2.yml @@ -0,0 +1,96 @@ +name: CI_TF1 + +on: + push: + path: + - 'deepctr/*' + - 'tests/*' + pull_request: + path: + - 'deepctr/*' + - 'tests/*' + +jobs: + build: + + runs-on: ubuntu-latest + timeout-minutes: 180 + strategy: + matrix: + python-version: [ 3.6,3.7 ] + tf-version: [ 1.15.0 ] + + exclude: + - python-version: 3.7 + tf-version: 1.4.0 + - python-version: 3.7 + tf-version: 1.12.0 + - python-version: 3.7 + tf-version: 1.15.0 + - python-version: 3.8 + tf-version: 1.4.0 + - python-version: 3.8 + tf-version: 1.14.0 + - python-version: 3.8 + tf-version: 1.15.0 + - python-version: 3.6 + tf-version: 2.7.0 + - python-version: 3.6 + tf-version: 2.8.0 + - python-version: 3.6 + tf-version: 2.9.0 + - python-version: 3.6 + tf-version: 2.10.0 + - python-version: 3.9 + tf-version: 1.4.0 + - python-version: 3.9 + tf-version: 1.15.0 + - python-version: 3.9 + tf-version: 2.2.0 + - python-version: 3.9 + tf-version: 2.5.0 + - python-version: 3.9 + tf-version: 2.6.0 + - python-version: 3.9 + tf-version: 2.7.0 + - python-version: 3.10.7 + tf-version: 1.4.0 + - python-version: 3.10.7 + tf-version: 1.15.0 + - python-version: 3.10.7 + tf-version: 2.2.0 + - python-version: 3.10.7 + tf-version: 2.5.0 + - python-version: 3.10.7 + tf-version: 2.6.0 + - python-version: 3.10.7 + tf-version: 2.7.0 + steps: + + - uses: actions/checkout@v3 + + - name: Setup python environment + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + pip3 install -q tensorflow==${{ matrix.tf-version }} + pip install -q protobuf==3.19.0 + pip install -q requests + pip install -e . + - name: Test with pytest + timeout-minutes: 180 + run: | + pip install -q pytest + pip install -q pytest-cov + pip install -q python-coveralls + pytest --cov=deepctr --cov-report=xml + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3.1.0 + with: + token: ${{secrets.CODECOV_TOKEN}} + file: ./coverage.xml + flags: pytest + name: py${{ matrix.python-version }}-tf${{ matrix.tf-version }} diff --git a/README.md b/README.md index f0d90c13..ec42fd31 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,7 @@ Introduction](https://zhuanlan.zhihu.com/p/53231955)) and [welcome to join us!]( | ESMM | [SIGIR 2018][Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate](https://arxiv.org/abs/1804.07931) | | MMOE | [KDD 2018][Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts](https://dl.acm.org/doi/abs/10.1145/3219819.3220007) | | PLE | [RecSys 2020][Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations](https://dl.acm.org/doi/10.1145/3383313.3412236) | +| EDCN | [KDD 2021][Enhancing Explicit and Implicit Feature Interactions via Information Sharing for Parallel Deep CTR Models](https://dlp-kdd.github.io/assets/pdf/DLP-KDD_2021_paper_12.pdf) | ## Citation diff --git a/deepctr/__init__.py b/deepctr/__init__.py index 3c6d40b5..7eaabe48 100644 --- a/deepctr/__init__.py +++ b/deepctr/__init__.py @@ -1,4 +1,4 @@ from .utils import check_version -__version__ = '0.9.2' +__version__ = '0.9.3' check_version(__version__) diff --git a/deepctr/layers/__init__.py b/deepctr/layers/__init__.py index 1bfd40ef..18e45011 100644 --- a/deepctr/layers/__init__.py +++ b/deepctr/layers/__init__.py @@ -1,17 +1,16 @@ import tensorflow as tf from .activation import Dice -from .core import DNN, LocalActivationUnit, PredictionLayer +from .core import DNN, LocalActivationUnit, PredictionLayer, RegulationModule from .interaction import (CIN, FM, AFMLayer, BiInteractionPooling, CrossNet, CrossNetMix, InnerProductLayer, InteractingLayer, OutterProductLayer, FGCNNLayer, SENETLayer, BilinearInteraction, - FieldWiseBiInteraction, FwFMLayer, FEFMLayer) + FieldWiseBiInteraction, FwFMLayer, FEFMLayer, BridgeModule) from .normalization import LayerNormalization from .sequence import (AttentionSequencePoolingLayer, BiasEncoding, BiLSTM, KMaxPooling, SequencePoolingLayer, WeightedSequenceLayer, - Transformer, DynamicGRU,PositionEncoding) - -from .utils import NoMask, Hash, Linear, _Add, combined_dnn_input, softmax, reduce_sum + Transformer, DynamicGRU, PositionEncoding) +from .utils import NoMask, Hash, Linear, _Add, combined_dnn_input, softmax, reduce_sum, Concat custom_objects = {'tf': tf, 'InnerProductLayer': InnerProductLayer, @@ -38,6 +37,7 @@ 'FGCNNLayer': FGCNNLayer, 'Hash': Hash, 'Linear': Linear, + 'Concat': Concat, 'DynamicGRU': DynamicGRU, 'SENETLayer': SENETLayer, 'BilinearInteraction': BilinearInteraction, @@ -48,5 +48,7 @@ 'softmax': softmax, 'FEFMLayer': FEFMLayer, 'reduce_sum': reduce_sum, - 'PositionEncoding':PositionEncoding + 'PositionEncoding': PositionEncoding, + 'RegulationModule': RegulationModule, + 'BridgeModule': BridgeModule } diff --git a/deepctr/layers/core.py b/deepctr/layers/core.py index 668348d2..ad249473 100644 --- a/deepctr/layers/core.py +++ b/deepctr/layers/core.py @@ -10,9 +10,9 @@ from tensorflow.python.keras import backend as K try: - from tensorflow.python.ops.init_ops_v2 import Zeros, glorot_normal + from tensorflow.python.ops.init_ops_v2 import Zeros, Ones, glorot_normal except ImportError: - from tensorflow.python.ops.init_ops import Zeros, glorot_normal_initializer as glorot_normal + from tensorflow.python.ops.init_ops import Zeros, Ones, glorot_normal_initializer as glorot_normal from tensorflow.python.keras.layers import Layer, Dropout @@ -265,3 +265,57 @@ def get_config(self, ): config = {'task': self.task, 'use_bias': self.use_bias} base_config = super(PredictionLayer, self).get_config() return dict(list(base_config.items()) + list(config.items())) + + +class RegulationModule(Layer): + """Regulation module used in EDCN. + + Input shape + - 3D tensor with shape: ``(batch_size,field_size,embedding_size)``. + + Output shape + - 2D tensor with shape: ``(batch_size,field_size * embedding_size)``. + + Arguments + - **tau** : Positive float, the temperature coefficient to control + distribution of field-wise gating unit. + + References + - [Enhancing Explicit and Implicit Feature Interactions via Information Sharing for Parallel Deep CTR Models.](https://dlp-kdd.github.io/assets/pdf/DLP-KDD_2021_paper_12.pdf) + """ + + def __init__(self, tau=1.0, **kwargs): + if tau == 0: + raise ValueError("RegulationModule tau can not be zero.") + self.tau = 1.0 / tau + super(RegulationModule, self).__init__(**kwargs) + + def build(self, input_shape): + self.field_size = int(input_shape[1]) + self.embedding_size = int(input_shape[2]) + self.g = self.add_weight( + shape=(1, self.field_size, 1), + initializer=Ones(), + name=self.name + '_field_weight') + + # Be sure to call this somewhere! + super(RegulationModule, self).build(input_shape) + + def call(self, inputs, **kwargs): + + if K.ndim(inputs) != 3: + raise ValueError( + "Unexpected inputs dimensions %d, expect to be 3 dimensions" % (K.ndim(inputs))) + + feild_gating_score = tf.nn.softmax(self.g * self.tau, 1) + E = inputs * feild_gating_score + return tf.reshape(E, [-1, self.field_size * self.embedding_size]) + + def compute_output_shape(self, input_shape): + return (None, self.field_size * self.embedding_size) + + def get_config(self): + config = {'tau': self.tau} + base_config = super(RegulationModule, self).get_config() + base_config.update(config) + return base_config diff --git a/deepctr/layers/interaction.py b/deepctr/layers/interaction.py index d26eb2c1..f76eda32 100644 --- a/deepctr/layers/interaction.py +++ b/deepctr/layers/interaction.py @@ -3,7 +3,8 @@ Authors: Weichen Shen,weichenswc@163.com, - Harshit Pande + Harshit Pande, + Yi He, heyi_jack@163.com """ @@ -26,6 +27,7 @@ from .activation import activation_layer from .utils import concat_func, reduce_sum, softmax, reduce_mean +from .core import DNN class AFMLayer(Layer): @@ -1489,3 +1491,69 @@ def get_config(self): 'regularizer': self.regularizer, }) return config + + +class BridgeModule(Layer): + """Bridge Module used in EDCN + + Input shape + - A list of two 2D tensor with shape: ``(batch_size, units)``. + + Output shape + - 2D tensor with shape: ``(batch_size, units)``. + + Arguments + - **bridge_type**: The type of bridge interaction, one of 'pointwise_addition', 'hadamard_product', 'concatenation', 'attention_pooling' + + - **activation**: Activation function to use. + + References + - [Enhancing Explicit and Implicit Feature Interactions via Information Sharing for Parallel Deep CTR Models.](https://dlp-kdd.github.io/assets/pdf/DLP-KDD_2021_paper_12.pdf) + + """ + + def __init__(self, bridge_type='hadamard_product', activation='relu', **kwargs): + self.bridge_type = bridge_type + self.activation = activation + + super(BridgeModule, self).__init__(**kwargs) + + def build(self, input_shape): + if not isinstance(input_shape, list) or len(input_shape) < 2: + raise ValueError( + 'A `BridgeModule` layer should be called ' + 'on a list of 2 inputs') + + self.dnn_dim = int(input_shape[0][-1]) + if self.bridge_type == "concatenation": + self.dense = Dense(self.dnn_dim, self.activation) + elif self.bridge_type == "attention_pooling": + self.dense_x = DNN([self.dnn_dim, self.dnn_dim], self.activation, output_activation='softmax') + self.dense_h = DNN([self.dnn_dim, self.dnn_dim], self.activation, output_activation='softmax') + + super(BridgeModule, self).build(input_shape) # Be sure to call this somewhere! + + def call(self, inputs, **kwargs): + x, h = inputs + if self.bridge_type == "pointwise_addition": + return x + h + elif self.bridge_type == "hadamard_product": + return x * h + elif self.bridge_type == "concatenation": + return self.dense(tf.concat([x, h], axis=-1)) + elif self.bridge_type == "attention_pooling": + a_x = self.dense_x(x) + a_h = self.dense_h(h) + return a_x * x + a_h * h + + def compute_output_shape(self, input_shape): + return (None, self.dnn_dim) + + def get_config(self): + base_config = super(BridgeModule, self).get_config().copy() + config = { + 'bridge_type': self.bridge_type, + 'activation': self.activation + } + config.update(base_config) + return config diff --git a/deepctr/layers/sequence.py b/deepctr/layers/sequence.py index 93866640..6b8b93b6 100644 --- a/deepctr/layers/sequence.py +++ b/deepctr/layers/sequence.py @@ -11,10 +11,9 @@ from tensorflow.python.keras import backend as K try: - from tensorflow.python.ops.init_ops import TruncatedNormal, glorot_uniform_initializer as glorot_uniform, \ - identity_initializer as identity + from tensorflow.python.ops.init_ops import TruncatedNormal, Constant, glorot_uniform_initializer as glorot_uniform except ImportError: - from tensorflow.python.ops.init_ops_v2 import TruncatedNormal, glorot_uniform, identity + from tensorflow.python.ops.init_ops_v2 import TruncatedNormal, Constant, glorot_uniform from tensorflow.python.keras.layers import LSTM, Lambda, Layer, Dropout @@ -387,7 +386,7 @@ def call(self, inputs, mask=None, **kwargs): elif self.merge_mode == "bw": output = output_bw elif self.merge_mode == 'concat': - output = K.concatenate([output_fw, output_bw]) + output = tf.concat([output_fw, output_bw], axis=-1) elif self.merge_mode == 'sum': output = output_fw + output_bw elif self.merge_mode == 'ave': @@ -530,7 +529,7 @@ def call(self, inputs, mask=None, training=None, **kwargs): if self.use_positional_encoding: queries = self.query_pe(queries) - keys = self.key_pe(queries) + keys = self.key_pe(keys) Q = tf.tensordot(queries, self.W_Query, axes=(-1, 0)) # N T_q D*h @@ -665,7 +664,7 @@ def build(self, input_shape): if self.zero_pad: position_enc[0, :] = np.zeros(num_units) self.lookup_table = self.add_weight("lookup_table", (T, num_units), - initializer=identity(position_enc), + initializer=Constant(position_enc), trainable=self.pos_embedding_trainable) # Be sure to call this somewhere! @@ -867,52 +866,3 @@ def get_config(self, ): config = {'k': self.k, 'axis': self.axis} base_config = super(KMaxPooling, self).get_config() return dict(list(base_config.items()) + list(config.items())) - -# def positional_encoding(inputs, -# pos_embedding_trainable=True, -# zero_pad=False, -# scale=True, -# ): -# '''Sinusoidal Positional_Encoding. -# -# Args: -# -# - inputs: A 2d Tensor with shape of (N, T). -# - num_units: Output dimensionality -# - zero_pad: Boolean. If True, all the values of the first row (id = 0) should be constant zero -# - scale: Boolean. If True, the output will be multiplied by sqrt num_units(check details from paper) -# - scope: Optional scope for `variable_scope`. -# - reuse: Boolean, whether to reuse the weights of a previous layer by the same name. -# -# Returns: -# -# - A 'Tensor' with one more rank than inputs's, with the dimensionality should be 'num_units' -# ''' -# -# _, T, num_units = inputs.get_shape().as_list() -# # with tf.variable_scope(scope, reuse=reuse): -# position_ind = tf.expand_dims(tf.range(T), 0) -# # First part of the PE function: sin and cos argument -# position_enc = np.array([ -# [pos / np.power(10000, 2. * i / num_units) -# for i in range(num_units)] -# for pos in range(T)]) -# -# # Second part, apply the cosine to even columns and sin to odds. -# position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i -# position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1 -# -# # Convert to a tensor -# -# if pos_embedding_trainable: -# lookup_table = K.variable(position_enc, dtype=tf.float32) -# -# if zero_pad: -# lookup_table = tf.concat((tf.zeros(shape=[1, num_units]), -# lookup_table[1:, :]), 0) -# -# outputs = tf.nn.embedding_lookup(lookup_table, position_ind) -# -# if scale: -# outputs = outputs * num_units ** 0.5 -# return outputs + inputs diff --git a/deepctr/layers/utils.py b/deepctr/layers/utils.py index 2be8f3fe..07eec6e0 100644 --- a/deepctr/layers/utils.py +++ b/deepctr/layers/utils.py @@ -6,7 +6,8 @@ """ import tensorflow as tf -from tensorflow.python.keras.layers import Flatten, Concatenate, Layer, Add +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.layers import Flatten, Layer, Add from tensorflow.python.ops.lookup_ops import TextFileInitializer try: @@ -185,13 +186,60 @@ def get_config(self, ): return dict(list(base_config.items()) + list(config.items())) +class Concat(Layer): + def __init__(self, axis, supports_masking=True, **kwargs): + super(Concat, self).__init__(**kwargs) + self.axis = axis + self.supports_masking = supports_masking + + def call(self, inputs): + return tf.concat(inputs, axis=self.axis) + + def compute_mask(self, inputs, mask=None): + if not self.supports_masking: + return None + if mask is None: + mask = [inputs_i._keras_mask if hasattr(inputs_i, "_keras_mask") else None for inputs_i in inputs] + if mask is None: + return None + if not isinstance(mask, list): + raise ValueError('`mask` should be a list.') + if not isinstance(inputs, list): + raise ValueError('`inputs` should be a list.') + if len(mask) != len(inputs): + raise ValueError('The lists `inputs` and `mask` ' + 'should have the same length.') + if all([m is None for m in mask]): + return None + # Make a list of masks while making sure + # the dimensionality of each mask + # is the same as the corresponding input. + masks = [] + for input_i, mask_i in zip(inputs, mask): + if mask_i is None: + # Input is unmasked. Append all 1s to masks, + masks.append(tf.ones_like(input_i, dtype='bool')) + elif K.ndim(mask_i) < K.ndim(input_i): + # Mask is smaller than the input, expand it + masks.append(tf.expand_dims(mask_i, axis=-1)) + else: + masks.append(mask_i) + concatenated = K.concatenate(masks, axis=self.axis) + return K.all(concatenated, axis=-1, keepdims=False) + + def get_config(self, ): + config = {'axis': self.axis, 'supports_masking': self.supports_masking} + base_config = super(Concat, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def concat_func(inputs, axis=-1, mask=False): - if not mask: - inputs = list(map(NoMask(), inputs)) if len(inputs) == 1: - return inputs[0] - else: - return Concatenate(axis=axis)(inputs) + input = inputs[0] + if not mask: + input = NoMask()(input) + return input + return Concat(axis, supports_masking=mask)(inputs) def reduce_mean(input_tensor, @@ -271,10 +319,6 @@ def build(self, input_shape): super(_Add, self).build(input_shape) def call(self, inputs, **kwargs): - # if not isinstance(inputs, list): - # return inputs - # if len(inputs) == 1: - # return inputs[0] if len(inputs) == 0: return tf.constant([[0.0]]) diff --git a/deepctr/models/__init__.py b/deepctr/models/__init__.py index 2d19714b..1d797e78 100644 --- a/deepctr/models/__init__.py +++ b/deepctr/models/__init__.py @@ -20,7 +20,8 @@ from .sequence import DIN, DIEN, DSIN, BST from .wdl import WDL from .xdeepfm import xDeepFM +from .edcn import EDCN __all__ = ["AFM", "CCPM", "DCN", "IFM", "DIFM", "DCNMix", "MLR", "DeepFM", "MLR", "NFM", "DIN", "DIEN", "FNN", "PNN", "WDL", "xDeepFM", "AutoInt", "ONN", "FGCNN", "DSIN", "FiBiNET", 'FLEN', "FwFM", "BST", "DeepFEFM", - "SharedBottom", "ESMM", "MMOE", "PLE"] + "SharedBottom", "ESMM", "MMOE", "PLE", 'EDCN'] diff --git a/deepctr/models/edcn.py b/deepctr/models/edcn.py new file mode 100644 index 00000000..973d6391 --- /dev/null +++ b/deepctr/models/edcn.py @@ -0,0 +1,94 @@ +# -*- coding:utf-8 -*- +""" +Author: + Yi He, heyi_jack@163.com + +Reference: + [1] Chen, B., Wang, Y., Liu, et al. Enhancing Explicit and Implicit Feature Interactions via Information Sharing for Parallel Deep CTR Models. CIKM, 2021, October (https://dlp-kdd.github.io/assets/pdf/DLP-KDD_2021_paper_12.pdf) +""" +from tensorflow.python.keras.layers import Dense, Reshape, Concatenate +from tensorflow.python.keras.models import Model + +from ..feature_column import build_input_features, get_linear_logit, input_from_feature_columns +from ..layers.core import PredictionLayer, DNN, RegulationModule +from ..layers.interaction import CrossNet, BridgeModule +from ..layers.utils import add_func, concat_func + + +def EDCN(linear_feature_columns, + dnn_feature_columns, + cross_num=2, + cross_parameterization='vector', + bridge_type='concatenation', + tau=1.0, + l2_reg_linear=1e-5, + l2_reg_embedding=1e-5, + l2_reg_cross=1e-5, + l2_reg_dnn=0, + seed=1024, + dnn_dropout=0, + dnn_use_bn=False, + dnn_activation='relu', + task='binary'): + """Instantiates the Enhanced Deep&Cross Network architecture. + + :param linear_feature_columns: An iterable containing all the features used by linear part of the model. + :param dnn_feature_columns: An iterable containing all the features used by deep part of the model. + :param cross_num: positive integet,cross layer number + :param cross_parameterization: str, ``"vector"`` or ``"matrix"``, how to parameterize the cross network. + :param bridge_type: The type of bridge interaction, one of ``"pointwise_addition"``, ``"hadamard_product"``, ``"concatenation"`` , ``"attention_pooling"`` + :param tau: Positive float, the temperature coefficient to control distribution of field-wise gating unit + :param l2_reg_linear: float. L2 regularizer strength applied to linear part + :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector + :param l2_reg_cross: float. L2 regularizer strength applied to cross net + :param l2_reg_dnn: float. L2 regularizer strength applied to DNN + :param seed: integer ,to use as random seed. + :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. + :param dnn_use_bn: bool. Whether use BatchNormalization before activation or not DNN + :param dnn_activation: Activation function to use in DNN + :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss + :return: A Keras model instance. + + """ + if cross_num == 0: + raise ValueError("Cross layer num must > 0") + + print('EDCN brige type: ', bridge_type) + + features = build_input_features(dnn_feature_columns) + inputs_list = list(features.values()) + + linear_logit = get_linear_logit(features, linear_feature_columns, seed=seed, prefix='linear', l2_reg=l2_reg_linear) + + sparse_embedding_list, _ = input_from_feature_columns( + features, dnn_feature_columns, l2_reg_embedding, seed, support_dense=False) + + emb_input = concat_func(sparse_embedding_list, axis=1) + deep_in = RegulationModule(tau)(emb_input) + cross_in = RegulationModule(tau)(emb_input) + + field_size = len(sparse_embedding_list) + embedding_size = int(sparse_embedding_list[0].shape[-1]) + cross_dim = field_size * embedding_size + + for i in range(cross_num): + cross_out = CrossNet(1, parameterization=cross_parameterization, + l2_reg=l2_reg_cross)(cross_in) + deep_out = DNN([cross_dim], dnn_activation, l2_reg_dnn, + dnn_dropout, dnn_use_bn, seed=seed)(deep_in) + print(cross_out, deep_out) + bridge_out = BridgeModule(bridge_type)([cross_out, deep_out]) + if i + 1 < cross_num: + bridge_out_list = Reshape([field_size, embedding_size])(bridge_out) + deep_in = RegulationModule(tau)(bridge_out_list) + cross_in = RegulationModule(tau)(bridge_out_list) + + stack_out = Concatenate()([cross_out, deep_out, bridge_out]) + final_logit = Dense(1, use_bias=False)(stack_out) + + final_logit = add_func([final_logit, linear_logit]) + output = PredictionLayer(task)(final_logit) + + model = Model(inputs=inputs_list, outputs=output) + + return model diff --git a/deepctr/models/sequence/din.py b/deepctr/models/sequence/din.py index 14877a7a..84b7b432 100644 --- a/deepctr/models/sequence/din.py +++ b/deepctr/models/sequence/din.py @@ -6,15 +6,15 @@ Reference: [1] Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068. (https://arxiv.org/pdf/1706.06978.pdf) """ +from tensorflow.python.keras.layers import Dense, Flatten from tensorflow.python.keras.models import Model -from tensorflow.python.keras.layers import Dense, Concatenate, Flatten from ...feature_column import SparseFeat, VarLenSparseFeat, DenseFeat, build_input_features from ...inputs import create_embedding_matrix, embedding_lookup, get_dense_input, varlen_embedding_lookup, \ get_varlen_pooling_list from ...layers.core import DNN, PredictionLayer from ...layers.sequence import AttentionSequencePoolingLayer -from ...layers.utils import concat_func, NoMask, combined_dnn_input +from ...layers.utils import concat_func, combined_dnn_input def DIN(dnn_feature_columns, history_feature_list, dnn_use_bn=False, @@ -84,7 +84,7 @@ def DIN(dnn_feature_columns, history_feature_list, dnn_use_bn=False, weight_normalization=att_weight_normalization, supports_masking=True)([ query_emb, keys_emb]) - deep_input_emb = Concatenate()([NoMask()(deep_input_emb), hist]) + deep_input_emb = concat_func([deep_input_emb, hist]) deep_input_emb = Flatten()(deep_input_emb) dnn_input = combined_dnn_input([deep_input_emb], dense_value_list) output = DNN(dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout, dnn_use_bn, seed=seed)(dnn_input) diff --git a/deepctr/models/sequence/dsin.py b/deepctr/models/sequence/dsin.py index c7c2ea1a..f02f89cb 100644 --- a/deepctr/models/sequence/dsin.py +++ b/deepctr/models/sequence/dsin.py @@ -24,7 +24,7 @@ def DSIN(dnn_feature_columns, sess_feature_list, sess_max_count=5, bias_encoding=False, - att_embedding_size=1, att_head_num=8, dnn_hidden_units=(256, 128, 64), dnn_activation='sigmoid', dnn_dropout=0, + att_embedding_size=1, att_head_num=8, dnn_hidden_units=(256, 128, 64), dnn_activation='relu', dnn_dropout=0, dnn_use_bn=False, l2_reg_dnn=0, l2_reg_embedding=1e-6, seed=1024, task='binary', ): """Instantiates the Deep Session Interest Network architecture. diff --git a/docs/pics/EDCN.png b/docs/pics/EDCN.png new file mode 100644 index 00000000..fea8cd5e Binary files /dev/null and b/docs/pics/EDCN.png differ diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md index 41cbc3b6..5e7dfdf6 100644 --- a/docs/source/FAQ.md +++ b/docs/source/FAQ.md @@ -128,7 +128,7 @@ from deepctr.models import DeepFM from deepctr.feature_column import SparseFeat,get_feature_names pretrained_item_weights = np.random.randn(60,4) -pretrained_weights_initializer = tf.initializers.identity(pretrained_item_weights) +pretrained_weights_initializer = tf.initializers.constant(pretrained_item_weights) feature_columns = [SparseFeat('user_id',120,),SparseFeat('item_id',60,embedding_dim=4,embeddings_initializer=pretrained_weights_initializer,trainable=False)] fixlen_feature_names = get_feature_names(feature_columns) diff --git a/docs/source/Features.md b/docs/source/Features.md index c9a9a550..6acee071 100644 --- a/docs/source/Features.md +++ b/docs/source/Features.md @@ -304,6 +304,17 @@ feature. FEFM has significantly lower model complexity than FFM and roughly the [Pande H. Field-Embedded Factorization Machines for Click-through rate prediction[J]. arXiv preprint arXiv:2009.09931, 2020.](https://arxiv.org/pdf/2009.09931) +### EDCN(Enhancing Explicit and Implicit Feature Interactions DCN) + +EDCN introduces two advanced modules, namelybridge moduleandregulation module, which work collaboratively tocapture the layer-wise interactive signals and learn discriminativefeature distributions for each hidden layer of the parallel networks. + +[**EDCN Model API**](./deepctr.models.edcn.html) + +![EDCN](../pics/EDCN.png) + +[Chen B, Wang Y, Liu Z, et al. Enhancing explicit and implicit feature interactions via information sharing for parallel deep ctr models[C]//Proceedings of the 30th ACM International Conference on Information & Knowledge Management. 2021: 3757-3766.](https://dlp-kdd.github.io/assets/pdf/DLP-KDD_2021_paper_12.pdf) + + ## Sequence Models ### DIN (Deep Interest Network) @@ -413,6 +424,8 @@ information routing across tasks in a general setup. [Tang H, Liu J, Zhao M, et al. Progressive layered extraction (ple): A novel multi-task learning (mtl) model for personalized recommendations[C]//Fourteenth ACM Conference on Recommender Systems. 2020.](https://dl.acm.org/doi/10.1145/3383313.3412236) + + ## Layers The models of deepctr are modular, so you can use different modules to build your own models. diff --git a/docs/source/History.md b/docs/source/History.md index 2e19942a..8735d457 100644 --- a/docs/source/History.md +++ b/docs/source/History.md @@ -1,4 +1,5 @@ # History +- 11/10/2022 : [v0.9.3](https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.3) released.Add [EDCN](./Features.html#edcn-enhancing-explicit-and-implicit-feature-interactions-dcn). - 10/15/2022 : [v0.9.2](https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.2) released.Support python `3.9`,`3.10`. - 06/11/2022 : [v0.9.1](https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.1) released.Improve compatibility with tensorflow `2.x`. - 09/03/2021 : [v0.9.0](https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.0) released.Add multitask learning models:[SharedBottom](./Features.html#sharedbottom),[ESMM](./Features.html#esmm-entire-space-multi-task-model),[MMOE](./Features.html#mmoe-multi-gate-mixture-of-experts) and [PLE](./Features.html#ple-progressive-layered-extraction). [running example](./Examples.html#multitask-learning-mmoe) @@ -10,8 +11,8 @@ - 10/11/2020 : [v0.8.2](https://github.com/shenweichen/DeepCTR/releases/tag/v0.8.2) released.Refactor `DNN` Layer. - 09/12/2020 : [v0.8.1](https://github.com/shenweichen/DeepCTR/releases/tag/v0.8.1) released.Improve the reproducibility & fix some bugs. - 06/27/2020 : [v0.8.0](https://github.com/shenweichen/DeepCTR/releases/tag/v0.8.0) released. - - Support `Tensorflow Estimator` for large scale data and distributed training. [example: Estimator with TFRecord](https://deepctr-doc.readthedocs.io/en/latest/Examples.html#estimator-with-tfrecord-classification-criteo) - - Support different initializers for different embedding weights and loading pretrained embeddings. [example](https://deepctr-doc.readthedocs.io/en/latest/FAQ.html#how-to-use-pretrained-weights-to-initialize-embedding-weights-and-frozen-embedding-weights) + - Support `Tensorflow Estimator` for large scale data and distributed training. [example: Estimator with TFRecord](./Examples.html#estimator-with-tfrecord-classification-criteo) + - Support different initializers for different embedding weights and loading pretrained embeddings. [example](./FAQ.html#how-to-use-pretrained-weights-to-initialize-embedding-weights-and-frozen-embedding-weights) - Add new model `FwFM`. - 05/17/2020 : [v0.7.5](https://github.com/shenweichen/DeepCTR/releases/tag/v0.7.5) released.Fix numerical instability in `LayerNormalization`. - 03/15/2020 : [v0.7.4](https://github.com/shenweichen/DeepCTR/releases/tag/v0.7.4) released.Add [FLEN](./Features.html#flen-field-leveraged-embedding-network) and `FieldWiseBiInteraction`. diff --git a/docs/source/Models.rst b/docs/source/Models.rst index a3f5691e..4f864184 100644 --- a/docs/source/Models.rst +++ b/docs/source/Models.rst @@ -30,5 +30,6 @@ DeepCTR Models API ESMM MMOE PLE + EDCN \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index d0f0df24..e0ae9c06 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,7 +26,7 @@ # The short X.Y version version = '' # The full version, including alpha/beta/rc tags -release = '0.9.2' +release = '0.9.3' # -- General configuration --------------------------------------------------- diff --git a/docs/source/deepctr.models.edcn.rst b/docs/source/deepctr.models.edcn.rst new file mode 100644 index 00000000..3772f3b5 --- /dev/null +++ b/docs/source/deepctr.models.edcn.rst @@ -0,0 +1,7 @@ +deepctr.models.edcn module +========================= + +.. automodule:: deepctr.models.edcn + :members: + :no-undoc-members: + :no-show-inheritance: diff --git a/docs/source/deepctr.models.rst b/docs/source/deepctr.models.rst index 2b4e9e18..4acf2a12 100644 --- a/docs/source/deepctr.models.rst +++ b/docs/source/deepctr.models.rst @@ -11,6 +11,7 @@ Submodules deepctr.models.ccpm deepctr.models.dcn deepctr.models.dcnmix + deepctr.models.edcn deepctr.models.deepfm deepctr.models.dien deepctr.models.din diff --git a/docs/source/index.rst b/docs/source/index.rst index 0330a10d..93316678 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -42,11 +42,11 @@ You can read the latest code and related projects News ----- -10/15/2022 : Support python `3.9`,`3.10`. `Changelog `_ +11/10/2022 : Add `EDCN` . `Changelog `_ -06/11/2022 : Improve compatibility with tensorflow `2.x`. `Changelog `_ +10/15/2022 : Support python `3.9` , `3.10` . `Changelog `_ -09/03/2021 : Add multitask learning models: `SharedBottom <./Features.html#sharedbottom>`_ , `ESMM <./Features.html#esmm-entire-space-multi-task-model>`_ , `MMOE <./Features.html#mmoe-multi-gate-mixture-of-experts>`_ , `PLE <./Features.html#ple-progressive-layered-extraction>`_ . `running example <./Examples.html#multitask-learning-mmoe>`_ `Changelog `_ +06/11/2022 : Improve compatibility with tensorflow `2.x`. `Changelog `_ DisscussionGroup ----------------------- diff --git a/setup.py b/setup.py index 43eee556..c389ea41 100644 --- a/setup.py +++ b/setup.py @@ -1,21 +1,19 @@ +import sys + import setuptools -with open("README.md", "r",encoding='utf-8') as fh: +with open("README.md", "r") as fh: long_description = fh.read() -import sys -if sys.version_info < (3, 9): - REQUIRED_PACKAGES = [ - 'h5py==2.10.0', 'requests' - ] -else: - REQUIRED_PACKAGES = [ - 'h5py==3.7.0', 'requests' - ] +REQUIRED_PACKAGES = [ + 'requests', + 'h5py==3.7.0; python_version>="3.9"', + 'h5py==2.10.0; python_version<"3.9"' +] setuptools.setup( name="deepctr", - version="0.9.2", + version="0.9.3", author="Weichen Shen", author_email="weichenswc@163.com", description="Easy-to-use,Modular and Extendible package of deep learning based CTR(Click Through Rate) prediction models with tensorflow 1.x and 2.x .", diff --git a/tests/models/EDCN_test.py b/tests/models/EDCN_test.py new file mode 100644 index 00000000..f01f7fe0 --- /dev/null +++ b/tests/models/EDCN_test.py @@ -0,0 +1,28 @@ +import pytest + +from deepctr.models import EDCN +from ..utils import check_model, get_test_data, SAMPLE_SIZE + + +@pytest.mark.parametrize( + 'bridge_type, cross_num, cross_parameterization, sparse_feature_num', + [ + ('pointwise_addition', 2, 'vector', 3), + ('hadamard_product', 2, 'vector', 4), + ('concatenation', 1, 'vector', 5), + ('attention_pooling', 2, 'matrix', 6), + ] +) +def test_EDCN(bridge_type, cross_num, cross_parameterization, sparse_feature_num): + model_name = "EDCN" + + sample_size = SAMPLE_SIZE + x, y, feature_columns = get_test_data(sample_size, sparse_feature_num=sparse_feature_num, + dense_feature_num=0) + + model = EDCN(feature_columns, feature_columns, cross_num, cross_parameterization, bridge_type) + check_model(model, model_name, x, y) + + +if __name__ == "__main__": + pass diff --git a/tests/models/xDeepFM_test.py b/tests/models/xDeepFM_test.py index 3981e229..b350ad28 100644 --- a/tests/models/xDeepFM_test.py +++ b/tests/models/xDeepFM_test.py @@ -61,4 +61,4 @@ def test_xDeepFMEstimator(dnn_hidden_units, cin_layer_size, cin_split_half, cin_ if __name__ == "__main__": - pass + pass \ No newline at end of file