Skip to content

Commit

Permalink
Add native downloader http tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Jun 22, 2023
1 parent 15092fa commit bf94599
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 29 deletions.
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pandas==2.0.2; python_version >= '3.8'
xxhash>=3.0.0
Pillow==9.5.0
opencv-python==4.7.0.72
requests==2.31.0

# Ray
ray[data, default]==2.4.0
Expand Down
169 changes: 140 additions & 29 deletions tests/remote_io/test_url_download_http.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,115 @@
from __future__ import annotations

import pathlib
import subprocess
import socketserver
import time
from http.server import BaseHTTPRequestHandler, SimpleHTTPRequestHandler
from multiprocessing import Process
from socket import socket

import pytest
import requests
from aiohttp.client_exceptions import ClientResponseError

import daft
from tests.remote_io.conftest import YieldFixture


def _get_free_port():
"""Helper to get a free port number - may be susceptible to race conditions,
but is likely good enough for our unit testing usecase.
"""
with socket() as s:
s.bind(("", 0))
return s.getsockname()[1]


def _wait_for_server(ready_url: str, max_wait_time_s: int = 1):
"""Waits for a server to be up and serving 200's from the provided URL"""
SLEEP_INTERVAL = 0.1
for _ in range(int(max_wait_time_s / SLEEP_INTERVAL)):
try:
if requests.get(ready_url).status_code == 200:
break
except requests.exceptions.ConnectionError:
time.sleep(SLEEP_INTERVAL)
else:
raise RuntimeError("Timed out while waiting for mock HTTP server fixture to be ready")


def _serve_error_server(code, port):
"""Target function for serving a HTTP service that always throws the specified error code"""

class MyAlwaysThrowHandler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == "/ready":
self.send_response(200)
self.end_headers()
return
self.send_response(code)
self.end_headers()
self.wfile.write(b"Some message")

with socketserver.TCPServer(("", port), MyAlwaysThrowHandler) as httpd:
httpd.serve_forever()


def _serve_file_server(port, directory):
"""Target function for serving a HTTP service that serves files from a directory"""

class ServeFileHandler(SimpleHTTPRequestHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, directory=directory)

def do_GET(self):
if self.path == "/ready":
self.send_response(200)
self.end_headers()
return
super().do_GET()

with socketserver.TCPServer(("", port), ServeFileHandler) as httpd:
httpd.serve_forever()


@pytest.fixture(
scope="function",
params=[
# Unauthorized
401,
# Forbidden
403,
# Not found
404,
# Too many requests
429,
# Internal server error
500,
# Service unavailable
503,
],
)
def mock_error_http_server(request) -> YieldFixture[tuple[str, int]]:
"""Provides a mock HTTP server that throws various HTTP status code errors when receiving any GET requests
This fixture yields a tuple of:
str: URL to the HTTP server
int: HTTP status code that it throws when accessed with a GET request at any path
"""
code = request.param
port = _get_free_port()
url = f"http://localhost:{port}"

p = Process(target=_serve_error_server, args=(code, port))
p.start()
try:
_wait_for_server(f"{url}/ready")
yield (url, code)
finally:
p.terminate()
p.join()


@pytest.fixture(scope="session")
def mock_http_server(tmp_path_factory) -> YieldFixture[tuple[str, pathlib.Path]]:
"""Provides a mock HTTP server that serves files in a given directory
Expand All @@ -18,51 +118,62 @@ def mock_http_server(tmp_path_factory) -> YieldFixture[tuple[str, pathlib.Path]]
str: URL to the HTTP server
pathlib.Path: tmpdir to place files into, which will then be served by the HTTP server
"""

tmpdir = tmp_path_factory.mktemp("data")
port = _get_free_port()
url = f"http://localhost:{port}"

PORT = 8000
proc = subprocess.Popen(["python", "-m", "http.server", "-d", str(tmpdir), str(PORT)])

# Give the server some time to spin up
time.sleep(0.2)

yield (f"http://localhost:{PORT}", tmpdir)

proc.kill()
p = Process(target=_serve_file_server, args=(port, str(tmpdir)))
p.start()
try:
_wait_for_server(f"{url}/ready")
yield (url, tmpdir)
finally:
p.terminate()
p.join()


@pytest.fixture(scope="function")
def http_image_data_fixture(mock_http_server, image_data) -> YieldFixture[list[str]]:
"""Populates the mock HTTP server with some fake data and returns filepaths"""
# Dump some images into the tmpdir
# NOTE: We use 1 image because the HTTPServer that we use is pretty bad at handling concurrent requests
server_url, tmpdir = mock_http_server
urls = []
for i in range(10):
path = tmpdir / f"{i}.jpeg"
path.write_bytes(image_data)
urls.append(f"{server_url}/{path.relative_to(tmpdir)}")

yield urls
path = tmpdir / f"img.jpeg"
path.write_bytes(image_data)
yield [f"{server_url}/{path.relative_to(tmpdir)}"]

# Cleanup tmpdir
for child in tmpdir.glob("*"):
child.unlink()


def test_url_download_http(http_image_data_fixture, image_data):
@pytest.mark.parametrize("use_native_downloader", [True, False])
def test_url_download_http(http_image_data_fixture, image_data, use_native_downloader):
data = {"urls": http_image_data_fixture}
df = daft.from_pydict(data)
df = df.with_column("data", df["urls"].url.download())
df = df.with_column("data", df["urls"].url.download(use_native_downloader=use_native_downloader))
assert df.to_pydict() == {**data, "data": [image_data for _ in range(len(http_image_data_fixture))]}


def test_url_download_http_missing(mock_http_server):
server_url, _ = mock_http_server
data = {"urls": [f"{server_url}/missing.jpeg"]}
@pytest.mark.parametrize("use_native_downloader", [True, False])
def test_url_download_http_error_codes(mock_error_http_server, use_native_downloader):
url, code = mock_error_http_server
data = {"urls": [f"{url}/missing.jpeg"]}
df = daft.from_pydict(data)
df = df.with_column("data", df["urls"].url.download(on_error="raise"))

# NOTE: if using fsspec FileNotFoundError will be correctly thrown by fsspec.implementations.http.HTTPFileSystem
with pytest.raises(FileNotFoundError):
df.collect()
df = df.with_column("data", df["urls"].url.download(on_error="raise", use_native_downloader=use_native_downloader))

# 404 should always be corner-cased to return FileNotFoundError regardless of I/O implementation
if code == 404:
with pytest.raises(FileNotFoundError):
df.collect()
# When using fsspec, other error codes are bubbled up to the user as aiohttp.client_exceptions.ClientResponseError
elif not use_native_downloader:
with pytest.raises(ClientResponseError) as e:
df.collect()
assert e.value.code == code
# When using native downloader, we throw a ValueError
else:
with pytest.raises(ValueError) as e:
df.collect()
# NOTE: We may want to add better errors in the future to provide a better
# user-facing I/O error with the error code
assert f"Status({code})" in str(e.value)

0 comments on commit bf94599

Please sign in to comment.