Skip to content

Commit

Permalink
fix multi process download model (PaddlePaddle#662)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghan1992 committed May 20, 2021
1 parent 2ed67ee commit af81999
Showing 1 changed file with 34 additions and 19 deletions.
53 changes: 34 additions & 19 deletions ernie/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from tqdm import tqdm
from pathlib import Path
import six
import paddle as P
import time
if six.PY2:
from pathlib2 import Path
else:
Expand All @@ -33,32 +35,45 @@ def _fetch_from_remote(url,
force_download=False,
cached_dir='~/.paddle-ernie-cache'):
import hashlib, tempfile, requests, tarfile
env = P.distributed.ParallelEnv()

sig = hashlib.md5(url.encode('utf8')).hexdigest()
cached_dir = Path(cached_dir).expanduser()
try:
cached_dir.mkdir()
except OSError:
pass
cached_dir_model = cached_dir / sig
if force_download or not cached_dir_model.exists():
cached_dir_model.mkdir()
tmpfile = cached_dir_model / 'tmp'
with tmpfile.open('wb') as f:
#url = 'https://ernie.bj.bcebos.com/ERNIE_stable.tgz'
r = requests.get(url, stream=True)
total_len = int(r.headers.get('content-length'))
for chunk in tqdm(
r.iter_content(chunk_size=1024),
total=total_len // 1024,
desc='downloading %s' % url,
unit='KB'):
if chunk:
f.write(chunk)
f.flush()
log.debug('extacting... to %s' % tmpfile)
with tarfile.open(tmpfile.as_posix()) as tf:
tf.extractall(path=cached_dir_model.as_posix())
os.remove(tmpfile.as_posix())
done_file = cached_dir_model / 'fetch_done'
if force_download or not done_file.exists():
if env.dev_id == 0:
cached_dir_model.mkdir()
tmpfile = cached_dir_model / 'tmp'
with tmpfile.open('wb') as f:
#url = 'https://ernie.bj.bcebos.com/ERNIE_stable.tgz'
r = requests.get(url, stream=True)
total_len = int(r.headers.get('content-length'))
for chunk in tqdm(
r.iter_content(chunk_size=1024),
total=total_len // 1024,
desc='downloading %s' % url,
unit='KB'):
if chunk:
f.write(chunk)
f.flush()
log.debug('extacting... to %s' % tmpfile)
with tarfile.open(tmpfile.as_posix()) as tf:
tf.extractall(path=cached_dir_model.as_posix())
os.remove(tmpfile.as_posix())
f = done_file.open('wb')
f.close()
else:
while True:
if done_file.exists():
break
else:
time.sleep(1)

log.debug('%s cached in %s' % (url, cached_dir))
return cached_dir_model

Expand Down

0 comments on commit af81999

Please sign in to comment.