Skip to content

Commit

Permalink
Preserve query string on redirect (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
c-w authored Oct 2, 2024
1 parent 62ef942 commit ede5f09
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Using the following categories, list your changes in this order:
### Added

- Support Python 3.13.
- Query strings are now preserved during HTTP redirection.

## [2.0.1] - 2024-09-13

Expand Down
1 change: 1 addition & 0 deletions src/servestatic/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ async def __call__(self, scope, receive, send):
wsgi_headers = {
"HTTP_" + key.decode().upper().replace("-", "_"): value.decode() for key, value in scope["headers"]
}
wsgi_headers["QUERY_STRING"] = scope["query_string"].decode()

# Get the ServeStatic file response
response = await self.static_file.aget_response(scope["method"], wsgi_headers)
Expand Down
13 changes: 11 additions & 2 deletions src/servestatic/responders.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,16 +315,25 @@ def get_path_and_headers(self, request_headers):


class Redirect:
location = "Location"

def __init__(self, location, headers=None):
headers = list(headers.items()) if headers else []
headers.append(("Location", quote(location.encode("utf8"))))
headers.append((self.location, quote(location.encode("utf8"))))
self.response = Response(HTTPStatus.FOUND, headers, None)

def get_response(self, method, request_headers):
query_string = request_headers.get("QUERY_STRING")
if query_string:
headers = list(self.response.headers)
i, value = next((i, value) for (i, (name, value)) in enumerate(headers) if name == self.location)
value = f"{value}?{query_string}"
headers[i] = (self.location, value)
return Response(self.response.status, headers, None)
return self.response

async def aget_response(self, method, request_headers):
return self.response
return self.get_response(method, request_headers)


class NotARegularFileError(Exception):
Expand Down
16 changes: 15 additions & 1 deletion tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
def test_files():
return Files(
js=str(Path("static") / "app.js"),
index=str(Path("static") / "with-index" / "index.html"),
)


Expand All @@ -34,7 +35,12 @@ async def asgi_app(scope, receive, send):
})
await send({"type": "http.response.body", "body": b"Not Found"})

return ServeStaticASGI(asgi_app, root=test_files.directory, autorefresh=request.param)
return ServeStaticASGI(
asgi_app,
root=test_files.directory,
autorefresh=request.param,
index_file=True,
)


def test_get_js_static_file(application, test_files):
Expand All @@ -47,6 +53,14 @@ def test_get_js_static_file(application, test_files):
assert send.headers[b"content-length"] == str(len(test_files.js_content)).encode()


def test_redirect_preserves_query_string(application, test_files):
scope = AsgiScopeEmulator({"path": "/static/with-index", "query_string": b"v=1&x=2"})
receive = AsgiReceiveEmulator()
send = AsgiSendEmulator()
asyncio.run(application(scope, receive, send))
assert send.headers[b"location"] == b"with-index/?v=1&x=2"


def test_user_app(application):
scope = AsgiScopeEmulator({"path": "/"})
receive = AsgiReceiveEmulator()
Expand Down
17 changes: 16 additions & 1 deletion tests/test_servestatic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pytest

from servestatic import ServeStatic
from servestatic.responders import StaticFile
from servestatic.responders import Redirect, StaticFile

from .utils import AppServer, Files

Expand Down Expand Up @@ -245,6 +245,15 @@ def test_index_file_path_redirected(server, files):
assert location == directory_url


def test_index_file_path_redirected_with_query_string(server, files):
directory_url = files.index_url.rpartition("/")[0] + "/"
query_string = "v=1"
response = server.get(f"{files.index_url}?{query_string}", allow_redirects=False)
location = urljoin(files.index_url, response.headers["Location"])
assert response.status_code == 302
assert location == f"{directory_url}?{query_string}"


def test_directory_path_without_trailing_slash_redirected(server, files):
directory_url = files.index_url.rpartition("/")[0] + "/"
no_slash_url = directory_url.rstrip("/")
Expand Down Expand Up @@ -376,3 +385,9 @@ def test_chunked_file_size_matches_range_with_range_header():
while response.file.read(1):
file_size += 1
assert file_size == 14


def test_redirect_preserves_query_string():
responder = Redirect("/redirect/to/here/")
response = responder.get_response("GET", {"QUERY_STRING": "foo=1&bar=2"})
assert response.headers[0] == ("Location", "/redirect/to/here/?foo=1&bar=2")

0 comments on commit ede5f09

Please sign in to comment.