Skip to content

Commit

Permalink
Merge pull request #739 from kevin1024/issue-734-fix-body-matcher-for…
Browse files Browse the repository at this point in the history
…-chunked-requests

Fix body matcher for chunked requests (fixes #734)
  • Loading branch information
hartwork authored Jul 23, 2023
2 parents 92dd4d0 + e69b10c commit e7c00a4
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 13 deletions.
33 changes: 33 additions & 0 deletions tests/unit/test_matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def test_uri_matcher():
"Expect": b"100-continue",
"Content-Length": "21",
}
chunked_headers = {
"Transfer-Encoding": "chunked",
}


@pytest.mark.parametrize(
Expand Down Expand Up @@ -151,6 +154,36 @@ def test_uri_matcher():
request.Request("POST", "http://aws.custom.com/", b"123", boto3_bytes_headers),
request.Request("POST", "http://aws.custom.com/", b"123", boto3_bytes_headers),
),
(
# chunked transfer encoding: decoded bytes versus encoded bytes
request.Request("POST", "scheme1://host1.test/", b"123456789_123456", chunked_headers),
request.Request(
"GET",
"scheme2://host2.test/",
b"10\r\n123456789_123456\r\n0\r\n\r\n",
chunked_headers,
),
),
(
# chunked transfer encoding: bytes iterator versus string iterator
request.Request(
"POST",
"scheme1://host1.test/",
iter([b"123456789_", b"123456"]),
chunked_headers,
),
request.Request("GET", "scheme2://host2.test/", iter(["123456789_", "123456"]), chunked_headers),
),
(
# chunked transfer encoding: bytes iterator versus single byte iterator
request.Request(
"POST",
"scheme1://host1.test/",
iter([b"123456789_", b"123456"]),
chunked_headers,
),
request.Request("GET", "scheme2://host2.test/", iter(b"123456789_123456"), chunked_headers),
),
],
)
def test_body_matcher_does_match(r1, r2):
Expand Down
87 changes: 74 additions & 13 deletions vcr/matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
import logging
import urllib
import xmlrpc.client
from string import hexdigits
from typing import List, Set

from .util import read_body

_HEXDIG_CODE_POINTS: Set[int] = {ord(s.encode("ascii")) for s in hexdigits}

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -49,11 +53,17 @@ def raw_body(r1, r2):


def body(r1, r2):
transformer = _get_transformer(r1)
r2_transformer = _get_transformer(r2)
if transformer != r2_transformer:
transformer = _identity
if transformer(read_body(r1)) != transformer(read_body(r2)):
transformers = list(_get_transformers(r1))
if transformers != list(_get_transformers(r2)):
transformers = []

b1 = read_body(r1)
b2 = read_body(r2)
for transform in transformers:
b1 = transform(b1)
b2 = transform(b2)

if b1 != b2:
raise AssertionError


Expand All @@ -72,6 +82,62 @@ def checker(headers):
return checker


def _dechunk(body):
if isinstance(body, str):
body = body.encode("utf-8")
elif isinstance(body, bytearray):
body = bytes(body)
elif hasattr(body, "__iter__"):
body = list(body)
if body:
if isinstance(body[0], str):
body = ("".join(body)).encode("utf-8")
elif isinstance(body[0], bytes):
body = b"".join(body)
elif isinstance(body[0], int):
body = bytes(body)
else:
raise ValueError(f"Body chunk type {type(body[0])} not supported")
else:
body = None

if not isinstance(body, bytes):
return body

# Now decode chunked data format (https://en.wikipedia.org/wiki/Chunked_transfer_encoding)
# Example input: b"45\r\n<69 bytes>\r\n0\r\n\r\n" where int(b"45", 16) == 69.
CHUNK_GAP = b"\r\n"
BODY_LEN: int = len(body)

chunks: List[bytes] = []
pos: int = 0

while True:
for i in range(pos, BODY_LEN):
if body[i] not in _HEXDIG_CODE_POINTS:
break

if i == 0 or body[i : i + len(CHUNK_GAP)] != CHUNK_GAP:
if pos == 0:
return body # i.e. assume non-chunk data
raise ValueError("Malformed chunked data")

size_bytes = int(body[pos:i], 16)
if size_bytes == 0: # i.e. well-formed ending
return b"".join(chunks)

chunk_data_first = i + len(CHUNK_GAP)
chunk_data_after_last = chunk_data_first + size_bytes

if body[chunk_data_after_last : chunk_data_after_last + len(CHUNK_GAP)] != CHUNK_GAP:
raise ValueError("Malformed chunked data")

chunk_data = body[chunk_data_first:chunk_data_after_last]
chunks.append(chunk_data)

pos = chunk_data_after_last + len(CHUNK_GAP)


def _transform_json(body):
if body:
return json.loads(body)
Expand All @@ -80,6 +146,7 @@ def _transform_json(body):
_xml_header_checker = _header_checker("text/xml")
_xmlrpc_header_checker = _header_checker("xmlrpc", header="User-Agent")
_checker_transformer_pairs = (
(_header_checker("chunked", header="Transfer-Encoding"), _dechunk),
(
_header_checker("application/x-www-form-urlencoded"),
lambda body: urllib.parse.parse_qs(body.decode("ascii")),
Expand All @@ -89,16 +156,10 @@ def _transform_json(body):
)


def _identity(x):
return x


def _get_transformer(request):
def _get_transformers(request):
for checker, transformer in _checker_transformer_pairs:
if checker(request.headers):
return transformer
else:
return _identity
yield transformer


def requests_match(r1, r2, matchers):
Expand Down

0 comments on commit e7c00a4

Please sign in to comment.