Skip to content

Commit

Permalink
Support macOS & Windows and Support Tensorflow >= 2.12.0 (#361)
Browse files Browse the repository at this point in the history
* support macOS

* modify path with processing '/'

* modify path with processing '/'
  • Loading branch information
wwxxzz authored Apr 26, 2023
1 parent daacc71 commit d11321c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
8 changes: 7 additions & 1 deletion easy_rec/python/compat/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import threading
import time

import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
Expand All @@ -33,11 +35,15 @@
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.util.tf_export import estimator_export

from easy_rec.python.utils.config_util import parse_time
from easy_rec.python.utils.load_class import load_by_path

if LooseVersion(tf.__version__) >= LooseVersion('2.12.0'):
from tensorflow_estimator.python.estimator.estimator_export import estimator_export
else:
from tensorflow.python.util.tf_export import estimator_export

_EVENT_FILE_GLOB_PATTERN = 'events.out.tfevents.*'

EARLY_STOP_SIG_SCOPE = 'signal_early_stopping'
Expand Down
12 changes: 9 additions & 3 deletions git-lfs/git_lfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ def get_yes_no(msg):
if line_str.startswith('#'):
continue
line_str = line_str.replace('~/', os.environ['HOME'] + '/')
line_str = line_str.replace('${TMPDIR}', os.environ.get('TMPDIR', '/tmp'))
line_str = line_str.replace('${TMPDIR}/',
os.environ.get('TMPDIR', '/tmp/'))
line_str = line_str.replace('${PROJECT_NAME}', get_proj_name())
line_tok = [x.strip() for x in line_str.split('=') if x != '']
if line_tok[0] == 'host':
Expand Down Expand Up @@ -363,7 +364,6 @@ def get_yes_no(msg):
remote_path = git_bin_url[leaf_path][1]
_, file_name_with_sig = os.path.split(remote_path)
tar_tmp_path = '%s/%s.tar.gz' % (git_oss_cache_dir, file_name_with_sig)

max_retry = 5
while max_retry > 0:
try:
Expand All @@ -373,7 +373,13 @@ def get_yes_no(msg):
oss_bucket.get_object_to_file(remote_path, tar_tmp_path)
else:
url = 'http://%s.%s/%s' % (bucket_name, host, remote_path)
subprocess.check_output(['wget', url, '-O', tar_tmp_path])
# subprocess.check_output(['wget', url, '-O', tar_tmp_path])
if sys.platform.startswith('linux'):
subprocess.check_output(['wget', url, '-O', tar_tmp_path])
elif sys.platform.startswith('darwin'):
subprocess.check_output(['curl', url, '--output', tar_tmp_path])
elif sys.platform.startswith('win'):
subprocess.check_output(['curl', url, '--output', tar_tmp_path])
else:
in_cache = True
logging.info('%s is in cache' % file_name_with_sig)
Expand Down

0 comments on commit d11321c

Please sign in to comment.