From bf94599efb9a6218785001b7b5deae6b242932fe Mon Sep 17 00:00:00 2001 From: Jay Chia Date: Thu, 22 Jun 2023 16:51:26 -0700 Subject: [PATCH] Add native downloader http tests --- requirements-dev.txt | 1 + tests/remote_io/test_url_download_http.py | 169 ++++++++++++++++++---- 2 files changed, 141 insertions(+), 29 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 4f0c9ed351..e593bff567 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -27,6 +27,7 @@ pandas==2.0.2; python_version >= '3.8' xxhash>=3.0.0 Pillow==9.5.0 opencv-python==4.7.0.72 +requests==2.31.0 # Ray ray[data, default]==2.4.0 diff --git a/tests/remote_io/test_url_download_http.py b/tests/remote_io/test_url_download_http.py index 1959fa46cc..8a5f40cb69 100644 --- a/tests/remote_io/test_url_download_http.py +++ b/tests/remote_io/test_url_download_http.py @@ -1,15 +1,115 @@ from __future__ import annotations import pathlib -import subprocess +import socketserver import time +from http.server import BaseHTTPRequestHandler, SimpleHTTPRequestHandler +from multiprocessing import Process +from socket import socket import pytest +import requests +from aiohttp.client_exceptions import ClientResponseError import daft from tests.remote_io.conftest import YieldFixture +def _get_free_port(): + """Helper to get a free port number - may be susceptible to race conditions, + but is likely good enough for our unit testing usecase. + """ + with socket() as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def _wait_for_server(ready_url: str, max_wait_time_s: int = 1): + """Waits for a server to be up and serving 200's from the provided URL""" + SLEEP_INTERVAL = 0.1 + for _ in range(int(max_wait_time_s / SLEEP_INTERVAL)): + try: + if requests.get(ready_url).status_code == 200: + break + except requests.exceptions.ConnectionError: + time.sleep(SLEEP_INTERVAL) + else: + raise RuntimeError("Timed out while waiting for mock HTTP server fixture to be ready") + + +def _serve_error_server(code, port): + """Target function for serving a HTTP service that always throws the specified error code""" + + class MyAlwaysThrowHandler(BaseHTTPRequestHandler): + def do_GET(self): + if self.path == "/ready": + self.send_response(200) + self.end_headers() + return + self.send_response(code) + self.end_headers() + self.wfile.write(b"Some message") + + with socketserver.TCPServer(("", port), MyAlwaysThrowHandler) as httpd: + httpd.serve_forever() + + +def _serve_file_server(port, directory): + """Target function for serving a HTTP service that serves files from a directory""" + + class ServeFileHandler(SimpleHTTPRequestHandler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, directory=directory) + + def do_GET(self): + if self.path == "/ready": + self.send_response(200) + self.end_headers() + return + super().do_GET() + + with socketserver.TCPServer(("", port), ServeFileHandler) as httpd: + httpd.serve_forever() + + +@pytest.fixture( + scope="function", + params=[ + # Unauthorized + 401, + # Forbidden + 403, + # Not found + 404, + # Too many requests + 429, + # Internal server error + 500, + # Service unavailable + 503, + ], +) +def mock_error_http_server(request) -> YieldFixture[tuple[str, int]]: + """Provides a mock HTTP server that throws various HTTP status code errors when receiving any GET requests + + This fixture yields a tuple of: + str: URL to the HTTP server + int: HTTP status code that it throws when accessed with a GET request at any path + """ + code = request.param + port = _get_free_port() + url = f"http://localhost:{port}" + + p = Process(target=_serve_error_server, args=(code, port)) + p.start() + try: + _wait_for_server(f"{url}/ready") + yield (url, code) + finally: + p.terminate() + p.join() + + @pytest.fixture(scope="session") def mock_http_server(tmp_path_factory) -> YieldFixture[tuple[str, pathlib.Path]]: """Provides a mock HTTP server that serves files in a given directory @@ -18,51 +118,62 @@ def mock_http_server(tmp_path_factory) -> YieldFixture[tuple[str, pathlib.Path]] str: URL to the HTTP server pathlib.Path: tmpdir to place files into, which will then be served by the HTTP server """ - tmpdir = tmp_path_factory.mktemp("data") + port = _get_free_port() + url = f"http://localhost:{port}" - PORT = 8000 - proc = subprocess.Popen(["python", "-m", "http.server", "-d", str(tmpdir), str(PORT)]) - - # Give the server some time to spin up - time.sleep(0.2) - - yield (f"http://localhost:{PORT}", tmpdir) - - proc.kill() + p = Process(target=_serve_file_server, args=(port, str(tmpdir))) + p.start() + try: + _wait_for_server(f"{url}/ready") + yield (url, tmpdir) + finally: + p.terminate() + p.join() @pytest.fixture(scope="function") def http_image_data_fixture(mock_http_server, image_data) -> YieldFixture[list[str]]: """Populates the mock HTTP server with some fake data and returns filepaths""" - # Dump some images into the tmpdir + # NOTE: We use 1 image because the HTTPServer that we use is pretty bad at handling concurrent requests server_url, tmpdir = mock_http_server - urls = [] - for i in range(10): - path = tmpdir / f"{i}.jpeg" - path.write_bytes(image_data) - urls.append(f"{server_url}/{path.relative_to(tmpdir)}") - - yield urls + path = tmpdir / f"img.jpeg" + path.write_bytes(image_data) + yield [f"{server_url}/{path.relative_to(tmpdir)}"] # Cleanup tmpdir for child in tmpdir.glob("*"): child.unlink() -def test_url_download_http(http_image_data_fixture, image_data): +@pytest.mark.parametrize("use_native_downloader", [True, False]) +def test_url_download_http(http_image_data_fixture, image_data, use_native_downloader): data = {"urls": http_image_data_fixture} df = daft.from_pydict(data) - df = df.with_column("data", df["urls"].url.download()) + df = df.with_column("data", df["urls"].url.download(use_native_downloader=use_native_downloader)) assert df.to_pydict() == {**data, "data": [image_data for _ in range(len(http_image_data_fixture))]} -def test_url_download_http_missing(mock_http_server): - server_url, _ = mock_http_server - data = {"urls": [f"{server_url}/missing.jpeg"]} +@pytest.mark.parametrize("use_native_downloader", [True, False]) +def test_url_download_http_error_codes(mock_error_http_server, use_native_downloader): + url, code = mock_error_http_server + data = {"urls": [f"{url}/missing.jpeg"]} df = daft.from_pydict(data) - df = df.with_column("data", df["urls"].url.download(on_error="raise")) - - # NOTE: if using fsspec FileNotFoundError will be correctly thrown by fsspec.implementations.http.HTTPFileSystem - with pytest.raises(FileNotFoundError): - df.collect() + df = df.with_column("data", df["urls"].url.download(on_error="raise", use_native_downloader=use_native_downloader)) + + # 404 should always be corner-cased to return FileNotFoundError regardless of I/O implementation + if code == 404: + with pytest.raises(FileNotFoundError): + df.collect() + # When using fsspec, other error codes are bubbled up to the user as aiohttp.client_exceptions.ClientResponseError + elif not use_native_downloader: + with pytest.raises(ClientResponseError) as e: + df.collect() + assert e.value.code == code + # When using native downloader, we throw a ValueError + else: + with pytest.raises(ValueError) as e: + df.collect() + # NOTE: We may want to add better errors in the future to provide a better + # user-facing I/O error with the error code + assert f"Status({code})" in str(e.value)