diff --git a/docs/source/train.md b/docs/source/train.md index 871d6eb30..e58bb6862 100644 --- a/docs/source/train.md +++ b/docs/source/train.md @@ -80,21 +80,21 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2 #### 使用单个损失函数 -| 损失函数 | 说明 | +| 损失函数 | 说明 | | ------------------------------------------ | ---------------------------------------------------------- | | CLASSIFICATION | 分类Loss,二分类为sigmoid_cross_entropy;多分类为softmax_cross_entropy | | L2_LOSS | 平方损失 | -| SIGMOID_L2_LOSS | 对sigmoid函数的结果计算平方损失 | +| SIGMOID_L2_LOSS | 对sigmoid函数的结果计算平方损失 | | CROSS_ENTROPY_LOSS | log loss 负对数损失 | | CIRCLE_LOSS | CoMetricLearningI2I模型专用 | | MULTI_SIMILARITY_LOSS | CoMetricLearningI2I模型专用 | -| SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING | 自动负采样版本的多分类softmax_cross_entropy,用在二分类任务中 | -| BINARY_FOCAL_LOSS | 支持困难样本挖掘和类别平衡的focal loss | -| PAIR_WISE_LOSS | 以优化全局AUC为目标的rank loss | -| PAIRWISE_FOCAL_LOSS | pair粒度的focal loss, 支持自定义pair分组 | -| PAIRWISE_LOGISTIC_LOSS | pair粒度的logistic loss, 支持自定义pair分组 | -| JRC_LOSS | 二分类 + listwise ranking loss | -| F1_REWEIGHTED_LOSS | 可以调整二分类召回率和准确率相对权重的损失函数,可有效对抗正负样本不平衡问题 | +| SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING | 自动负采样版本的多分类softmax_cross_entropy,用在二分类任务中 | +| BINARY_FOCAL_LOSS | 支持困难样本挖掘和类别平衡的focal loss | +| PAIR_WISE_LOSS | 以优化全局AUC为目标的rank loss | +| PAIRWISE_FOCAL_LOSS | pair粒度的focal loss, 支持自定义pair分组 | +| PAIRWISE_LOGISTIC_LOSS | pair粒度的logistic loss, 支持自定义pair分组 | +| JRC_LOSS | 二分类 + listwise ranking loss | +| F1_REWEIGHTED_LOSS | 可以调整二分类召回率和准确率相对权重的损失函数,可有效对抗正负样本不平衡问题 | - 说明:SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING - 支持参数配置,升级为 [support vector guided softmax loss](https://128.84.21.199/abs/1812.11317) , @@ -153,40 +153,49 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2 - f1_beta_square 即为 上述公式中的 beta 系数的平方。 - PAIRWISE_FOCAL_LOSS 的参数配置 + - gamma: focal loss的指数,默认值2.0 - - alpha: 调节样本权重的类别平衡参数,建议根据正负样本比例来配置alpha, $\frac{\alpha}{1-\alpha}=\frac{#Neg}{#Pos}$ + - alpha: 调节样本权重的类别平衡参数,建议根据正负样本比例来配置alpha, $\\frac{\\alpha}{1-\\alpha}=\\frac{#Neg}{#Pos}$ - session_name: pair分组的字段名,比如user_id - hinge_margin: 当pair的logit之差大于该参数值时,当前样本的loss为0,默认值为1.0 - ohem_ratio: 困难样本的百分比,只有部分困难样本参与loss计算,默认值为1.0 - temperature: 温度系数,logit除以该参数值后再参与计算,默认值为1.0 - PAIRWISE_LOGISTIC_LOSS 的参数配置 + - session_name: pair分组的字段名,比如user_id - hinge_margin: 当pair的logit之差大于该参数值时,当前样本的loss为0,默认值为1.0 - ohem_ratio: 困难样本的百分比,只有部分困难样本参与loss计算,默认值为1.0 - temperature: 温度系数,logit除以该参数值后再参与计算,默认值为1.0 - PAIRWISE_LOSS 的参数配置 + - session_name: pair分组的字段名,比如user_id - margin: 当pair的logit之差减去该参数值后再参与计算,即正负样本的logit之差至少要大于margin,默认值为0 - temperature: 温度系数,logit除以该参数值后再参与计算,默认值为1.0 -备注:上述 PAIRWISE_*_LOSS 都是在mini-batch内构建正负样本pair,目标是让正负样本pair的logit相差尽可能大 +备注:上述 PAIRWISE\_\*\_LOSS 都是在mini-batch内构建正负样本pair,目标是让正负样本pair的logit相差尽可能大 - BINARY_FOCAL_LOSS 的参数配置 + - gamma: focal loss的指数,默认值2.0 - - alpha: 调节样本权重的类别平衡参数,建议根据正负样本比例来配置alpha, $\frac{\alpha}{1-\alpha}=\frac{#Neg}{#Pos}$ + - alpha: 调节样本权重的类别平衡参数,建议根据正负样本比例来配置alpha, $\\frac{\\alpha}{1-\\alpha}=\\frac{#Neg}{#Pos}$ - ohem_ratio: 困难样本的百分比,只有部分困难样本参与loss计算,默认值为1.0 - label_smoothing: 标签平滑系数 - JRC_LOSS 的参数配置 + - alpha: ranking loss 与 calibration loss 的相对权重系数;不设置该值时,触发权重自适应学习 - session_name: list分组的字段名,比如user_id - 参考论文:《 [Joint Optimization of Ranking and Calibration with Contextualized Hybrid Model](https://arxiv.org/pdf/2208.06164.pdf) 》 + - 使用示例: [dbmtl_with_jrc_loss.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/dbmtl_on_taobao_with_multi_loss.config) 排序模型同时使用多个损失函数的完整示例: [cmbf_with_multi_loss.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/cmbf_with_multi_loss.config) +多目标排序模型同时使用多个损失函数的完整示例: +[dbmtl_with_multi_loss.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/dbmtl_on_taobao_with_multi_loss.config) + ##### 损失函数权重自适应学习 多目标学习任务中,人工指定多个损失函数的静态权重通常不能获得最好的效果。EasyRec支持损失函数权重自适应学习,示例如下: diff --git a/easy_rec/python/compat/early_stopping.py b/easy_rec/python/compat/early_stopping.py index fc850fb62..fe4c12132 100644 --- a/easy_rec/python/compat/early_stopping.py +++ b/easy_rec/python/compat/early_stopping.py @@ -21,9 +21,9 @@ import os import threading import time +from distutils.version import LooseVersion import tensorflow as tf -from distutils.version import LooseVersion from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import init_ops diff --git a/easy_rec/python/compat/weight_decay_optimizers.py b/easy_rec/python/compat/weight_decay_optimizers.py index 7c9baf905..47a755e0f 100755 --- a/easy_rec/python/compat/weight_decay_optimizers.py +++ b/easy_rec/python/compat/weight_decay_optimizers.py @@ -472,5 +472,4 @@ def __init__(self, use_locking=use_locking, name=name) except ImportError: - print('import AdamAsyncOptimizer failed when loading AdamAsyncWOptimizer') pass diff --git a/easy_rec/python/feature_column/feature_column.py b/easy_rec/python/feature_column/feature_column.py index 94a9cd132..04fc07baf 100644 --- a/easy_rec/python/feature_column/feature_column.py +++ b/easy_rec/python/feature_column/feature_column.py @@ -86,6 +86,8 @@ def _cmp_embed_config(a, b): 'shared embed info of [%s] is not matched [%s] vs [%s]' % ( embed_name, config, self._share_embed_infos[embed_name]) self._share_embed_names[embed_name] += 1 + if config.feature_type == FeatureConfig.FeatureType.SequenceFeature: + self._share_embed_infos[embed_name] = copy_obj(config) else: self._share_embed_names[embed_name] = 1 self._share_embed_infos[embed_name] = copy_obj(config) @@ -156,6 +158,11 @@ def _cmp_embed_config(a, b): combiner=self._share_embed_infos[embed_name].combiner, partitioner=partitioner, ev_params=ev_params) + config = self._share_embed_infos[embed_name] + max_seq_len = config.max_seq_len if config.HasField( + 'max_seq_len') else -1 + for fc in share_embed_fcs: + fc.max_seq_length = max_seq_len self._deep_share_embed_columns[embed_name] = share_embed_fcs # for handling wide share embedding columns @@ -168,6 +175,11 @@ def _cmp_embed_config(a, b): combiner='sum', partitioner=partitioner, ev_params=ev_params) + config = self._share_embed_infos[embed_name] + max_seq_len = config.max_seq_len if config.HasField( + 'max_seq_len') else -1 + for fc in share_embed_fcs: + fc.max_seq_length = max_seq_len self._wide_share_embed_columns[embed_name] = share_embed_fcs for fc_name in self._deep_columns: diff --git a/easy_rec/python/input/input.py b/easy_rec/python/input/input.py index 4aec1ed17..52581b4e2 100644 --- a/easy_rec/python/input/input.py +++ b/easy_rec/python/input/input.py @@ -141,7 +141,8 @@ def __init__(self, model_name = model_config.WhichOneof('model') if model_name in {'mmoe', 'esmm', 'dbmtl', 'simple_multi_task', 'ple'}: model = getattr(model_config, model_name) - towers = [model.ctr_tower, model.cvr_tower] if model_name == 'esmm' else model.task_towers + towers = [model.ctr_tower, model.cvr_tower + ] if model_name == 'esmm' else model.task_towers for tower in towers: metrics = tower.metrics_set for metric in metrics: diff --git a/easy_rec/python/layers/cmbf.py b/easy_rec/python/layers/cmbf.py index 2c6ed8444..b633bac2b 100644 --- a/easy_rec/python/layers/cmbf.py +++ b/easy_rec/python/layers/cmbf.py @@ -325,6 +325,10 @@ def merge_text_embedding(self, txt_embeddings, input_masks): return txt_embeddings def __call__(self, is_training, *args, **kwargs): + if not is_training: + self._model_config.hidden_dropout_prob = 0.0 + self._model_config.attention_probs_dropout_prob = 0.0 + # shape: [batch_size, image_num/image_dim, hidden_size] img_attention_fea = self.image_self_attention_tower() diff --git a/easy_rec/python/layers/common_layers.py b/easy_rec/python/layers/common_layers.py index 80ad1496f..165fce5e1 100644 --- a/easy_rec/python/layers/common_layers.py +++ b/easy_rec/python/layers/common_layers.py @@ -1,29 +1,12 @@ # -*- encoding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -import numpy as np import tensorflow as tf if tf.__version__ >= '2.0': tf = tf.compat.v1 -def gelu(x): - """Gaussian Error Linear Unit. - - This is a smoother version of the RELU. - Original paper: https://arxiv.org/abs/1606.08415 - Args: - x: float Tensor to perform activation. - - Returns: - `x` with the GELU activation applied. - """ - cdf = 0.5 * (1.0 + tf.tanh( - (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) - return x * cdf - - def highway(x, size=None, activation=None, diff --git a/easy_rec/python/layers/dnn.py b/easy_rec/python/layers/dnn.py index 4fdce37ba..7a57f5661 100644 --- a/easy_rec/python/layers/dnn.py +++ b/easy_rec/python/layers/dnn.py @@ -4,7 +4,7 @@ import tensorflow as tf -from easy_rec.python.utils.load_class import load_by_path +from easy_rec.python.utils.activation import get_activation if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -25,7 +25,7 @@ def __init__(self, dnn_config: instance of easy_rec.python.protos.dnn_pb2.DNN l2_reg: l2 regularizer name: scope of the DNN, so that the parameters could be separated from other dnns - is_training: train phase or not, impact batchnorm and dropout + is_training: train phase or not, impact batch_norm and dropout last_layer_no_activation: in last layer, use or not use activation last_layer_no_batch_norm: in last layer, use or not use batch norm """ @@ -34,7 +34,8 @@ def __init__(self, self._name = name self._is_training = is_training logging.info('dnn activation function = %s' % self._config.activation) - self.activation = load_by_path(self._config.activation) + self.activation = get_activation( + self._config.activation, training=is_training) self._last_layer_no_activation = last_layer_no_activation self._last_layer_no_batch_norm = last_layer_no_batch_norm diff --git a/easy_rec/python/layers/multihead_cross_attention.py b/easy_rec/python/layers/multihead_cross_attention.py index 911ff7bae..511b2711d 100644 --- a/easy_rec/python/layers/multihead_cross_attention.py +++ b/easy_rec/python/layers/multihead_cross_attention.py @@ -6,11 +6,10 @@ import math -import six import tensorflow as tf from easy_rec.python.compat.layers import layer_norm as tf_layer_norm -from easy_rec.python.layers.common_layers import gelu +from easy_rec.python.utils.activation import gelu from easy_rec.python.utils.shape_utils import get_shape_list if tf.__version__ >= '2.0': @@ -53,7 +52,8 @@ def attention_layer(from_tensor, do_return_2d_tensor=False, batch_size=None, from_seq_length=None, - to_seq_length=None): + to_seq_length=None, + reuse=None): """Performs multi-headed attention from `from_tensor` to `to_tensor`. This is an implementation of multi-headed attention based on "Attention is all you Need". @@ -96,6 +96,7 @@ def attention_layer(from_tensor, of the 3D version of the `from_tensor`. to_seq_length: (Optional) If the input is 2D, this might be the seq length of the 3D version of the `to_tensor`. + reuse: whether to reuse this layer Returns: float Tensor of shape [batch_size, from_seq_length, @@ -149,7 +150,8 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, num_attention_heads * size_per_head, activation=query_act, name='query', - kernel_initializer=create_initializer(initializer_range)) + kernel_initializer=create_initializer(initializer_range), + reuse=reuse) # `key_layer` = [B*T, N*H] key_layer = tf.layers.dense( @@ -157,7 +159,8 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, num_attention_heads * size_per_head, activation=key_act, name='key', - kernel_initializer=create_initializer(initializer_range)) + kernel_initializer=create_initializer(initializer_range), + reuse=reuse) # `value_layer` = [B*T, N*H] value_layer = tf.layers.dense( @@ -165,7 +168,8 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, num_attention_heads * size_per_head, activation=value_act, name='value', - kernel_initializer=create_initializer(initializer_range)) + kernel_initializer=create_initializer(initializer_range), + reuse=reuse) # `query_layer` = [B, N, F, H] query_layer = transpose_for_scores(query_layer, batch_size, @@ -242,6 +246,7 @@ def transformer_encoder(input_tensor, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, initializer_range=0.02, + reuse=None, name='transformer'): """Multi-headed, multi-layer Transformer from "Attention is All You Need". @@ -265,6 +270,7 @@ def transformer_encoder(input_tensor, probabilities. initializer_range: float. Range of the initializer (stddev of truncated normal). + reuse: whether to reuse this encoder name: scope name prefix Returns: @@ -315,11 +321,12 @@ def transformer_encoder(input_tensor, do_return_2d_tensor=True, batch_size=batch_size, from_seq_length=seq_length, - to_seq_length=seq_length) + to_seq_length=seq_length, + reuse=reuse) # Run a linear projection of `hidden_size` then add a residual # with `layer_input`. - with tf.variable_scope('output'): + with tf.variable_scope('output', reuse=reuse): attention_output = tf.layers.dense( attention_output, hidden_size, @@ -328,7 +335,7 @@ def transformer_encoder(input_tensor, attention_output = layer_norm(attention_output + layer_input) # The activation is only applied to the "intermediate" hidden layer. - with tf.variable_scope('intermediate'): + with tf.variable_scope('intermediate', reuse=reuse): intermediate_output = tf.layers.dense( attention_output, intermediate_size, @@ -336,7 +343,7 @@ def transformer_encoder(input_tensor, kernel_initializer=create_initializer(initializer_range)) # Down-project back to `hidden_size` then add the residual. - with tf.variable_scope('output'): + with tf.variable_scope('output', reuse=reuse): layer_output = tf.layers.dense( intermediate_output, hidden_size, @@ -640,6 +647,7 @@ def embedding_postprocessor(input_tensor, reuse_token_type=None, use_position_embeddings=True, position_embedding_name='position_embeddings', + reuse_position_embedding=None, initializer_range=0.02, max_position_embeddings=512, dropout_prob=0.1): @@ -659,6 +667,7 @@ def embedding_postprocessor(input_tensor, position of each token in the sequence. position_embedding_name: string. The name of the embedding table variable for positional embeddings. + reuse_position_embedding: bool. Whether to reuse position embedding variable. initializer_range: float. Range of the weight initialization. max_position_embeddings: int. Maximum sequence length that might ever be used with this model. This can be longer than the sequence length of @@ -699,7 +708,8 @@ def embedding_postprocessor(input_tensor, if use_position_embeddings: assert_op = tf.assert_less_equal(seq_length, max_position_embeddings) with tf.control_dependencies([assert_op]): - full_position_embeddings = tf.get_variable( + with tf.variable_scope("position_embedding", reuse=reuse_position_embedding): + full_position_embeddings = tf.get_variable( name=position_embedding_name, shape=[max_position_embeddings, width], initializer=create_initializer(initializer_range)) @@ -736,41 +746,3 @@ def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): output_tensor = layer_norm(input_tensor, name) output_tensor = dropout(output_tensor, dropout_prob) return output_tensor - - -def get_activation(activation_string): - """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. - - Args: - activation_string: String name of the activation function. - - Returns: - A Python function corresponding to the activation function. If - `activation_string` is None, empty, or "linear", this will return None. - If `activation_string` is not a string, it will return `activation_string`. - - Raises: - ValueError: The `activation_string` does not correspond to a known - activation. - """ - # We assume that anything that's not a string is already an activation - # function, so we just return it. - if not isinstance(activation_string, six.string_types): - return activation_string - - if not activation_string: - return None - - act = activation_string.lower() - if act == 'linear': - return None - elif act == 'relu': - return tf.nn.relu - elif act == 'gelu': - return gelu - elif act == 'tanh': - return tf.tanh - elif act == 'swish': - return tf.nn.swish - else: - raise ValueError('Unsupported activation: %s' % act) diff --git a/easy_rec/python/layers/uniter.py b/easy_rec/python/layers/uniter.py index 96b9cdc46..fa5c6a3ca 100644 --- a/easy_rec/python/layers/uniter.py +++ b/easy_rec/python/layers/uniter.py @@ -4,6 +4,7 @@ from easy_rec.python.layers import dnn from easy_rec.python.layers import multihead_cross_attention +from easy_rec.python.utils.activation import get_activation from easy_rec.python.utils.shape_utils import get_shape_list if tf.__version__ >= '2.0': @@ -223,6 +224,10 @@ def image_embeddings(self): return img_fea def __call__(self, is_training, *args, **kwargs): + if not is_training: + self._model_config.hidden_dropout_prob = 0.0 + self._model_config.attention_probs_dropout_prob = 0.0 + sub_modules = [] img_fea = self.image_embeddings() @@ -261,8 +266,7 @@ def __call__(self, is_training, *args, **kwargs): input_mask = tf.concat(masks, axis=1) attention_mask = multihead_cross_attention.create_attention_mask_from_input_mask( from_tensor=all_fea, to_mask=input_mask) - hidden_act = multihead_cross_attention.get_activation( - self._model_config.hidden_act) + hidden_act = get_activation(self._model_config.hidden_act) attention_fea = multihead_cross_attention.transformer_encoder( all_fea, hidden_size=hidden_size, diff --git a/easy_rec/python/model/collaborative_metric_learning.py b/easy_rec/python/model/collaborative_metric_learning.py index 84c87ccaa..d785e7141 100644 --- a/easy_rec/python/model/collaborative_metric_learning.py +++ b/easy_rec/python/model/collaborative_metric_learning.py @@ -3,12 +3,12 @@ from easy_rec.python.core.metrics import metric_learning_average_precision_at_k from easy_rec.python.core.metrics import metric_learning_recall_at_k from easy_rec.python.layers import dnn -from easy_rec.python.layers.common_layers import gelu from easy_rec.python.layers.common_layers import highway from easy_rec.python.loss.circle_loss import circle_loss from easy_rec.python.loss.multi_similarity import ms_loss from easy_rec.python.model.easy_rec_model import EasyRecModel from easy_rec.python.protos.loss_pb2 import LossType +from easy_rec.python.utils.activation import gelu from easy_rec.python.utils.proto_util import copy_obj from easy_rec.python.protos.collaborative_metric_learning_pb2 import CoMetricLearningI2I as MetricLearningI2IConfig # NOQA diff --git a/easy_rec/python/model/match_model.py b/easy_rec/python/model/match_model.py index 851c7eb38..1be4fcb7f 100644 --- a/easy_rec/python/model/match_model.py +++ b/easy_rec/python/model/match_model.py @@ -179,7 +179,8 @@ def _build_point_wise_loss_graph(self): self._loss_type, label=label, pred=pred, - loss_weight=self._sample_weight, **kwargs) + loss_weight=self._sample_weight, + **kwargs) # build kd loss kd_loss_dict = loss_builder.build_kd_loss(self.kd, self._prediction_dict, diff --git a/easy_rec/python/test/train_eval_test.py b/easy_rec/python/test/train_eval_test.py index 5fca892b2..57c1d79bd 100644 --- a/easy_rec/python/test/train_eval_test.py +++ b/easy_rec/python/test/train_eval_test.py @@ -7,11 +7,11 @@ import threading import time import unittest +from distutils.version import LooseVersion import numpy as np import six import tensorflow as tf -from distutils.version import LooseVersion from tensorflow.python.platform import gfile from easy_rec.python.main import predict diff --git a/easy_rec/python/utils/activation.py b/easy_rec/python/utils/activation.py new file mode 100644 index 000000000..f52a012ae --- /dev/null +++ b/easy_rec/python/utils/activation.py @@ -0,0 +1,120 @@ +# -*- encoding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np +import six +import tensorflow as tf + +from easy_rec.python.utils.load_class import load_by_path + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + + +def dice(_x, axis=-1, epsilon=1e-9, name='dice', training=True): + """The Data Adaptive Activation Function in DIN. + + Which can be viewed as a generalization of PReLu, + and can adaptively adjust the rectified point according to distribution of input data. + + Arguments + - **axis** : Integer, the axis that should be used to compute data distribution (typically the features axis). + - **epsilon** : Small float added to variance to avoid dividing by zero. + + References + - [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C] + Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. + ACM, 2018: 1059-1068.] (https://arxiv.org/pdf/1706.06978.pdf) + """ + alphas = tf.get_variable( + 'alpha_' + name, + _x.get_shape()[-1], + initializer=tf.constant_initializer(0.0), + dtype=tf.float32) + inputs_normed = tf.layers.batch_normalization( + inputs=_x, + axis=axis, + epsilon=epsilon, + center=False, + scale=False, + training=training) + x_p = tf.sigmoid(inputs_normed) + return alphas * (1.0 - x_p) * _x + x_p * _x + + +def gelu(x, name='gelu'): + """Gaussian Error Linear Unit. + + This is a smoother version of the RELU. + Original paper: https://arxiv.org/abs/1606.08415 + + Args: + x: float Tensor to perform activation. + name: name for this activation + + Returns: + `x` with the GELU activation applied. + """ + with tf.name_scope(name): + cdf = 0.5 * (1.0 + tf.tanh( + (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) + return x * cdf + + +def swish(x, name='swish'): + with tf.name_scope(name): + return x * tf.sigmoid(x) + + +def get_activation(activation_string, **kwargs): + """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. + + Args: + activation_string: String name of the activation function. + + Returns: + A Python function corresponding to the activation function. If + `activation_string` is None, empty, or "linear", this will return None. + If `activation_string` is not a string, it will return `activation_string`. + + Raises: + ValueError: The `activation_string` does not correspond to a known + activation. + """ + # We assume that anything that's not a string is already an activation + # function, so we just return it. + if not isinstance(activation_string, six.string_types): + return activation_string + + if not activation_string: + return None + + act = activation_string.lower() + if act == 'linear': + return None + elif act == 'relu': + return tf.nn.relu + elif act == 'gelu': + return gelu + elif act == 'leaky_relu': + return tf.nn.leaky_relu + elif act == 'prelu': + if len(kwargs) == 0: + return tf.nn.leaky_relu + return tf.keras.layers.PReLU(**kwargs) + elif act == 'dice': + return lambda x, name='dice': dice(x, name=name, **kwargs) + elif act == 'elu': + return tf.nn.elu + elif act == 'selu': + return tf.nn.selu + elif act == 'tanh': + return tf.tanh + elif act == 'swish': + if tf.__version__ < '1.13.0': + return swish + return tf.nn.swish + elif act == 'sigmoid': + return tf.nn.sigmoid + else: + return load_by_path(activation_string) diff --git a/samples/model_config/dbmtl_on_taobao_with_multi_loss.config b/samples/model_config/dbmtl_on_taobao_with_multi_loss.config index 6c5aeacd1..d04564b02 100644 --- a/samples/model_config/dbmtl_on_taobao_with_multi_loss.config +++ b/samples/model_config/dbmtl_on_taobao_with_multi_loss.config @@ -252,6 +252,7 @@ model_config { dbmtl { bottom_dnn { hidden_units: [1024, 512, 256] + activation: "dice" } task_towers { tower_name: "ctr" @@ -272,25 +273,36 @@ model_config { } dnn { hidden_units: [256, 128, 64, 32] + activation: "dice" } relation_dnn { hidden_units: [32] + activation: "dice" } weight: 1.0 } task_towers { tower_name: "cvr" label_name: "buy" - loss_type: CLASSIFICATION + num_class: 2 + losses { + loss_type: JRC_LOSS + jrc_loss { + session_name: "user_id" + alpha: 0.5 + } + } metrics_set: { auc {} } dnn { hidden_units: [256, 128, 64, 32] + activation: "dice" } relation_tower_names: ["ctr"] relation_dnn { hidden_units: [32] + activation: "dice" } weight: 1.0 } diff --git a/setup.cfg b/setup.cfg index b5b966faa..b180b9fb1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,7 +10,7 @@ multi_line_output = 7 force_single_line = true known_standard_library = setuptools known_first_party = easy_rec -known_third_party = absl,common_io,distutils,docutils,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml +known_third_party = absl,common_io,docutils,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml no_lines_before = LOCALFOLDER default_section = THIRDPARTY skip = easy_rec/python/protos