Skip to content

Commit

Permalink
[bugix]: Fix need_key_feature for negative sampler sequence feature (#…
Browse files Browse the repository at this point in the history
…288)

* fix bug for negative sampler sequence feature need key feature
  • Loading branch information
lgqfhwy authored Sep 23, 2022
1 parent 97af5b0 commit 009c01b
Show file tree
Hide file tree
Showing 3 changed files with 312 additions and 1 deletion.
2 changes: 1 addition & 1 deletion easy_rec/python/layers/sequence_feature_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def negative_sampler_target_attention(self,
all_hist_dim_emb.append(outputs)
hist_din_emb = tf.concat(all_hist_dim_emb, axis=1)
if not need_key_feature:
return hist_din_emb
return hist_din_emb, concat_features
din_output = tf.concat([hist_din_emb, cur_id], axis=2)
return din_output, concat_features

Expand Down
7 changes: 7 additions & 0 deletions easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,13 @@ def test_multi_tower_recall_neg_sampler_sequence_feature(self):
self._test_dir)
self.assertTrue(self._success)

@unittest.skipIf(gl is None, 'graphlearn is not installed')
def test_multi_tower_recall_neg_sampler_only_sequence_feature(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/multi_tower_recall_neg_sampler_only_sequence_feature.config',
self._test_dir)
self.assertTrue(self._success)


if __name__ == '__main__':
tf.test.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
train_input_path: "data/test/tb_data/taobao_train_data"
eval_input_path: "data/test/tb_data/taobao_test_data"
model_dir: "experiments/multi_tower_recall_neg_sampler_only_sequence_feature"

train_config {
optimizer_config: {
adam_optimizer: {
learning_rate: {
exponential_decay_learning_rate {
initial_learning_rate: 0.001
decay_steps: 1000
decay_factor: 0.5
min_learning_rate: 1e-07
}
}
}
use_moving_average: false
}
num_steps: 6
sync_replicas: false
save_checkpoints_steps: 100
log_step_count_steps: 2
}

eval_config {
metrics_set: {
auc {
}
}
metrics_set: {
gauc {
uid_field: "user_id"
}
}
}

data_config {
batch_size: 16
input_fields {
input_name:'clk'
input_type: INT32
}
input_fields {
input_name:'buy'
input_type: INT32
}
input_fields {
input_name: 'pid'
input_type: STRING
}
input_fields {
input_name: 'adgroup_id'
input_type: STRING
}
input_fields {
input_name: 'cate_id'
input_type: STRING
}
input_fields {
input_name: 'campaign_id'
input_type: STRING
}
input_fields {
input_name: 'customer'
input_type: STRING
}
input_fields {
input_name: 'brand'
input_type: STRING
}
input_fields {
input_name: 'user_id'
input_type: STRING
}
input_fields {
input_name: 'cms_segid'
input_type: STRING
}
input_fields {
input_name: 'cms_group_id'
input_type: STRING
}
input_fields {
input_name: 'final_gender_code'
input_type: STRING
}
input_fields {
input_name: 'age_level'
input_type: STRING
}
input_fields {
input_name: 'pvalue_level'
input_type: STRING
}
input_fields {
input_name: 'shopping_level'
input_type: STRING
}
input_fields {
input_name: 'occupation'
input_type: STRING
}
input_fields {
input_name: 'new_user_class_level'
input_type: STRING
}
input_fields {
input_name: 'tag_category_list'
input_type: STRING
}
input_fields {
input_name: 'tag_brand_list'
input_type: STRING
}
input_fields {
input_name: 'price'
input_type: INT32
}

label_fields: 'clk'
num_epochs: 5
prefetch_size: 4
input_type: CSVInput

negative_sampler {
input_path: 'data/test/tb_data/taobao_ad_feature_gl'
num_sample: 4
num_eval_sample: 4
attr_fields: 'adgroup_id'
attr_fields: 'cate_id'
attr_fields: 'campaign_id'
attr_fields: 'customer'
attr_fields: 'brand'
item_id_field: 'adgroup_id'
}
}

feature_configs : {
input_names: 'pid'
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 10
}
feature_configs : {
input_names: 'adgroup_id'
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 100000
}
feature_configs : {
input_names: 'cate_id'
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 10000
}
feature_configs : {
input_names: 'campaign_id'
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 100000
}
feature_configs : {
input_names: 'customer'
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 100000
}
feature_configs : {
input_names: 'brand'
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 100000
}
feature_configs : {
input_names: 'user_id'
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 100000
}
feature_configs : {
input_names: 'cms_segid'
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 100
}
feature_configs : {
input_names: 'cms_group_id'
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 100
}
feature_configs : {
input_names: 'final_gender_code'
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 10
}
feature_configs : {
input_names: 'age_level'
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 10
}
feature_configs : {
input_names: 'pvalue_level'
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 10
}
feature_configs : {
input_names: 'shopping_level'
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 10
}
feature_configs : {
input_names: 'occupation'
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 10
}
feature_configs : {
input_names: 'new_user_class_level'
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 10
}
feature_configs {
input_names: "tag_category_list"
feature_type: SequenceFeature
embedding_dim: 16
hash_bucket_size: 100000
sub_feature_type: IdFeature
separator: "|"
}
feature_configs {
input_names: "tag_brand_list"
feature_type: SequenceFeature
embedding_dim: 16
hash_bucket_size: 100000
sub_feature_type: IdFeature
separator: "|"
}
feature_configs : {
input_names: 'price'
feature_type: IdFeature
embedding_dim: 16
num_buckets: 50
}
model_config:{
model_class: "MultiTowerRecall"
feature_groups: {
group_name: 'user'
feature_names: 'user_id'
feature_names: 'cms_segid'
feature_names: 'cms_group_id'
feature_names: 'age_level'
feature_names: 'pvalue_level'
feature_names: 'shopping_level'
feature_names: 'occupation'
feature_names: 'new_user_class_level'
wide_deep:DEEP
negative_sampler:true
sequence_features: {
group_name: "seq_fea"
allow_key_search: true
need_key_feature:false
seq_att_map: {
key: "brand"
key: "cate_id"
hist_seq: "tag_brand_list"
hist_seq: "tag_category_list"
}
}
}
feature_groups: {
group_name: "item"
feature_names: 'adgroup_id'
feature_names: 'cate_id'
feature_names: 'campaign_id'
feature_names: 'customer'
feature_names: 'brand'
wide_deep:DEEP
}
multi_tower_recall {
user_tower {
dnn {
hidden_units: [256, 128, 64, 32]
# dropout_ratio : [0.1, 0.1, 0.1, 0.1]
}
}
item_tower {
dnn {
hidden_units: [256, 128, 64, 32]
}
}
final_dnn {
hidden_units: [128, 96, 64, 32, 16]
}
l2_regularization: 1e-6
}
loss_type: CLASSIFICATION
embedding_regularization: 5e-6
}

0 comments on commit 009c01b

Please sign in to comment.