Skip to content

Commit

Permalink
[feature]: add continue train on evaluation data (#399)
Browse files Browse the repository at this point in the history
* add support for continue train on evaluation data
* fix sync_replicas_optimizer que runners
  • Loading branch information
chengmengli06 authored Jul 20, 2023
1 parent 1be51db commit 0049d0d
Show file tree
Hide file tree
Showing 8 changed files with 419 additions and 21 deletions.
5 changes: 5 additions & 0 deletions easy_rec/python/compat/estimator_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,8 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
fout.write('Train Done.')

return result


def estimator_train_done(estimator):
train_done_file = os.path.join(estimator.model_dir, 'ESTIMATOR_TRAIN_DONE')
return gfile.Exists(train_done_file)
24 changes: 17 additions & 7 deletions easy_rec/python/compat/sync_replicas_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
```
"""

sync_que_id = -1

def __init__(self,
opt,
replicas_to_aggregate,
Expand Down Expand Up @@ -299,15 +301,24 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
update_op = self._opt.apply_gradients(aggregated_grads_and_vars,
global_step)

def _get_token_qname():
SyncReplicasOptimizer.sync_que_id += 1
if SyncReplicasOptimizer.sync_que_id == 0:
return 'sync_token_q'
else:
return 'sync_token_q_' + str(SyncReplicasOptimizer.sync_que_id)

# Create token queue.
token_qname = _get_token_qname()
logging.info('create sync_token_queue[%s]' % token_qname)
with ops.device(global_step.device), ops.name_scope(''):
sync_token_queue = (
data_flow_ops.FIFOQueue(
-1,
global_step.dtype.base_dtype,
shapes=(),
name='sync_token_q',
shared_name='sync_token_q'))
name=token_qname,
shared_name=token_qname))
self._sync_token_queue = sync_token_queue
self._is_sync_que_closed = sync_token_queue.is_closed()
self._close_sync_que = sync_token_queue.close(
Expand Down Expand Up @@ -342,6 +353,8 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None):

self._chief_queue_runner = queue_runner.QueueRunner(
dummy_queue, [sync_op])
ops.add_to_collection(ops.GraphKeys.QUEUE_RUNNERS,
self._chief_queue_runner)
for accum, dev in self._accumulator_list:
with ops.device(dev):
chief_init_ops.append(
Expand Down Expand Up @@ -479,14 +492,12 @@ def begin(self):
self._local_init_op = self._sync_optimizer.chief_init_op
self._ready_for_local_init_op = (
self._sync_optimizer.ready_for_local_init_op)
self._q_runner = self._sync_optimizer.get_chief_queue_runner()
self._init_tokens_op = self._sync_optimizer.get_init_tokens_op(
self._num_tokens)
else:
self._local_init_op = self._sync_optimizer.local_step_init_op
self._ready_for_local_init_op = (
self._sync_optimizer.ready_for_local_init_op)
self._q_runner = None
self._init_tokens_op = None

def after_create_session(self, session, coord):
Expand All @@ -500,11 +511,10 @@ def after_create_session(self, session, coord):
'local_init. Init op: %s, error: %s' %
(self._local_init_op.name, msg))
session.run(self._local_init_op)
is_closed = session.run(self._sync_optimizer._is_sync_que_closed)
assert not is_closed, 'sync_que is closed'
if self._init_tokens_op is not None:
session.run(self._init_tokens_op)
if self._q_runner is not None:
self._q_runner.create_threads(
session, coord=coord, daemon=True, start=True)

def end(self, session):
try:
Expand Down
34 changes: 29 additions & 5 deletions easy_rec/python/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import math
import os
import time

import six
import tensorflow as tf
Expand Down Expand Up @@ -279,7 +280,9 @@ def train_and_evaluate(pipeline_config_path, continue_train=False):

def _train_and_evaluate_impl(pipeline_config,
continue_train=False,
check_mode=False):
check_mode=False,
fit_on_eval=False,
fit_on_eval_steps=None):
train_config = pipeline_config.train_config
data_config = pipeline_config.data_config
feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
Expand All @@ -301,18 +304,15 @@ def _train_and_evaluate_impl(pipeline_config,
estimator, run_config = _create_estimator(
pipeline_config, distribution=distribution, params=params)

master_stat_file = os.path.join(pipeline_config.model_dir, 'master.stat')
version_file = os.path.join(pipeline_config.model_dir, 'version')
if estimator_utils.is_chief():
_check_model_dir(pipeline_config.model_dir, continue_train)
config_util.save_pipeline_config(pipeline_config, pipeline_config.model_dir)
with gfile.GFile(version_file, 'w') as f:
f.write(easy_rec.__version__ + '\n')
if gfile.Exists(master_stat_file):
gfile.Remove(master_stat_file)

train_steps = None
if train_config.HasField('num_steps'):
if train_config.HasField('num_steps') and train_config.num_steps > 0:
train_steps = train_config.num_steps
assert train_steps is not None or data_config.num_epochs > 0, (
'either num_steps and num_epochs must be set to an integer > 0.')
Expand Down Expand Up @@ -347,6 +347,30 @@ def _train_and_evaluate_impl(pipeline_config,
from easy_rec.python.compat import estimator_train
estimator_train.train_and_evaluate(estimator, train_spec, eval_spec)
logging.info('Train and evaluate finish')
if fit_on_eval and (not estimator_utils.is_evaluator()):
tf.reset_default_graph()
logging.info('Start continue training on eval data')
eval_input_fn = _get_input_fn(data_config, feature_configs, eval_data,
**input_fn_kwargs)
if fit_on_eval_steps is not None:
# wait estimator train done to get the correct train_steps
while not estimator_train.estimator_train_done(estimator):
time.sleep(1)
train_steps = estimator_utils.get_trained_steps(estimator.model_dir)
logging.info('\ttrain_steps=%d fit_on_eval_steps=%d' %
(train_steps, fit_on_eval_steps))
fit_on_eval_steps += train_steps
# Do not use estimator_train.train_and_evaluate as it starts tf.Server,
# which is redundant and reports port not available error.
estimator.train(
input_fn=eval_input_fn,
max_steps=fit_on_eval_steps,
hooks=list(train_spec.hooks),
saving_listeners=train_spec.saving_listeners if hasattr(
train_spec, 'saving_listeners') else None)
logging.info('Finished training on eval data')
# return estimator for custom training using estimator.train
return estimator


def evaluate(pipeline_config,
Expand Down
17 changes: 17 additions & 0 deletions easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,23 @@ def test_train_with_ps_worker(self):
'samples/model_config/multi_tower_on_taobao.config', self._test_dir)
self.assertTrue(self._success)

def test_fit_on_eval(self):
self._success = test_utils.test_distributed_train_eval(
'samples/model_config/multi_tower_on_taobao.config',
self._test_dir,
num_evaluator=1,
fit_on_eval=True)
self.assertTrue(self._success)

def test_unbalance_data(self):
self._success = test_utils.test_distributed_train_eval(
'samples/model_config/multi_tower_on_taobao_unblanace.config',
self._test_dir,
total_steps=0,
num_epoch=1,
num_evaluator=1)
self.assertTrue(self._success)

def test_train_with_ps_worker_with_evaluator(self):
self._success = test_utils.test_distributed_train_eval(
'samples/model_config/multi_tower_on_taobao.config',
Expand Down
18 changes: 16 additions & 2 deletions easy_rec/python/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@
nargs='*',
default=None,
help='eval data input path')
parser.add_argument(
'--fit_on_eval',
action='store_true',
default=False,
help='Fit evaluation data after fitting and evaluating train data')
parser.add_argument(
'--fit_on_eval_steps',
type=int,
default=None,
help='Fit evaluation data steps')
parser.add_argument(
'--fine_tune_checkpoint',
type=str,
Expand Down Expand Up @@ -169,7 +179,11 @@
has_evaluator=False)
else:
config_util.auto_expand_share_feature_configs(pipeline_config)
_train_and_evaluate_impl(pipeline_config, args.continue_train,
args.check_mode)
_train_and_evaluate_impl(
pipeline_config,
args.continue_train,
args.check_mode,
fit_on_eval=args.fit_on_eval,
fit_on_eval_steps=args.fit_on_eval_steps)
else:
raise ValueError('pipeline_config_path should not be empty when training!')
8 changes: 8 additions & 0 deletions easy_rec/python/utils/estimator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,14 @@ def latest_checkpoint(model_dir):
return None


def get_trained_steps(model_dir):
ckpt_path = latest_checkpoint(model_dir)
if ckpt_path is not None:
return int(ckpt_path.split('-')[-1])
else:
return 0


def master_to_chief():
if 'TF_CONFIG' in os.environ:
tf_config = json.loads(os.environ['TF_CONFIG'])
Expand Down
30 changes: 23 additions & 7 deletions easy_rec/python/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ def _replace_data_for_test(data_path):
return data_path


def _load_config_for_test(pipeline_config_path, test_dir, total_steps=50):
def _load_config_for_test(pipeline_config_path,
test_dir,
total_steps=50,
num_epochs=0):
pipeline_config = config_util.get_configs_from_pipeline_file(
pipeline_config_path)
train_config = pipeline_config.train_config
Expand All @@ -171,7 +174,7 @@ def _load_config_for_test(pipeline_config_path, test_dir, total_steps=50):
pipeline_config.model_dir = os.path.join(test_dir, 'train')
logging.info('test_model_dir %s' % pipeline_config.model_dir)
eval_config.num_examples = max(10, data_config.batch_size)
data_config.num_epochs = 0
data_config.num_epochs = num_epochs
return pipeline_config


Expand Down Expand Up @@ -529,7 +532,9 @@ def _get_ports(num_worker):
def _ps_worker_train(pipeline_config_path,
test_dir,
num_worker,
num_evaluator=0):
num_evaluator=0,
fit_on_eval=False,
fit_on_eval_steps=None):
gpus = get_available_gpus()
# not enough gpus, run on cpu only
if len(gpus) < num_worker:
Expand All @@ -547,6 +552,10 @@ def _ps_worker_train(pipeline_config_path,
os.environ['TF_CONFIG'] = json.dumps(tf_config)
set_gpu_id(gpus[0])
train_cmd = 'python -m easy_rec.python.train_eval --pipeline_config_path %s' % pipeline_config_path
if fit_on_eval:
train_cmd += ' --fit_on_eval'
if fit_on_eval_steps is not None:
train_cmd += ' --fit_on_eval_steps ' + str(int(fit_on_eval_steps))
procs[chief_or_master] = run_cmd(
train_cmd, '%s/log_%s.txt' % (test_dir, chief_or_master))
tf_config['task'] = {'type': 'ps', 'index': 0}
Expand Down Expand Up @@ -665,10 +674,12 @@ def test_distributed_train_eval(pipeline_config_path,
total_steps=50,
num_evaluator=0,
edit_config_json=None,
use_hvd=False):
use_hvd=False,
fit_on_eval=False,
num_epoch=0):
logging.info('testing pipeline config %s' % pipeline_config_path)
pipeline_config = _load_config_for_test(pipeline_config_path, test_dir,
total_steps)
total_steps, num_epoch)
if edit_config_json is not None:
config_util.edit_config(pipeline_config, edit_config_json)

Expand All @@ -687,8 +698,13 @@ def test_distributed_train_eval(pipeline_config_path,
return _multi_worker_hvd_train(test_pipeline_config_path, test_dir, 2)
if train_config.train_distribute == DistributionStrategy.NoStrategy:
num_worker = 2
procs = _ps_worker_train(test_pipeline_config_path, test_dir, num_worker,
num_evaluator)
procs = _ps_worker_train(
test_pipeline_config_path,
test_dir,
num_worker,
num_evaluator,
fit_on_eval,
fit_on_eval_steps=int(total_steps // 2))
elif train_config.train_distribute == DistributionStrategy.MultiWorkerMirroredStrategy:
num_worker = 2
procs = _multi_worker_mirror_train(test_pipeline_config_path, test_dir,
Expand Down
Loading

0 comments on commit 0049d0d

Please sign in to comment.