Skip to content

Commit

Permalink
[feat]: add dice activation for dnn layer (#362)
Browse files Browse the repository at this point in the history
* [feat]: add dice activation for dnn layer
  • Loading branch information
yangxudong authored May 5, 2023
1 parent d11321c commit adc5f25
Show file tree
Hide file tree
Showing 16 changed files with 209 additions and 91 deletions.
33 changes: 21 additions & 12 deletions docs/source/train.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,21 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2

#### 使用单个损失函数

| 损失函数 | 说明 |
| 损失函数 | 说明 |
| ------------------------------------------ | ---------------------------------------------------------- |
| CLASSIFICATION | 分类Loss,二分类为sigmoid_cross_entropy;多分类为softmax_cross_entropy |
| L2_LOSS | 平方损失 |
| SIGMOID_L2_LOSS | 对sigmoid函数的结果计算平方损失 |
| SIGMOID_L2_LOSS | 对sigmoid函数的结果计算平方损失 |
| CROSS_ENTROPY_LOSS | log loss 负对数损失 |
| CIRCLE_LOSS | CoMetricLearningI2I模型专用 |
| MULTI_SIMILARITY_LOSS | CoMetricLearningI2I模型专用 |
| SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING | 自动负采样版本的多分类softmax_cross_entropy,用在二分类任务中 |
| BINARY_FOCAL_LOSS | 支持困难样本挖掘和类别平衡的focal loss |
| PAIR_WISE_LOSS | 以优化全局AUC为目标的rank loss |
| PAIRWISE_FOCAL_LOSS | pair粒度的focal loss, 支持自定义pair分组 |
| PAIRWISE_LOGISTIC_LOSS | pair粒度的logistic loss, 支持自定义pair分组 |
| JRC_LOSS | 二分类 + listwise ranking loss |
| F1_REWEIGHTED_LOSS | 可以调整二分类召回率和准确率相对权重的损失函数,可有效对抗正负样本不平衡问题 |
| SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING | 自动负采样版本的多分类softmax_cross_entropy,用在二分类任务中 |
| BINARY_FOCAL_LOSS | 支持困难样本挖掘和类别平衡的focal loss |
| PAIR_WISE_LOSS | 以优化全局AUC为目标的rank loss |
| PAIRWISE_FOCAL_LOSS | pair粒度的focal loss, 支持自定义pair分组 |
| PAIRWISE_LOGISTIC_LOSS | pair粒度的logistic loss, 支持自定义pair分组 |
| JRC_LOSS | 二分类 + listwise ranking loss |
| F1_REWEIGHTED_LOSS | 可以调整二分类召回率和准确率相对权重的损失函数,可有效对抗正负样本不平衡问题 |

- 说明:SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING
- 支持参数配置,升级为 [support vector guided softmax loss](https://128.84.21.199/abs/1812.11317) ,
Expand Down Expand Up @@ -153,40 +153,49 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2
- f1_beta_square 即为 上述公式中的 beta 系数的平方。

- PAIRWISE_FOCAL_LOSS 的参数配置

- gamma: focal loss的指数,默认值2.0
- alpha: 调节样本权重的类别平衡参数,建议根据正负样本比例来配置alpha, $\frac{\alpha}{1-\alpha}=\frac{#Neg}{#Pos}$
- alpha: 调节样本权重的类别平衡参数,建议根据正负样本比例来配置alpha, $\\frac{\\alpha}{1-\\alpha}=\\frac{#Neg}{#Pos}$
- session_name: pair分组的字段名,比如user_id
- hinge_margin: 当pair的logit之差大于该参数值时,当前样本的loss为0,默认值为1.0
- ohem_ratio: 困难样本的百分比,只有部分困难样本参与loss计算,默认值为1.0
- temperature: 温度系数,logit除以该参数值后再参与计算,默认值为1.0

- PAIRWISE_LOGISTIC_LOSS 的参数配置

- session_name: pair分组的字段名,比如user_id
- hinge_margin: 当pair的logit之差大于该参数值时,当前样本的loss为0,默认值为1.0
- ohem_ratio: 困难样本的百分比,只有部分困难样本参与loss计算,默认值为1.0
- temperature: 温度系数,logit除以该参数值后再参与计算,默认值为1.0

- PAIRWISE_LOSS 的参数配置

- session_name: pair分组的字段名,比如user_id
- margin: 当pair的logit之差减去该参数值后再参与计算,即正负样本的logit之差至少要大于margin,默认值为0
- temperature: 温度系数,logit除以该参数值后再参与计算,默认值为1.0

备注:上述 PAIRWISE_*_LOSS 都是在mini-batch内构建正负样本pair,目标是让正负样本pair的logit相差尽可能大
备注:上述 PAIRWISE\_\*\_LOSS 都是在mini-batch内构建正负样本pair,目标是让正负样本pair的logit相差尽可能大

- BINARY_FOCAL_LOSS 的参数配置

- gamma: focal loss的指数,默认值2.0
- alpha: 调节样本权重的类别平衡参数,建议根据正负样本比例来配置alpha, $\frac{\alpha}{1-\alpha}=\frac{#Neg}{#Pos}$
- alpha: 调节样本权重的类别平衡参数,建议根据正负样本比例来配置alpha, $\\frac{\\alpha}{1-\\alpha}=\\frac{#Neg}{#Pos}$
- ohem_ratio: 困难样本的百分比,只有部分困难样本参与loss计算,默认值为1.0
- label_smoothing: 标签平滑系数

- JRC_LOSS 的参数配置

- alpha: ranking loss 与 calibration loss 的相对权重系数;不设置该值时,触发权重自适应学习
- session_name: list分组的字段名,比如user_id
- 参考论文:《 [Joint Optimization of Ranking and Calibration with Contextualized Hybrid Model](https://arxiv.org/pdf/2208.06164.pdf)
- 使用示例: [dbmtl_with_jrc_loss.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/dbmtl_on_taobao_with_multi_loss.config)

排序模型同时使用多个损失函数的完整示例:
[cmbf_with_multi_loss.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/cmbf_with_multi_loss.config)

多目标排序模型同时使用多个损失函数的完整示例:
[dbmtl_with_multi_loss.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/dbmtl_on_taobao_with_multi_loss.config)

##### 损失函数权重自适应学习

多目标学习任务中,人工指定多个损失函数的静态权重通常不能获得最好的效果。EasyRec支持损失函数权重自适应学习,示例如下:
Expand Down
2 changes: 1 addition & 1 deletion easy_rec/python/compat/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import os
import threading
import time
from distutils.version import LooseVersion

import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
Expand Down
1 change: 0 additions & 1 deletion easy_rec/python/compat/weight_decay_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,5 +472,4 @@ def __init__(self,
use_locking=use_locking,
name=name)
except ImportError:
print('import AdamAsyncOptimizer failed when loading AdamAsyncWOptimizer')
pass
12 changes: 12 additions & 0 deletions easy_rec/python/feature_column/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def _cmp_embed_config(a, b):
'shared embed info of [%s] is not matched [%s] vs [%s]' % (
embed_name, config, self._share_embed_infos[embed_name])
self._share_embed_names[embed_name] += 1
if config.feature_type == FeatureConfig.FeatureType.SequenceFeature:
self._share_embed_infos[embed_name] = copy_obj(config)
else:
self._share_embed_names[embed_name] = 1
self._share_embed_infos[embed_name] = copy_obj(config)
Expand Down Expand Up @@ -156,6 +158,11 @@ def _cmp_embed_config(a, b):
combiner=self._share_embed_infos[embed_name].combiner,
partitioner=partitioner,
ev_params=ev_params)
config = self._share_embed_infos[embed_name]
max_seq_len = config.max_seq_len if config.HasField(
'max_seq_len') else -1
for fc in share_embed_fcs:
fc.max_seq_length = max_seq_len
self._deep_share_embed_columns[embed_name] = share_embed_fcs

# for handling wide share embedding columns
Expand All @@ -168,6 +175,11 @@ def _cmp_embed_config(a, b):
combiner='sum',
partitioner=partitioner,
ev_params=ev_params)
config = self._share_embed_infos[embed_name]
max_seq_len = config.max_seq_len if config.HasField(
'max_seq_len') else -1
for fc in share_embed_fcs:
fc.max_seq_length = max_seq_len
self._wide_share_embed_columns[embed_name] = share_embed_fcs

for fc_name in self._deep_columns:
Expand Down
3 changes: 2 additions & 1 deletion easy_rec/python/input/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def __init__(self,
model_name = model_config.WhichOneof('model')
if model_name in {'mmoe', 'esmm', 'dbmtl', 'simple_multi_task', 'ple'}:
model = getattr(model_config, model_name)
towers = [model.ctr_tower, model.cvr_tower] if model_name == 'esmm' else model.task_towers
towers = [model.ctr_tower, model.cvr_tower
] if model_name == 'esmm' else model.task_towers
for tower in towers:
metrics = tower.metrics_set
for metric in metrics:
Expand Down
4 changes: 4 additions & 0 deletions easy_rec/python/layers/cmbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ def merge_text_embedding(self, txt_embeddings, input_masks):
return txt_embeddings

def __call__(self, is_training, *args, **kwargs):
if not is_training:
self._model_config.hidden_dropout_prob = 0.0
self._model_config.attention_probs_dropout_prob = 0.0

# shape: [batch_size, image_num/image_dim, hidden_size]
img_attention_fea = self.image_self_attention_tower()

Expand Down
17 changes: 0 additions & 17 deletions easy_rec/python/layers/common_layers.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,12 @@
# -*- encoding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.

import numpy as np
import tensorflow as tf

if tf.__version__ >= '2.0':
tf = tf.compat.v1


def gelu(x):
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
Original paper: https://arxiv.org/abs/1606.08415
Args:
x: float Tensor to perform activation.
Returns:
`x` with the GELU activation applied.
"""
cdf = 0.5 * (1.0 + tf.tanh(
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf


def highway(x,
size=None,
activation=None,
Expand Down
7 changes: 4 additions & 3 deletions easy_rec/python/layers/dnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import tensorflow as tf

from easy_rec.python.utils.load_class import load_by_path
from easy_rec.python.utils.activation import get_activation

if tf.__version__ >= '2.0':
tf = tf.compat.v1
Expand All @@ -25,7 +25,7 @@ def __init__(self,
dnn_config: instance of easy_rec.python.protos.dnn_pb2.DNN
l2_reg: l2 regularizer
name: scope of the DNN, so that the parameters could be separated from other dnns
is_training: train phase or not, impact batchnorm and dropout
is_training: train phase or not, impact batch_norm and dropout
last_layer_no_activation: in last layer, use or not use activation
last_layer_no_batch_norm: in last layer, use or not use batch norm
"""
Expand All @@ -34,7 +34,8 @@ def __init__(self,
self._name = name
self._is_training = is_training
logging.info('dnn activation function = %s' % self._config.activation)
self.activation = load_by_path(self._config.activation)
self.activation = get_activation(
self._config.activation, training=is_training)
self._last_layer_no_activation = last_layer_no_activation
self._last_layer_no_batch_norm = last_layer_no_batch_norm

Expand Down
70 changes: 21 additions & 49 deletions easy_rec/python/layers/multihead_cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@

import math

import six
import tensorflow as tf

from easy_rec.python.compat.layers import layer_norm as tf_layer_norm
from easy_rec.python.layers.common_layers import gelu
from easy_rec.python.utils.activation import gelu
from easy_rec.python.utils.shape_utils import get_shape_list

if tf.__version__ >= '2.0':
Expand Down Expand Up @@ -53,7 +52,8 @@ def attention_layer(from_tensor,
do_return_2d_tensor=False,
batch_size=None,
from_seq_length=None,
to_seq_length=None):
to_seq_length=None,
reuse=None):
"""Performs multi-headed attention from `from_tensor` to `to_tensor`.
This is an implementation of multi-headed attention based on "Attention is all you Need".
Expand Down Expand Up @@ -96,6 +96,7 @@ def attention_layer(from_tensor,
of the 3D version of the `from_tensor`.
to_seq_length: (Optional) If the input is 2D, this might be the seq length
of the 3D version of the `to_tensor`.
reuse: whether to reuse this layer
Returns:
float Tensor of shape [batch_size, from_seq_length,
Expand Down Expand Up @@ -149,23 +150,26 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
num_attention_heads * size_per_head,
activation=query_act,
name='query',
kernel_initializer=create_initializer(initializer_range))
kernel_initializer=create_initializer(initializer_range),
reuse=reuse)

# `key_layer` = [B*T, N*H]
key_layer = tf.layers.dense(
to_tensor_2d,
num_attention_heads * size_per_head,
activation=key_act,
name='key',
kernel_initializer=create_initializer(initializer_range))
kernel_initializer=create_initializer(initializer_range),
reuse=reuse)

# `value_layer` = [B*T, N*H]
value_layer = tf.layers.dense(
to_tensor_2d,
num_attention_heads * size_per_head,
activation=value_act,
name='value',
kernel_initializer=create_initializer(initializer_range))
kernel_initializer=create_initializer(initializer_range),
reuse=reuse)

# `query_layer` = [B, N, F, H]
query_layer = transpose_for_scores(query_layer, batch_size,
Expand Down Expand Up @@ -242,6 +246,7 @@ def transformer_encoder(input_tensor,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
initializer_range=0.02,
reuse=None,
name='transformer'):
"""Multi-headed, multi-layer Transformer from "Attention is All You Need".
Expand All @@ -265,6 +270,7 @@ def transformer_encoder(input_tensor,
probabilities.
initializer_range: float. Range of the initializer (stddev of truncated
normal).
reuse: whether to reuse this encoder
name: scope name prefix
Returns:
Expand Down Expand Up @@ -315,11 +321,12 @@ def transformer_encoder(input_tensor,
do_return_2d_tensor=True,
batch_size=batch_size,
from_seq_length=seq_length,
to_seq_length=seq_length)
to_seq_length=seq_length,
reuse=reuse)

# Run a linear projection of `hidden_size` then add a residual
# with `layer_input`.
with tf.variable_scope('output'):
with tf.variable_scope('output', reuse=reuse):
attention_output = tf.layers.dense(
attention_output,
hidden_size,
Expand All @@ -328,15 +335,15 @@ def transformer_encoder(input_tensor,
attention_output = layer_norm(attention_output + layer_input)

# The activation is only applied to the "intermediate" hidden layer.
with tf.variable_scope('intermediate'):
with tf.variable_scope('intermediate', reuse=reuse):
intermediate_output = tf.layers.dense(
attention_output,
intermediate_size,
activation=intermediate_act_fn,
kernel_initializer=create_initializer(initializer_range))

# Down-project back to `hidden_size` then add the residual.
with tf.variable_scope('output'):
with tf.variable_scope('output', reuse=reuse):
layer_output = tf.layers.dense(
intermediate_output,
hidden_size,
Expand Down Expand Up @@ -640,6 +647,7 @@ def embedding_postprocessor(input_tensor,
reuse_token_type=None,
use_position_embeddings=True,
position_embedding_name='position_embeddings',
reuse_position_embedding=None,
initializer_range=0.02,
max_position_embeddings=512,
dropout_prob=0.1):
Expand All @@ -659,6 +667,7 @@ def embedding_postprocessor(input_tensor,
position of each token in the sequence.
position_embedding_name: string. The name of the embedding table variable
for positional embeddings.
reuse_position_embedding: bool. Whether to reuse position embedding variable.
initializer_range: float. Range of the weight initialization.
max_position_embeddings: int. Maximum sequence length that might ever be
used with this model. This can be longer than the sequence length of
Expand Down Expand Up @@ -699,7 +708,8 @@ def embedding_postprocessor(input_tensor,
if use_position_embeddings:
assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
with tf.control_dependencies([assert_op]):
full_position_embeddings = tf.get_variable(
with tf.variable_scope("position_embedding", reuse=reuse_position_embedding):
full_position_embeddings = tf.get_variable(
name=position_embedding_name,
shape=[max_position_embeddings, width],
initializer=create_initializer(initializer_range))
Expand Down Expand Up @@ -736,41 +746,3 @@ def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
output_tensor = layer_norm(input_tensor, name)
output_tensor = dropout(output_tensor, dropout_prob)
return output_tensor


def get_activation(activation_string):
"""Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
Args:
activation_string: String name of the activation function.
Returns:
A Python function corresponding to the activation function. If
`activation_string` is None, empty, or "linear", this will return None.
If `activation_string` is not a string, it will return `activation_string`.
Raises:
ValueError: The `activation_string` does not correspond to a known
activation.
"""
# We assume that anything that's not a string is already an activation
# function, so we just return it.
if not isinstance(activation_string, six.string_types):
return activation_string

if not activation_string:
return None

act = activation_string.lower()
if act == 'linear':
return None
elif act == 'relu':
return tf.nn.relu
elif act == 'gelu':
return gelu
elif act == 'tanh':
return tf.tanh
elif act == 'swish':
return tf.nn.swish
else:
raise ValueError('Unsupported activation: %s' % act)
Loading

0 comments on commit adc5f25

Please sign in to comment.