Skip to content

Commit

Permalink
[feat]: add avx string split feature (#364)
Browse files Browse the repository at this point in the history
* add avx string split feature
* add pip requirement for eas_prediction
* update incr_record ops
* add step counter hook
  • Loading branch information
chengmengli06 authored May 9, 2023
1 parent 135671a commit 3604da3
Show file tree
Hide file tree
Showing 25 changed files with 83 additions and 9 deletions.
5 changes: 4 additions & 1 deletion easy_rec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
elif tf.__version__.startswith('1.12'):
ops_dir = os.path.join(ops_dir, '1.12')
elif tf.__version__.startswith('1.15'):
ops_dir = os.path.join(ops_dir, '1.15')
if 'IS_ON_PAI' in os.environ:
ops_dir = os.path.join(ops_dir, 'DeepRec')
else:
ops_dir = os.path.join(ops_dir, '1.15')
else:
ops_dir = None
else:
Expand Down
3 changes: 1 addition & 2 deletions easy_rec/python/inference/client/client_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import sys
import traceback

from easyrec_request import EasyrecRequest

from easy_rec.python.inference.client.easyrec_request import EasyrecRequest
from easy_rec.python.protos.predict_pb2 import PBFeature
from easy_rec.python.protos.predict_pb2 import PBRequest

Expand Down
3 changes: 2 additions & 1 deletion easy_rec/python/input/csv_input_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tensorflow as tf

from easy_rec.python.input.csv_input import CSVInput
from easy_rec.python.ops.gen_str_avx_op import str_split_by_chr

if tf.__version__ >= '2.0':
tf = tf.compat.v1
Expand Down Expand Up @@ -39,7 +40,7 @@ def _check_data(line):
(sep, field_num, len(record_defaults))
return True

fields = tf.string_split(
fields = str_split_by_chr(
line, self._data_config.separator, skip_empty=False)
tmp_fields = tf.reshape(fields.values, [-1, len(record_defaults)])
fields = []
Expand Down
4 changes: 2 additions & 2 deletions easy_rec/python/input/odps_rtp_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tensorflow as tf

from easy_rec.python.input.input import Input
from easy_rec.python.ops.gen_str_avx_op import str_split_by_chr
from easy_rec.python.utils.check_utils import check_split
from easy_rec.python.utils.input_utils import string_to_number

Expand Down Expand Up @@ -79,10 +80,9 @@ def _parse_table(self, *fields):
Tout=tf.bool)
] if self._check_mode else []
with tf.control_dependencies(check_list):
fields = tf.string_split(
fields = str_split_by_chr(
fields[-1], self._data_config.separator, skip_empty=False)
tmp_fields = tf.reshape(fields.values, [-1, feature_num])

fields = labels[len(self._label_fields):]
for i in range(feature_num):
field = string_to_number(tmp_fields[:, i], record_types[i],
Expand Down
3 changes: 2 additions & 1 deletion easy_rec/python/input/rtp_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tensorflow as tf

from easy_rec.python.input.input import Input
from easy_rec.python.ops.gen_str_avx_op import str_split_by_chr
from easy_rec.python.utils.check_utils import check_split
from easy_rec.python.utils.check_utils import check_string_to_number
from easy_rec.python.utils.input_utils import string_to_number
Expand Down Expand Up @@ -104,7 +105,7 @@ def _parse_csv(self, line):
Tout=tf.bool)
] if self._check_mode else []
with tf.control_dependencies(check_list):
fields = tf.string_split(
fields = str_split_by_chr(
feature_str, self._data_config.separator, skip_empty=False)
tmp_fields = tf.reshape(fields.values, [-1, len(record_types)])
rtp_record_defaults = [
Expand Down
3 changes: 3 additions & 0 deletions easy_rec/python/model/easy_rec_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,9 @@ def _train_model_fn(self, features, labels, run_config):
chief_hooks = []
if estimator_utils.is_chief():
hooks.append(saver_hook)
hooks.append(
basic_session_run_hooks.StepCounterHook(
every_n_steps=log_step_count_steps, output_dir=self.model_dir))

# profiling hook
if self.train_config.is_profiling and estimator_utils.is_chief():
Expand Down
Binary file modified easy_rec/python/ops/1.12/incr_record.so
Binary file not shown.
Binary file added easy_rec/python/ops/1.12/libstr_avx_op.so
Binary file not shown.
Binary file added easy_rec/python/ops/1.12_pai/libstr_avx_op.so
Binary file not shown.
Binary file modified easy_rec/python/ops/1.15/incr_record.so
Binary file not shown.
Binary file added easy_rec/python/ops/1.15/libstr_avx_op.so
Binary file not shown.
Binary file added easy_rec/python/ops/DeepRec/incr_record.so
Binary file not shown.
Binary file added easy_rec/python/ops/DeepRec/kafka.so
Binary file not shown.
Binary file added easy_rec/python/ops/DeepRec/libembed_op.so
Binary file not shown.
Binary file added easy_rec/python/ops/DeepRec/librdkafka++.so
Binary file not shown.
Binary file added easy_rec/python/ops/DeepRec/librdkafka++.so.1
Binary file not shown.
Binary file added easy_rec/python/ops/DeepRec/librdkafka.so
Binary file not shown.
Binary file added easy_rec/python/ops/DeepRec/librdkafka.so.1
Binary file not shown.
Binary file added easy_rec/python/ops/DeepRec/libstr_avx_op.so
Binary file not shown.
28 changes: 28 additions & 0 deletions easy_rec/python/ops/gen_str_avx_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os

import tensorflow as tf
from tensorflow.python.ops import string_ops

import easy_rec
from easy_rec.python.utils import constant

try:
str_avx_op_path = os.path.join(easy_rec.ops_dir, 'libstr_avx_op.so')
str_avx_op = tf.load_op_library(str_avx_op_path)
logging.info('load avx string_split op from %s succeed' % str_avx_op_path)
except Exception as ex:
logging.warning('load avx string_split op failed: %s' % str(ex))
str_avx_op = None


def str_split_by_chr(input_str, sep, skip_empty):
if constant.has_avx_str_split() and str_avx_op is not None:
assert len(sep) == 1, \
'invalid data_config.separator(%s) len(%d) != 1' % (
sep, len(sep))
return str_avx_op.avx512_string_split(input_str, sep, skip_empty=skip_empty)
else:
return string_ops.string_split(input_str, sep, skip_empty=skip_empty)
12 changes: 12 additions & 0 deletions easy_rec/python/test/csv_input_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
"""Define cv_input, the base class for cv tasks."""

import os
import unittest

import tensorflow as tf
from google.protobuf import text_format

Expand All @@ -10,6 +13,7 @@
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
from easy_rec.python.protos.feature_config_pb2 import FeatureConfig
from easy_rec.python.utils import config_util
from easy_rec.python.utils import constant
from easy_rec.python.utils.test_utils import RunAsSubprocess

if tf.__version__ >= '2.0':
Expand Down Expand Up @@ -264,6 +268,14 @@ def test_csv_input_ex(self):
sess.run(init_op)
feature_dict, label_dict = sess.run([features, labels])

@unittest.skipIf('AVX_TEST' not in os.environ,
'Only execute when avx512 instructions are supported')
@RunAsSubprocess
def test_csv_input_ex_avx(self):
constant.enable_avx_str_split()
self.test_csv_input_ex()
constant.disable_avx_str_split()

@RunAsSubprocess
def test_csv_data_ignore_error(self):
data_config_str = """
Expand Down
16 changes: 16 additions & 0 deletions easy_rec/python/utils/constant.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.

import os

SAMPLE_WEIGHT = 'SAMPLE_WEIGHT'

DENSE_UPDATE_VARIABLES = 'DENSE_UPDATE_VARIABLES'

SPARSE_UPDATE_VARIABLES = 'SPARSE_UPDATE_VARIABLES'
ENABLE_AVX_STR_SPLIT = 'ENABLE_AVX_STR_SPLIT'


def enable_avx_str_split():
os.environ[ENABLE_AVX_STR_SPLIT] = '1'


def has_avx_str_split():
return ENABLE_AVX_STR_SPLIT in os.environ and os.environ[
ENABLE_AVX_STR_SPLIT] == '1'


def disable_avx_str_split():
del os.environ[ENABLE_AVX_STR_SPLIT]
12 changes: 11 additions & 1 deletion pai_jobs/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
from easy_rec.python.inference.vector_retrieve import VectorRetrieve
from easy_rec.python.tools.pre_check import run_check
from easy_rec.python.utils import config_util
from easy_rec.python.utils import constant
from easy_rec.python.utils import estimator_utils
from easy_rec.python.utils import fg_util
from easy_rec.python.utils import hpo_util
from easy_rec.python.utils import pai_util
from easy_rec.python.utils.distribution_utils import DistributionStrategyMap
from easy_rec.python.utils.distribution_utils import set_distribution_config

os.environ['IS_ON_PAI'] = '1'

from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_train_worker_num # NOQA
os.environ['OENV_MultiWriteThreadsNum'] = '4'
os.environ['OENV_MultiCopyThreadsNum'] = '4'
Expand Down Expand Up @@ -175,6 +178,8 @@
tf.app.flags.DEFINE_string('asset_files', None, 'extra files to add to export')
tf.app.flags.DEFINE_bool('check_mode', False, 'is use check mode')
tf.app.flags.DEFINE_string('fg_json_path', None, '')
tf.app.flags.DEFINE_bool('enable_avx_str_split', False,
'enable avx str split to speedup')

FLAGS = tf.app.flags.FLAGS

Expand Down Expand Up @@ -220,7 +225,7 @@ def _wait_ckpt(ckpt_path, max_wait_ts):
break
else:
while time.time() - start_ts < max_wait_ts:
if gfile.Exists(ckpt_path + '.index'):
if not gfile.Exists(ckpt_path + '.index'):
logging.info('wait for checkpoint[%s]' % ckpt_path)
time.sleep(30)
else:
Expand All @@ -230,6 +235,11 @@ def _wait_ckpt(ckpt_path, max_wait_ts):

def main(argv):
pai_util.set_on_pai()
if FLAGS.enable_avx_str_split:
constant.enable_avx_str_split()
logging.info('will enable avx str split: %s' %
constant.is_avx_str_split_enabled())

if FLAGS.distribute_eval:
os.environ['distribute_eval'] = 'True'

Expand Down
1 change: 1 addition & 0 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
eas_prediction
future
matplotlib
oss2
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ multi_line_output = 7
force_single_line = true
known_standard_library = setuptools
known_first_party = easy_rec
known_third_party = absl,common_io,docutils,eas_prediction,easyrec_request,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml
known_third_party = absl,common_io,docutils,eas_prediction,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml
no_lines_before = LOCALFOLDER
default_section = THIRDPARTY
skip = easy_rec/python/protos
Expand Down

0 comments on commit 3604da3

Please sign in to comment.