Skip to content

Commit

Permalink
add tdm serving tree & remove recall_num in gen tree
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji committed Sep 30, 2024
1 parent 6066e6a commit dda0e4a
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 62 deletions.
8 changes: 5 additions & 3 deletions docs/source/quick_start/local_tutorial_tdm.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ python -m tzrec.tools.tdm.init_tree \
- --raw_attr_fields: (可选) item的数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致
- --tree_output_file: (可选)初始树的保存路径, 不输入不会保存
- --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持ODPS和本地txt两种
- --recall_num: (可选,默认为200)召回数量, 会根据召回数量自动跳过前几层树, 增加召回的效率
- --n_cluster: (可选,默认为2)树的分叉数

#### 训练
Expand Down Expand Up @@ -80,11 +79,13 @@ torchrun --master_addr=localhost --master_port=32555 \
-m tzrec.export \
--pipeline_config_path experiments/tdm_taobao_local/pipeline.config \
--export_dir experiments/tdm_taobao_local/export
--asset_files data/init_tree/serving_tree
```

- --pipeline_config_path: 导出用的配置文件
- --checkpoint_path: 指定要导出的checkpoint, 默认评估model_dir下面最新的checkpoint
- --export_dir: 导出到的模型目录
- --asset_files: 需额拷贝到模型目录的文件。tdm需拷贝serving_tree树文件用于线上服务

#### 导出item embedding

Expand Down Expand Up @@ -124,7 +125,6 @@ OMP_NUM_THREADS=4 python tzrec/tools/tdm/cluster_tree.py \
- --raw_attr_fields: (可选) item的数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致
- --tree_output_file: (可选)树的保存路径, 不输入不会保存
- --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持ODPS和本地txt两种
- --recall_num: (可选,默认为200)召回数量, 会根据召回数量自动跳过前几层树, 增加召回的效率
- --n_cluster: (可选,默认为2)树的分叉数
- --parllel: (可选,默认为16)聚类时CPU并行数

Expand Down Expand Up @@ -153,11 +153,13 @@ torchrun --master_addr=localhost --master_port=32555 \
-m tzrec.export \
--pipeline_config_path experiments/tdm_taobao_local_learnt/pipeline.config \
--export_dir experiments/tdm_taobao_local_learnt/export
--asset_files data/learnt_tree/serving_tree
```

- --pipeline_config_path: 导出用的配置文件
- --checkpoint_path: 指定要导出的checkpoint, 默认评估model_dir下面最新的checkpoint
- --export_dir: 导出到的模型目录
- --asset_files: 需额拷贝到模型目录的文件。tdm需拷贝serving_tree树文件用于线上服务

#### Recall评估

Expand All @@ -181,7 +183,7 @@ torchrun --master_addr=localhost --master_port=32555 \
- --predict_input_path: 预测输入数据的路径
- --predict_output_path: 预测输出数据的路径
- --gt_item_id_field: 文件中代表真实点击item_id的列名
- --recall_num:(可选, 默认为200) 召回的数量, 应与建树时输入保持一致
- --recall_num:(可选, 默认为200) 召回的数量
- --n_cluster:(可选, 默认为2) 数的分叉数量, 应与建树时输入保持一致
- --reserved_columns: 预测结果中要保留的输入列

Expand Down
6 changes: 6 additions & 0 deletions tzrec/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
default=None,
help="directory where model should be exported to.",
)
parser.add_argument(
"--asset_files",
type=str,
default=None,
help="more files will be copy to export_dir.",
)
args, extra_args = parser.parse_known_args()

export(
Expand Down
21 changes: 18 additions & 3 deletions tzrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import itertools
import json
import os
import shutil
from collections import OrderedDict
from queue import Queue
from threading import Thread
Expand Down Expand Up @@ -747,7 +748,10 @@ def _script_model(


def export(
pipeline_config_path: str, export_dir: str, checkpoint_path: Optional[str] = None
pipeline_config_path: str,
export_dir: str,
checkpoint_path: Optional[str] = None,
asset_files: Optional[str] = None,
) -> None:
"""Export a EasyRec model.
Expand All @@ -756,6 +760,7 @@ def export(
export_dir (str): base directory where the model should be exported.
checkpoint_path (str, optional): if specified, will use this model instead of
model specified by model_dir in pipeline_config_path.
asset_files (str, optional): more files will be copy to export_dir.
"""
pipeline_config = config_util.load_pipeline_config(pipeline_config_path)
ori_pipeline_config = copy.copy(pipeline_config)
Expand All @@ -766,6 +771,10 @@ def export(
if os.path.exists(export_dir):
raise RuntimeError(f"directory {export_dir} already exist.")

assets = []
if asset_files:
assets = asset_files.split(",")

data_config = pipeline_config.data_config
# Build feature
features = _create_features(list(pipeline_config.feature_configs), data_config)
Expand Down Expand Up @@ -832,13 +841,16 @@ def export(
for name, module in cpu_model.named_children():
if isinstance(module, MatchTower):
tower = ScriptWrapper(TowerWrapper(module, name))
tower_export_dir = os.path.join(export_dir, name.replace("_tower", ""))
_script_model(
ori_pipeline_config,
tower,
cpu_state_dict,
dataloader,
os.path.join(export_dir, name.replace("_tower", "")),
tower_export_dir,
)
for asset in assets:
shutil.copy(asset, tower_export_dir)
elif isinstance(cpu_model, TDM):
for name, module in cpu_model.named_children():
if isinstance(module, EmbeddingGroup):
Expand All @@ -857,7 +869,8 @@ def export(
dataloader,
export_dir,
)

for asset in assets:
shutil.copy(asset, export_dir)
else:
_script_model(
ori_pipeline_config,
Expand All @@ -866,6 +879,8 @@ def export(
dataloader,
export_dir,
)
for asset in assets:
shutil.copy(asset, export_dir)


def predict(
Expand Down
10 changes: 8 additions & 2 deletions tzrec/tests/train_eval_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,9 @@ def test_tdm_train_eval_export(self):
)
if self.success:
self.success = utils.test_export(
os.path.join(self.test_dir, "pipeline.config"), self.test_dir
os.path.join(self.test_dir, "pipeline.config"),
self.test_dir,
asset_files=os.path.join(self.test_dir, "init_tree/serving_tree"),
)
if self.success:
self.success = utils.test_predict(
Expand All @@ -556,8 +558,9 @@ def test_tdm_train_eval_export(self):
item_id="item_id",
embedding_field="item_emb",
)
self.success = True
if self.success:
with open(os.path.join(self.test_dir, "node_table.txt")) as f:
with open(os.path.join(self.test_dir, "init_tree/node_table.txt")) as f:
for line_number, line in enumerate(f):
if line_number == 1:
root_id = int(line.split("\t")[0])
Expand Down Expand Up @@ -586,6 +589,9 @@ def test_tdm_train_eval_export(self):
self.assertTrue(
os.path.exists(os.path.join(self.test_dir, "export/scripted_model.pt"))
)
self.assertTrue(
os.path.exists(os.path.join(self.test_dir, "export/serving_tree"))
)
self.assertTrue(os.path.exists(os.path.join(self.test_dir, "retrieval_result")))


Expand Down
22 changes: 14 additions & 8 deletions tzrec/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,16 +791,19 @@ def load_config_for_test(
f"--cate_id_field {cate_id} "
f"--attr_fields {','.join(attr_fields)} "
f"--raw_attr_fields {','.join(raw_attr_fields)} "
f"--node_edge_output_file {test_dir} "
f"--recall_num 1"
f"--node_edge_output_file {test_dir}/init_tree "
)
p = misc_util.run_cmd(cmd_str, os.path.join(test_dir, "log_init_tree.txt"))
p.wait(600)

sampler_config.item_input_path = os.path.join(test_dir, "node_table.txt")
sampler_config.edge_input_path = os.path.join(test_dir, "edge_table.txt")
sampler_config.item_input_path = os.path.join(
test_dir, "init_tree/node_table.txt"
)
sampler_config.edge_input_path = os.path.join(
test_dir, "init_tree/edge_table.txt"
)
sampler_config.predict_edge_input_path = os.path.join(
test_dir, "predict_edge_table.txt"
test_dir, "init_tree/predict_edge_table.txt"
)

else:
Expand Down Expand Up @@ -874,7 +877,9 @@ def test_eval(pipeline_config_path: str, test_dir: str) -> bool:
return True


def test_export(pipeline_config_path: str, test_dir: str) -> bool:
def test_export(
pipeline_config_path: str, test_dir: str, asset_files: str = ""
) -> bool:
"""Run export integration test."""
port = misc_util.get_free_port()
log_dir = os.path.join(test_dir, "log_export")
Expand All @@ -884,8 +889,10 @@ def test_export(pipeline_config_path: str, test_dir: str) -> bool:
f"--nproc-per-node=2 --node_rank=0 --log_dir {log_dir} "
"-r 3 -t 3 tzrec/export.py "
f"--pipeline_config_path {pipeline_config_path} "
f"--export_dir {test_dir}/export"
f"--export_dir {test_dir}/export "
)
if asset_files:
cmd_str += f"--asset_files {asset_files}"

p = misc_util.run_cmd(cmd_str, os.path.join(test_dir, "log_export.txt"))
p.wait(600)
Expand Down Expand Up @@ -1206,7 +1213,6 @@ def test_tdm_cluster_train_eval(
f"--raw_attr_fields {','.join(raw_attr_fields)} "
f"--node_edge_output_file {os.path.join(test_dir, 'learnt_tree')} "
f"--parallel 1 "
f"--recall_num 1 "
)
p = misc_util.run_cmd(
cluster_cmd_str, os.path.join(test_dir, "log_tdm_cluster.txt")
Expand Down
11 changes: 2 additions & 9 deletions tzrec/tools/tdm/cluster_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# limitations under the License.

import argparse
import math

from tzrec.tools.tdm.gen_tree.tree_cluster import TreeCluster
from tzrec.tools.tdm.gen_tree.tree_search_util import TreeSearch
Expand Down Expand Up @@ -66,12 +65,6 @@
default=16,
help="The number of CPU cores for parallel processing.",
)
parser.add_argument(
"--recall_num",
type=int,
default=200,
help="Recall number per item when retrieval.",
)
parser.add_argument(
"--n_cluster",
type=int,
Expand Down Expand Up @@ -102,6 +95,6 @@
child_num=args.n_cluster,
)
tree_search.save()
first_recall_layer = int(math.ceil(math.log(2 * args.recall_num, args.n_cluster)))
tree_search.save_predict_edge(first_recall_layer)
tree_search.save_predict_edge()
tree_search.save_serving_tree()
logger.info("Save nodes and edges table done.")
38 changes: 18 additions & 20 deletions tzrec/tools/tdm/gen_tree/tree_search_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,25 +167,19 @@ def save(self) -> None:
for i in range(self.max_level):
f.write(f"{travel[0]}\t{travel[i+1]}\t{1.0}\n")

def save_predict_edge(self, first_recall_layer: int) -> None:
def save_predict_edge(self) -> None:
"""Save edge info for prediction."""
if self.output_file.startswith("odps://"):
writer = create_writer(self.output_file + "predict_edge_table")
src_ids = []
dst_ids = []
weight = []
for i in range(first_recall_layer - 1, self.max_level):
if i == first_recall_layer - 1:
for node in self.level_code[i + 1]:
src_ids.append(self.root.item_id)
dst_ids.append(node.item_id)
for i in range(self.max_level):
for node in self.level_code[i]:
for child in node.children:
src_ids.append(node.item_id)
dst_ids.append(child.item_id)
weight.append(1.0)
else:
for node in self.level_code[i]:
for child in node.children:
src_ids.append(node.item_id)
dst_ids.append(child.item_id)
weight.append(1.0)
edge_table_dict = OrderedDict()
edge_table_dict["src_id"] = pa.array(src_ids)
edge_table_dict["dst_id"] = pa.array(dst_ids)
Expand All @@ -196,11 +190,15 @@ def save_predict_edge(self, first_recall_layer: int) -> None:
os.path.join(self.output_file, "predict_edge_table.txt"), "w"
) as f:
f.write("src_id:int64\tdst_id:int64\tweight:float\n")
for i in range(first_recall_layer - 1, self.max_level):
if i == first_recall_layer - 1:
for node in self.level_code[i + 1]:
f.write(f"{self.root.item_id}\t{node.item_id}\t{1.0}\n")
else:
for node in self.level_code[i]:
for child in node.children:
f.write(f"{node.item_id}\t{child.item_id}\t{1.0}\n")
for i in range(self.max_level):
for node in self.level_code[i]:
for child in node.children:
f.write(f"{node.item_id}\t{child.item_id}\t{1.0}\n")

def save_serving_tree(self) -> None:
"""Save tree info for serving."""
with open(os.path.join(self.output_file, "serving_tree"), "w") as f:
f.write(f"{self.max_level + 1} {self.child_num}\n")
for _, nodes in enumerate(self.level_code):
for node in nodes:
f.write(f"{node.tree_code} {node.item_id}\n")
11 changes: 2 additions & 9 deletions tzrec/tools/tdm/init_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# limitations under the License.

import argparse
import math

from tzrec.tools.tdm.gen_tree.tree_generator import TreeGenerator
from tzrec.tools.tdm.gen_tree.tree_search_util import TreeSearch
Expand Down Expand Up @@ -60,12 +59,6 @@
default=None,
help="The nodes and edges table output file.",
)
parser.add_argument(
"--recall_num",
type=int,
default=200,
help="Recall number per item when retrieval.",
)
parser.add_argument(
"--n_cluster",
type=int,
Expand Down Expand Up @@ -95,6 +88,6 @@
child_num=args.n_cluster,
)
tree_search.save()
first_recall_layer = int(math.ceil(math.log(2 * args.recall_num, args.n_cluster)))
tree_search.save_predict_edge(first_recall_layer)
tree_search.save_predict_edge()
tree_search.save_serving_tree()
logger.info("Save nodes and edges table done.")
18 changes: 10 additions & 8 deletions tzrec/tools/tdm/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def tdm_retrieval(
sampler_config = pipeline_config.data_config.tdm_sampler
item_id_field = sampler_config.item_id_field
max_level = len(sampler_config.layer_num_sample)
first_recall_layer = int(math.ceil(math.log(2 * recall_num, n_cluster)))
first_recall_layer = int(math.ceil(math.log(2 * n_cluster * recall_num, n_cluster)))

dataset = infer_dataloader.dataset
# pyre-ignore [16]
Expand All @@ -210,6 +210,7 @@ def tdm_retrieval(
pos_sampler.init_cluster(num_client_per_rank=1)
pos_sampler.launch_server()
pos_sampler.init()
pos_sampler.init_sampler(n_cluster)
i_step = 0

num_class = pipeline_config.model_config.num_class
Expand All @@ -226,10 +227,14 @@ def tdm_retrieval(
cur_batch_size = len(node_ids)

expand_num = n_cluster**first_recall_layer
pos_sampler.init_sampler(expand_num)

for layer in range(first_recall_layer, max_level):
for layer in range(1, max_level):
sampled_result_dict = pos_sampler.get(node_ids)

# skip layers before first_recall_layer
if layer < first_recall_layer:
node_ids = sampled_result_dict[item_id_field]
continue

updated_inputs = update_data(
reserve_batch_record, sampled_result_dict, expand_num
)
Expand Down Expand Up @@ -267,16 +272,13 @@ def tdm_retrieval(
_, topk_indices_in_group = torch.topk(probs, k, dim=1)
topk_indices = (
topk_indices_in_group
+ torch.arange(cur_batch_size)
.unsqueeze(1)
.to(topk_indices_in_group.device)
+ torch.arange(cur_batch_size, device=device).unsqueeze(1)
* expand_num
)
topk_indices = topk_indices.reshape(-1).cpu().numpy()
node_ids = updated_inputs[item_id_field].take(topk_indices)

if layer == first_recall_layer:
pos_sampler.init_sampler(n_cluster)
expand_num = n_cluster * k

output_dict = OrderedDict()
Expand Down

0 comments on commit dda0e4a

Please sign in to comment.