Skip to content

Commit

Permalink
add custom user agent for download_url (#3499)
Browse files Browse the repository at this point in the history
* add custom user agent for download_url

* fix progress bar

* lint

* [test] use repo instead of nightly for download tests

* .circleci: Be specific about where pytorch is coming from

Signed-off-by: Eli Uriegas <[email protected]>

* torchvision: Add more info to user-agent

Signed-off-by: Eli Uriegas <[email protected]>

* .circleci: Increase timeout for conda packages

The conda resolver is extremely slow so let's just give it more time to
idly sit by and resolve dependencies

Signed-off-by: Eli Uriegas <[email protected]>

Co-authored-by: Eli Uriegas <[email protected]>
  • Loading branch information
pmeier and seemethere authored Mar 3, 2021
1 parent 506279c commit 01dfa8e
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 23 deletions.
8 changes: 6 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,11 @@ jobs:
steps:
- checkout_merge
- designate_upload_channel
- run: packaging/build_conda.sh
- run:
name: Build conda packages
no_output_timeout: 20m
command: |
packaging/build_conda.sh
- store_artifacts:
path: /opt/conda/conda-bld/linux-64
- persist_to_workspace:
Expand Down Expand Up @@ -593,7 +597,7 @@ jobs:

keys:
- env-v1-windows-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/windows/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}

- run:
name: Setup
command: .circleci/unittest/windows/scripts/setup_env.sh
Expand Down
8 changes: 6 additions & 2 deletions .circleci/config.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,11 @@ jobs:
steps:
- checkout_merge
- designate_upload_channel
- run: packaging/build_conda.sh
- run:
name: Build conda packages
no_output_timeout: 20m
command: |
packaging/build_conda.sh
- store_artifacts:
path: /opt/conda/conda-bld/linux-64
- persist_to_workspace:
Expand Down Expand Up @@ -593,7 +597,7 @@ jobs:
{% raw %}
keys:
- env-v1-windows-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/windows/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
{% endraw %}
{% endraw %}
- run:
name: Setup
command: .circleci/unittest/windows/scripts/setup_env.sh
Expand Down
2 changes: 1 addition & 1 deletion .circleci/unittest/linux/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ else
fi

printf "Installing PyTorch with %s\n" "${cudatoolkit}"
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" -c conda-forge pytorch "${cudatoolkit}"
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" -c conda-forge "pytorch-${UPLOAD_CHANNEL}::pytorch" "${cudatoolkit}"

printf "* Installing torchvision\n"
python setup.py develop
2 changes: 1 addition & 1 deletion .circleci/unittest/windows/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ else
fi

printf "Installing PyTorch with %s\n" "${cudatoolkit}"
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" -c conda-forge pytorch "${cudatoolkit}"
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" -c conda-forge "pytorch-${UPLOAD_CHANNEL}::pytorch" "${cudatoolkit}"

printf "* Installing torchvision\n"
"$this_dir/vc_env_helper.bat" python setup.py develop
9 changes: 5 additions & 4 deletions .github/workflows/tests-schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v2

- name: Install PyTorch from the nightlies
run: |
pip install numpy
pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
- name: Install torch nightly build
run: pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html

- name: Install torchvision
run: pip install -e .

- name: Install all optional dataset requirements
run: pip install scipy pandas pycocotools lmdb requests
Expand Down
7 changes: 4 additions & 3 deletions test/test_datasets_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pytest

from torchvision import datasets
from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive
from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive, USER_AGENT

from common_utils import get_tmp_dir
from fakedata_generation import places365_root
Expand Down Expand Up @@ -150,7 +150,7 @@ def assert_server_response_ok():


def assert_url_is_accessible(url, timeout=5.0):
request = Request(url, headers=dict(method="HEAD"))
request = Request(url, headers={"method": "HEAD", "User-Agent": USER_AGENT})
with assert_server_response_ok():
urlopen(request, timeout=timeout)

Expand All @@ -160,7 +160,8 @@ def assert_file_downloads_correctly(url, md5, timeout=5.0):
file = path.join(root, path.basename(url))
with assert_server_response_ok():
with open(file, "wb") as fh:
response = urlopen(url, timeout=timeout)
request = Request(url, headers={"User-Agent": USER_AGENT})
response = urlopen(request, timeout=timeout)
fh.write(response.read())

assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
Expand Down
35 changes: 25 additions & 10 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,32 @@
from typing import Any, Callable, List, Iterable, Optional, TypeVar
from urllib.parse import urlparse
import zipfile
import urllib
import urllib.request
import urllib.error

import torch
from torch.utils.model_zoo import tqdm
try:
from ..version import __version__ as __vision_version__ # noqa: F401
except ImportError:
__vision_version__ = "undefined"

USER_AGENT = os.environ.get(
"TORCHVISION_USER_AGENT",
f"pytorch-{torch.__version__}/vision-{__vision_version__}"
)


def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
with open(filename, "wb") as fh:
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
with tqdm(total=response.length) as pbar:
for chunk in iter(lambda: response.read(chunk_size), ""):
if not chunk:
break
pbar.update(chunk_size)
fh.write(chunk)


def gen_bar_updater() -> Callable[[int, int, int], None]:
Expand Down Expand Up @@ -83,8 +106,6 @@ def download_url(
md5 (str, optional): MD5 checksum of the download. If None, do not check
max_redirect_hops (int, optional): Maximum number of redirect hops allowed
"""
import urllib

root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
Expand All @@ -108,19 +129,13 @@ def download_url(
# download the file
try:
print('Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater()
)
_urlretrieve(url, fpath)
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
if url[:5] == 'https':
url = url.replace('https:', 'http:')
print('Failed download. Trying https -> http instead.'
' Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater()
)
_urlretrieve(url, fpath)
else:
raise e
# check integrity of downloaded file
Expand Down

0 comments on commit 01dfa8e

Please sign in to comment.