diff --git a/easy_rec/__init__.py b/easy_rec/__init__.py index bd4bdae9b..ccb5b80bd 100644 --- a/easy_rec/__init__.py +++ b/easy_rec/__init__.py @@ -25,7 +25,10 @@ elif tf.__version__.startswith('1.12'): ops_dir = os.path.join(ops_dir, '1.12') elif tf.__version__.startswith('1.15'): - ops_dir = os.path.join(ops_dir, '1.15') + if 'IS_ON_PAI' in os.environ: + ops_dir = os.path.join(ops_dir, 'DeepRec') + else: + ops_dir = os.path.join(ops_dir, '1.15') else: ops_dir = None else: diff --git a/easy_rec/python/inference/client/client_demo.py b/easy_rec/python/inference/client/client_demo.py index 50160bc21..9464b1073 100644 --- a/easy_rec/python/inference/client/client_demo.py +++ b/easy_rec/python/inference/client/client_demo.py @@ -5,8 +5,7 @@ import sys import traceback -from easyrec_request import EasyrecRequest - +from easy_rec.python.inference.client.easyrec_request import EasyrecRequest from easy_rec.python.protos.predict_pb2 import PBFeature from easy_rec.python.protos.predict_pb2 import PBRequest diff --git a/easy_rec/python/input/csv_input_ex.py b/easy_rec/python/input/csv_input_ex.py index 0072c04b2..d3b506fce 100644 --- a/easy_rec/python/input/csv_input_ex.py +++ b/easy_rec/python/input/csv_input_ex.py @@ -4,6 +4,7 @@ import tensorflow as tf from easy_rec.python.input.csv_input import CSVInput +from easy_rec.python.ops.gen_str_avx_op import str_split_by_chr if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -39,7 +40,7 @@ def _check_data(line): (sep, field_num, len(record_defaults)) return True - fields = tf.string_split( + fields = str_split_by_chr( line, self._data_config.separator, skip_empty=False) tmp_fields = tf.reshape(fields.values, [-1, len(record_defaults)]) fields = [] diff --git a/easy_rec/python/input/odps_rtp_input.py b/easy_rec/python/input/odps_rtp_input.py index 0cbf6b093..6ae6096e0 100644 --- a/easy_rec/python/input/odps_rtp_input.py +++ b/easy_rec/python/input/odps_rtp_input.py @@ -5,6 +5,7 @@ import tensorflow as tf from easy_rec.python.input.input import Input +from easy_rec.python.ops.gen_str_avx_op import str_split_by_chr from easy_rec.python.utils.check_utils import check_split from easy_rec.python.utils.input_utils import string_to_number @@ -79,10 +80,9 @@ def _parse_table(self, *fields): Tout=tf.bool) ] if self._check_mode else [] with tf.control_dependencies(check_list): - fields = tf.string_split( + fields = str_split_by_chr( fields[-1], self._data_config.separator, skip_empty=False) tmp_fields = tf.reshape(fields.values, [-1, feature_num]) - fields = labels[len(self._label_fields):] for i in range(feature_num): field = string_to_number(tmp_fields[:, i], record_types[i], diff --git a/easy_rec/python/input/rtp_input.py b/easy_rec/python/input/rtp_input.py index e8cb860b8..9f9679b9e 100644 --- a/easy_rec/python/input/rtp_input.py +++ b/easy_rec/python/input/rtp_input.py @@ -5,6 +5,7 @@ import tensorflow as tf from easy_rec.python.input.input import Input +from easy_rec.python.ops.gen_str_avx_op import str_split_by_chr from easy_rec.python.utils.check_utils import check_split from easy_rec.python.utils.check_utils import check_string_to_number from easy_rec.python.utils.input_utils import string_to_number @@ -104,7 +105,7 @@ def _parse_csv(self, line): Tout=tf.bool) ] if self._check_mode else [] with tf.control_dependencies(check_list): - fields = tf.string_split( + fields = str_split_by_chr( feature_str, self._data_config.separator, skip_empty=False) tmp_fields = tf.reshape(fields.values, [-1, len(record_types)]) rtp_record_defaults = [ diff --git a/easy_rec/python/model/easy_rec_estimator.py b/easy_rec/python/model/easy_rec_estimator.py index 23b923a6b..51ecad09f 100644 --- a/easy_rec/python/model/easy_rec_estimator.py +++ b/easy_rec/python/model/easy_rec_estimator.py @@ -410,6 +410,9 @@ def _train_model_fn(self, features, labels, run_config): chief_hooks = [] if estimator_utils.is_chief(): hooks.append(saver_hook) + hooks.append( + basic_session_run_hooks.StepCounterHook( + every_n_steps=log_step_count_steps, output_dir=self.model_dir)) # profiling hook if self.train_config.is_profiling and estimator_utils.is_chief(): diff --git a/easy_rec/python/ops/1.12/incr_record.so b/easy_rec/python/ops/1.12/incr_record.so index 91ead3a4e..821391e7b 100755 Binary files a/easy_rec/python/ops/1.12/incr_record.so and b/easy_rec/python/ops/1.12/incr_record.so differ diff --git a/easy_rec/python/ops/1.12/libstr_avx_op.so b/easy_rec/python/ops/1.12/libstr_avx_op.so new file mode 100755 index 000000000..8544d120c Binary files /dev/null and b/easy_rec/python/ops/1.12/libstr_avx_op.so differ diff --git a/easy_rec/python/ops/1.12_pai/libstr_avx_op.so b/easy_rec/python/ops/1.12_pai/libstr_avx_op.so new file mode 100755 index 000000000..8544d120c Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/libstr_avx_op.so differ diff --git a/easy_rec/python/ops/1.15/incr_record.so b/easy_rec/python/ops/1.15/incr_record.so index 139a0aa1d..a548b9c9c 100755 Binary files a/easy_rec/python/ops/1.15/incr_record.so and b/easy_rec/python/ops/1.15/incr_record.so differ diff --git a/easy_rec/python/ops/1.15/libstr_avx_op.so b/easy_rec/python/ops/1.15/libstr_avx_op.so new file mode 100755 index 000000000..4237e9820 Binary files /dev/null and b/easy_rec/python/ops/1.15/libstr_avx_op.so differ diff --git a/easy_rec/python/ops/DeepRec/incr_record.so b/easy_rec/python/ops/DeepRec/incr_record.so new file mode 100755 index 000000000..fd8f73a48 Binary files /dev/null and b/easy_rec/python/ops/DeepRec/incr_record.so differ diff --git a/easy_rec/python/ops/DeepRec/kafka.so b/easy_rec/python/ops/DeepRec/kafka.so new file mode 100755 index 000000000..ec0d5b9f0 Binary files /dev/null and b/easy_rec/python/ops/DeepRec/kafka.so differ diff --git a/easy_rec/python/ops/DeepRec/libembed_op.so b/easy_rec/python/ops/DeepRec/libembed_op.so new file mode 100755 index 000000000..58975bd6f Binary files /dev/null and b/easy_rec/python/ops/DeepRec/libembed_op.so differ diff --git a/easy_rec/python/ops/DeepRec/librdkafka++.so b/easy_rec/python/ops/DeepRec/librdkafka++.so new file mode 100755 index 000000000..d9a8463e0 Binary files /dev/null and b/easy_rec/python/ops/DeepRec/librdkafka++.so differ diff --git a/easy_rec/python/ops/DeepRec/librdkafka++.so.1 b/easy_rec/python/ops/DeepRec/librdkafka++.so.1 new file mode 100755 index 000000000..d9a8463e0 Binary files /dev/null and b/easy_rec/python/ops/DeepRec/librdkafka++.so.1 differ diff --git a/easy_rec/python/ops/DeepRec/librdkafka.so b/easy_rec/python/ops/DeepRec/librdkafka.so new file mode 100755 index 000000000..431eeb3cf Binary files /dev/null and b/easy_rec/python/ops/DeepRec/librdkafka.so differ diff --git a/easy_rec/python/ops/DeepRec/librdkafka.so.1 b/easy_rec/python/ops/DeepRec/librdkafka.so.1 new file mode 100755 index 000000000..431eeb3cf Binary files /dev/null and b/easy_rec/python/ops/DeepRec/librdkafka.so.1 differ diff --git a/easy_rec/python/ops/DeepRec/libstr_avx_op.so b/easy_rec/python/ops/DeepRec/libstr_avx_op.so new file mode 100755 index 000000000..bb8d36306 Binary files /dev/null and b/easy_rec/python/ops/DeepRec/libstr_avx_op.so differ diff --git a/easy_rec/python/ops/gen_str_avx_op.py b/easy_rec/python/ops/gen_str_avx_op.py new file mode 100644 index 000000000..d022d52cb --- /dev/null +++ b/easy_rec/python/ops/gen_str_avx_op.py @@ -0,0 +1,28 @@ +# -*- encoding:utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import logging +import os + +import tensorflow as tf +from tensorflow.python.ops import string_ops + +import easy_rec +from easy_rec.python.utils import constant + +try: + str_avx_op_path = os.path.join(easy_rec.ops_dir, 'libstr_avx_op.so') + str_avx_op = tf.load_op_library(str_avx_op_path) + logging.info('load avx string_split op from %s succeed' % str_avx_op_path) +except Exception as ex: + logging.warning('load avx string_split op failed: %s' % str(ex)) + str_avx_op = None + + +def str_split_by_chr(input_str, sep, skip_empty): + if constant.has_avx_str_split() and str_avx_op is not None: + assert len(sep) == 1, \ + 'invalid data_config.separator(%s) len(%d) != 1' % ( + sep, len(sep)) + return str_avx_op.avx512_string_split(input_str, sep, skip_empty=skip_empty) + else: + return string_ops.string_split(input_str, sep, skip_empty=skip_empty) diff --git a/easy_rec/python/test/csv_input_test.py b/easy_rec/python/test/csv_input_test.py index 576b42297..ae0793fa5 100644 --- a/easy_rec/python/test/csv_input_test.py +++ b/easy_rec/python/test/csv_input_test.py @@ -2,6 +2,9 @@ # Copyright (c) Alibaba, Inc. and its affiliates. """Define cv_input, the base class for cv tasks.""" +import os +import unittest + import tensorflow as tf from google.protobuf import text_format @@ -10,6 +13,7 @@ from easy_rec.python.protos.dataset_pb2 import DatasetConfig from easy_rec.python.protos.feature_config_pb2 import FeatureConfig from easy_rec.python.utils import config_util +from easy_rec.python.utils import constant from easy_rec.python.utils.test_utils import RunAsSubprocess if tf.__version__ >= '2.0': @@ -264,6 +268,14 @@ def test_csv_input_ex(self): sess.run(init_op) feature_dict, label_dict = sess.run([features, labels]) + @unittest.skipIf('AVX_TEST' not in os.environ, + 'Only execute when avx512 instructions are supported') + @RunAsSubprocess + def test_csv_input_ex_avx(self): + constant.enable_avx_str_split() + self.test_csv_input_ex() + constant.disable_avx_str_split() + @RunAsSubprocess def test_csv_data_ignore_error(self): data_config_str = """ diff --git a/easy_rec/python/utils/constant.py b/easy_rec/python/utils/constant.py index 8caecaba8..74ea046ed 100644 --- a/easy_rec/python/utils/constant.py +++ b/easy_rec/python/utils/constant.py @@ -1,8 +1,24 @@ # -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. +import os + SAMPLE_WEIGHT = 'SAMPLE_WEIGHT' DENSE_UPDATE_VARIABLES = 'DENSE_UPDATE_VARIABLES' SPARSE_UPDATE_VARIABLES = 'SPARSE_UPDATE_VARIABLES' +ENABLE_AVX_STR_SPLIT = 'ENABLE_AVX_STR_SPLIT' + + +def enable_avx_str_split(): + os.environ[ENABLE_AVX_STR_SPLIT] = '1' + + +def has_avx_str_split(): + return ENABLE_AVX_STR_SPLIT in os.environ and os.environ[ + ENABLE_AVX_STR_SPLIT] == '1' + + +def disable_avx_str_split(): + del os.environ[ENABLE_AVX_STR_SPLIT] diff --git a/pai_jobs/run.py b/pai_jobs/run.py index ed02c73c5..41c61ad31 100644 --- a/pai_jobs/run.py +++ b/pai_jobs/run.py @@ -16,6 +16,7 @@ from easy_rec.python.inference.vector_retrieve import VectorRetrieve from easy_rec.python.tools.pre_check import run_check from easy_rec.python.utils import config_util +from easy_rec.python.utils import constant from easy_rec.python.utils import estimator_utils from easy_rec.python.utils import fg_util from easy_rec.python.utils import hpo_util @@ -23,6 +24,8 @@ from easy_rec.python.utils.distribution_utils import DistributionStrategyMap from easy_rec.python.utils.distribution_utils import set_distribution_config +os.environ['IS_ON_PAI'] = '1' + from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_train_worker_num # NOQA os.environ['OENV_MultiWriteThreadsNum'] = '4' os.environ['OENV_MultiCopyThreadsNum'] = '4' @@ -175,6 +178,8 @@ tf.app.flags.DEFINE_string('asset_files', None, 'extra files to add to export') tf.app.flags.DEFINE_bool('check_mode', False, 'is use check mode') tf.app.flags.DEFINE_string('fg_json_path', None, '') +tf.app.flags.DEFINE_bool('enable_avx_str_split', False, + 'enable avx str split to speedup') FLAGS = tf.app.flags.FLAGS @@ -220,7 +225,7 @@ def _wait_ckpt(ckpt_path, max_wait_ts): break else: while time.time() - start_ts < max_wait_ts: - if gfile.Exists(ckpt_path + '.index'): + if not gfile.Exists(ckpt_path + '.index'): logging.info('wait for checkpoint[%s]' % ckpt_path) time.sleep(30) else: @@ -230,6 +235,11 @@ def _wait_ckpt(ckpt_path, max_wait_ts): def main(argv): pai_util.set_on_pai() + if FLAGS.enable_avx_str_split: + constant.enable_avx_str_split() + logging.info('will enable avx str split: %s' % + constant.is_avx_str_split_enabled()) + if FLAGS.distribute_eval: os.environ['distribute_eval'] = 'True' diff --git a/requirements/runtime.txt b/requirements/runtime.txt index dbc3b5872..0bfc70146 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,3 +1,4 @@ +eas_prediction future matplotlib oss2 diff --git a/setup.cfg b/setup.cfg index 7172f3302..2303ef802 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,7 +10,7 @@ multi_line_output = 7 force_single_line = true known_standard_library = setuptools known_first_party = easy_rec -known_third_party = absl,common_io,docutils,eas_prediction,easyrec_request,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml +known_third_party = absl,common_io,docutils,eas_prediction,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml no_lines_before = LOCALFOLDER default_section = THIRDPARTY skip = easy_rec/python/protos