From 0c7ce3485dbcda196e125b4ea913a67e6b54f26c Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sun, 12 Apr 2026 19:23:03 +0100 Subject: [PATCH] Allow decompression to continue after exceeding max_length (#11966) --- .github/workflows/ci-cd.yml | 5 +- CHANGES/11966.feature.rst | 8 + aiohttp/_cparser.pxd | 1 + aiohttp/_http_parser.pyx | 120 +++++--- aiohttp/_websocket/writer.py | 2 +- aiohttp/base_protocol.py | 33 ++- aiohttp/client.py | 4 +- aiohttp/client_proto.py | 19 +- aiohttp/compression_utils.py | 43 ++- aiohttp/http_exceptions.py | 4 - aiohttp/http_parser.py | 188 +++++++++---- aiohttp/multipart.py | 27 +- aiohttp/payload.py | 2 +- aiohttp/streams.py | 40 ++- aiohttp/web_exceptions.py | 8 +- aiohttp/web_protocol.py | 54 ++-- aiohttp/web_request.py | 14 +- aiohttp/web_ws.py | 2 +- setup.cfg | 2 + tests/test_base_protocol.py | 22 +- tests/test_benchmarks_http_websocket.py | 4 +- tests/test_client_functional.py | 42 +-- tests/test_client_proto.py | 4 +- tests/test_flowcontrol_streams.py | 7 +- tests/test_http_parser.py | 354 ++++++++++++++++++++---- tests/test_http_writer.py | 8 +- tests/test_multipart.py | 60 +++- tests/test_payload.py | 4 +- tests/test_streams.py | 40 +-- tests/test_web_exceptions.py | 11 +- tests/test_web_functional.py | 38 ++- tests/test_web_protocol.py | 45 +-- tests/test_web_request.py | 10 +- tests/test_websocket_parser.py | 5 +- tests/test_websocket_writer.py | 4 +- 35 files changed, 881 insertions(+), 353 deletions(-) create mode 100644 CHANGES/11966.feature.rst diff --git a/.github/workflows/ci-cd.yml b/.github/workflows/ci-cd.yml index ddc3b283094..2da43d1bd41 100644 --- a/.github/workflows/ci-cd.yml +++ b/.github/workflows/ci-cd.yml @@ -422,7 +422,10 @@ jobs: name: Cython coverage needs: gen_llhttp - runs-on: ubuntu-latest + strategy: + matrix: + os: [ubuntu, windows] + runs-on: ${{ matrix.os }}-latest steps: - name: Checkout uses: actions/checkout@v6 diff --git a/CHANGES/11966.feature.rst b/CHANGES/11966.feature.rst new file mode 100644 index 00000000000..9f298f98e12 --- /dev/null +++ b/CHANGES/11966.feature.rst @@ -0,0 +1,8 @@ +Large overhaul of parser/decompression code. + +The zip bomb security fix in 3.13 stopped highly compressed payloads +from being decompressed, regardless of validity. Now aiohttp will +decompress such payloads in chunks of 256+ KiB, allowing safe decompression +of such payloads. + +-- by :user:`Dreamsorcerer`. diff --git a/aiohttp/_cparser.pxd b/aiohttp/_cparser.pxd index 1b3be6d4efb..cc7ef58d664 100644 --- a/aiohttp/_cparser.pxd +++ b/aiohttp/_cparser.pxd @@ -145,6 +145,7 @@ cdef extern from "llhttp.h": int llhttp_should_keep_alive(const llhttp_t* parser) + void llhttp_resume(llhttp_t* parser) void llhttp_resume_after_upgrade(llhttp_t* parser) llhttp_errno_t llhttp_get_errno(const llhttp_t* parser) diff --git a/aiohttp/_http_parser.pyx b/aiohttp/_http_parser.pyx index 9a444be66fc..719387493f5 100644 --- a/aiohttp/_http_parser.pyx +++ b/aiohttp/_http_parser.pyx @@ -46,7 +46,8 @@ include "_headers.pxi" from aiohttp cimport _find_header -ALLOWED_UPGRADES = frozenset({"websocket"}) + +cdef frozenset ALLOWED_UPGRADES = frozenset({"websocket"}) DEF DEFAULT_FREELIST_SIZE = 250 cdef extern from "Python.h": @@ -69,7 +70,7 @@ cdef object CONTENT_ENCODING = hdrs.CONTENT_ENCODING cdef object EMPTY_PAYLOAD = _EMPTY_PAYLOAD cdef object StreamReader = _StreamReader cdef object DeflateBuffer = _DeflateBuffer -cdef bytes EMPTY_BYTES = b"" +cdef tuple EMPTY_FEED_DATA_RESULT = ((), False, b"") # RFC 9110 singleton headers — duplicates are rejected in strict mode. # In lax mode (response parser default), the check is skipped entirely @@ -298,7 +299,7 @@ cdef class HttpParser: bint _has_value int _header_name_size - object _protocol + readonly object protocol object _loop object _timer @@ -309,6 +310,7 @@ cdef class HttpParser: bint _read_until_eof bint _lax + bytes _tail bint _started object _url bytearray _buf @@ -319,6 +321,9 @@ cdef class HttpParser: list _raw_headers bint _upgraded list _messages + bint _more_data_available + bint _paused + bint _eof_pending object _payload bint _payload_error object _payload_exception @@ -359,18 +364,22 @@ cdef class HttpParser: self._cparser.data = self self._cparser.content_length = 0 - self._protocol = protocol + self.protocol = protocol self._loop = loop self._timer = timer self._buf = bytearray() + self._more_data_available = False + self._paused = False + self._eof_pending = False self._payload = None self._payload_error = 0 self._payload_exception = payload_exception self._messages = [] - self._raw_name = EMPTY_BYTES - self._raw_value = EMPTY_BYTES + self._raw_name = b"" + self._raw_value = b"" + self._tail = b"" self._has_value = False self._header_name_size = 0 @@ -401,7 +410,7 @@ cdef class HttpParser: cdef _process_header(self): cdef str value - if self._raw_name is not EMPTY_BYTES: + if self._raw_name != b"": name = find_header(self._raw_name) value = self._raw_value.decode('utf-8', 'surrogateescape') @@ -426,20 +435,20 @@ cdef class HttpParser: self._has_value = False self._header_name_size = 0 self._raw_headers.append((self._raw_name, self._raw_value)) - self._raw_name = EMPTY_BYTES - self._raw_value = EMPTY_BYTES + self._raw_name = b"" + self._raw_value = b"" cdef _on_header_field(self, char* at, size_t length): if self._has_value: self._process_header() - if self._raw_name is EMPTY_BYTES: + if self._raw_name == b"": self._raw_name = at[:length] else: self._raw_name += at[:length] cdef _on_header_value(self, char* at, size_t length): - if self._raw_value is EMPTY_BYTES: + if self._raw_value == b"": self._raw_value = at[:length] else: self._raw_value += at[:length] @@ -495,14 +504,14 @@ cdef class HttpParser: self._read_until_eof) ): payload = StreamReader( - self._protocol, timer=self._timer, loop=self._loop, + self.protocol, timer=self._timer, loop=self._loop, limit=self._limit) else: payload = EMPTY_PAYLOAD self._payload = payload if encoding is not None and self._auto_decompress: - self._payload = DeflateBuffer(payload, encoding) + self._payload = DeflateBuffer(payload, encoding, max_decompress_size=self._limit) if not self._response_with_body: payload = EMPTY_PAYLOAD @@ -535,6 +544,10 @@ cdef class HttpParser: ### Public API ### + def pause_reading(self): + assert self._payload is not None + self._paused = True + def feed_eof(self): cdef bytes desc @@ -549,18 +562,52 @@ cdef class HttpParser: desc = cparser.llhttp_get_error_reason(self._cparser) raise PayloadEncodingError(desc.decode('latin-1')) else: + self._eof_pending = True + while self._more_data_available: + if self._paused: + self._paused = False + return # Will resume via feed_data(b"") later + self._more_data_available = self._payload.feed_data(b"") self._payload.feed_eof() + self._payload = None + self._more_data_available = False + self._eof_pending = False elif self._started: self._on_headers_complete() if self._messages: return self._messages[-1][0] - def feed_data(self, data): + def feed_data(self, incoming_data): cdef: size_t data_len size_t nb char* base cdef cparser.llhttp_errno_t errno + cdef bytes data + + # Proactor loop sends bytearray. + # Ensure cython sees `data` as bytes + if type(incoming_data) is not bytes: + data = bytes(incoming_data) + else: + data = incoming_data + + if self._tail: + data, self._tail = self._tail + data, b"" + + if self._more_data_available: + result = cb_on_body(self._cparser, b"", 0) + if result is cparser.HPE_PAUSED: + self._tail = data + return EMPTY_FEED_DATA_RESULT + + if self._eof_pending: + self._payload.feed_eof() + self._payload = None + self._eof_pending = False + # We can't have new messages here, otherwise we wouldn't have + # received EOF. + return EMPTY_FEED_DATA_RESULT PyObject_GetBuffer(data, &self.py_buf, PyBUF_SIMPLE) # Cache buffer pointer before PyBuffer_Release to avoid use-after-release. @@ -574,12 +621,15 @@ cdef class HttpParser: if errno is cparser.HPE_PAUSED_UPGRADE: cparser.llhttp_resume_after_upgrade(self._cparser) - nb = cparser.llhttp_get_error_pos(self._cparser) - base + elif errno is cparser.HPE_PAUSED: + cparser.llhttp_resume(self._cparser) + pos = cparser.llhttp_get_error_pos(self._cparser) - base + self._tail = data[pos:] PyBuffer_Release(&self.py_buf) - if errno not in (cparser.HPE_OK, cparser.HPE_PAUSED_UPGRADE): + if errno not in (cparser.HPE_OK, cparser.HPE_PAUSED, cparser.HPE_PAUSED_UPGRADE): if self._payload_error == 0: if self._last_error is not None: ex = self._last_error @@ -603,8 +653,9 @@ cdef class HttpParser: if self._upgraded: return messages, True, data[nb:] - else: - return messages, False, b"" + if not messages: # Shortcut to reduce Python overhead + return EMPTY_FEED_DATA_RESULT + return messages, False, b"" def set_upgraded(self, val): self._upgraded = val @@ -799,19 +850,26 @@ cdef int cb_on_body(cparser.llhttp_t* parser, const char *at, size_t length) except -1: cdef HttpParser pyparser = parser.data cdef bytes body = at[:length] - try: - pyparser._payload.feed_data(body) - except BaseException as underlying_exc: - reraised_exc = underlying_exc - if pyparser._payload_exception is not None: - reraised_exc = pyparser._payload_exception(str(underlying_exc)) - - set_exception(pyparser._payload, reraised_exc, underlying_exc) - - pyparser._payload_error = 1 - return -1 - else: - return 0 + while body or pyparser._more_data_available: + try: + pyparser._more_data_available = pyparser._payload.feed_data(body) + except BaseException as underlying_exc: + reraised_exc = underlying_exc + if pyparser._payload_exception is not None: + reraised_exc = pyparser._payload_exception(str(underlying_exc)) + + set_exception(pyparser._payload, reraised_exc, underlying_exc) + + pyparser._payload_error = 1 + pyparser._paused = False + return -1 + body = b"" + + if pyparser._paused: + pyparser._paused = False + return cparser.HPE_PAUSED + pyparser._paused = False + return 0 cdef int cb_on_message_complete(cparser.llhttp_t* parser) except -1: diff --git a/aiohttp/_websocket/writer.py b/aiohttp/_websocket/writer.py index 1b27dff9371..df89aabbd5b 100644 --- a/aiohttp/_websocket/writer.py +++ b/aiohttp/_websocket/writer.py @@ -21,7 +21,7 @@ ) from .models import WS_DEFLATE_TRAILING, WSMsgType -DEFAULT_LIMIT: Final[int] = 2**16 +DEFAULT_LIMIT: Final[int] = 2**18 # WebSocket opcode boundary: opcodes 0-7 are data frames, 8-15 are control frames # Control frames (ping, pong, close) are never compressed diff --git a/aiohttp/base_protocol.py b/aiohttp/base_protocol.py index d7d83425b88..f1f6edc3836 100644 --- a/aiohttp/base_protocol.py +++ b/aiohttp/base_protocol.py @@ -1,26 +1,35 @@ import asyncio -from typing import cast +from typing import TYPE_CHECKING, Any, cast from .client_exceptions import ClientConnectionResetError from .helpers import set_exception from .tcp_helpers import tcp_nodelay +if TYPE_CHECKING: + from .http_parser import HttpParser + class BaseProtocol(asyncio.Protocol): __slots__ = ( "_loop", "_paused", + "_parser", "_drain_waiter", "_connection_lost", "_reading_paused", + "_upgraded", "transport", ) - def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + def __init__( + self, loop: asyncio.AbstractEventLoop, parser: "HttpParser[Any] | None" = None + ) -> None: self._loop: asyncio.AbstractEventLoop = loop self._paused = False self._drain_waiter: asyncio.Future[None] | None = None self._reading_paused = False + self._parser = parser + self._upgraded = False self.transport: asyncio.Transport | None = None @@ -48,15 +57,27 @@ def resume_writing(self) -> None: waiter.set_result(None) def pause_reading(self) -> None: - if not self._reading_paused and self.transport is not None: + self._reading_paused = True + # Parser shouldn't be paused on websockets. + if not self._upgraded: + assert self._parser is not None + self._parser.pause_reading() + if self.transport is not None: try: self.transport.pause_reading() except (AttributeError, NotImplementedError, RuntimeError): pass - self._reading_paused = True - def resume_reading(self) -> None: - if self._reading_paused and self.transport is not None: + def resume_reading(self, resume_parser: bool = True) -> None: + self._reading_paused = False + + # This will resume parsing any unprocessed data from the last pause. + if not self._upgraded and resume_parser: + self.data_received(b"") + + # Reading may have been paused again in the above call if there was a lot of + # compressed data still pending. + if not self._reading_paused and self.transport is not None: try: self.transport.resume_reading() except (AttributeError, NotImplementedError, RuntimeError): diff --git a/aiohttp/client.py b/aiohttp/client.py index c3e874e650d..9bd9af10bf2 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -331,7 +331,7 @@ def __init__( trust_env: bool = False, requote_redirect_url: bool = True, trace_configs: list[TraceConfig[object]] | None = None, - read_bufsize: int = 2**16, + read_bufsize: int = 2**18, max_line_size: int = 8190, max_field_size: int = 8190, max_headers: int = 128, @@ -1226,7 +1226,7 @@ async def _ws_connect( transport = conn.transport assert transport is not None - reader = WebSocketDataQueue(conn_proto, 2**16, loop=self._loop) + reader = WebSocketDataQueue(conn_proto, 2**18, loop=self._loop) writer = WebSocketWriter( conn_proto, transport, diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index bb088b6a99c..19bd8564ca6 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -32,7 +32,7 @@ class ResponseHandler(BaseProtocol, DataQueue[tuple[RawResponseMessage, StreamRe """Helper class to adapt between Protocol and StreamReader.""" def __init__(self, loop: asyncio.AbstractEventLoop) -> None: - BaseProtocol.__init__(self, loop=loop) + BaseProtocol.__init__(self, loop=loop, parser=None) DataQueue.__init__(self, loop) self._should_close = False @@ -43,10 +43,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._data_received_cb: Callable[[], None] | None = None self._timer = None - self._tail = b"" - self._upgraded = False - self._parser: HttpResponseParser | None = None self._read_timeout: float | None = None self._read_timeout_handle: asyncio.TimerHandle | None = None @@ -197,8 +194,8 @@ def pause_reading(self) -> None: super().pause_reading() self._drop_timeout() - def resume_reading(self) -> None: - super().resume_reading() + def resume_reading(self, resume_parser: bool = True) -> None: + super().resume_reading(resume_parser) self._reschedule_timeout() def set_exception( @@ -234,7 +231,7 @@ def set_response_params( read_until_eof: bool = False, auto_decompress: bool = True, read_timeout: float | None = None, - read_bufsize: int = 2**16, + read_bufsize: int = 2**18, timeout_ceil_threshold: float = 5, max_line_size: int = 8190, max_field_size: int = 8190, @@ -299,10 +296,10 @@ def _on_read_timeout(self) -> None: set_exception(self._payload, exc) def data_received(self, data: bytes) -> None: - self._reschedule_timeout() - - if not data: - return + # If no data, then we are resuming decompression. We haven't received + # data from the socket, so we can avoid the reschedule overhead. + if data: + self._reschedule_timeout() # custom payload parser - currently always WebSocketReader if self._payload_parser is not None: diff --git a/aiohttp/compression_utils.py b/aiohttp/compression_utils.py index 2a8818c4220..5e12337113c 100644 --- a/aiohttp/compression_utils.py +++ b/aiohttp/compression_utils.py @@ -34,7 +34,9 @@ MAX_SYNC_CHUNK_SIZE = 4096 -DEFAULT_MAX_DECOMPRESS_SIZE = 2**25 # 32MiB +# Matches the max size we receive from sockets: +# https://github.com/python/cpython/blob/1857a40807daeae3a1bf5efb682de9c9ae6df845/Lib/asyncio/selector_events.py#L766 +DEFAULT_MAX_DECOMPRESS_SIZE = 256 * 1024 # Unlimited decompression constants - different libraries use different conventions ZLIB_MAX_LENGTH_UNLIMITED = 0 # zlib uses 0 to mean unlimited @@ -53,6 +55,9 @@ def flush(self, length: int = ..., /) -> bytes: ... @property def eof(self) -> bool: ... + @property + def unconsumed_tail(self) -> bytes: ... + class ZLibBackendProtocol(Protocol): MAX_WBITS: int @@ -179,6 +184,11 @@ async def decompress( ) return self.decompress_sync(data, max_length) + @property + @abstractmethod + def data_available(self) -> bool: + """Return True if more output is available by passing b"".""" + class ZLibCompressor: def __init__( @@ -267,11 +277,17 @@ def __init__( self._mode = encoding_to_mode(encoding, suppress_deflate_header) self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend) self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode) + self._last_empty = False def decompress_sync( self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED ) -> bytes: - return self._decompressor.decompress(data, max_length) + result = self._decompressor.decompress( + self._decompressor.unconsumed_tail + data, max_length + ) + # Only way to know that isal has no further data is checking we get no output + self._last_empty = result == b"" + return result def flush(self, length: int = 0) -> bytes: return ( @@ -280,6 +296,10 @@ def flush(self, length: int = 0) -> bytes: else self._decompressor.flush() ) + @property + def data_available(self) -> bool: + return bool(self._decompressor.unconsumed_tail) or not self._last_empty + @property def eof(self) -> bool: return self._decompressor.eof @@ -301,6 +321,7 @@ def __init__( "Please install `Brotli` module" ) self._obj = brotli.Decompressor() + self._last_empty = False super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size) def decompress_sync( @@ -308,8 +329,12 @@ def decompress_sync( ) -> bytes: """Decompress the given data.""" if hasattr(self._obj, "decompress"): - return cast(bytes, self._obj.decompress(data, max_length)) - return cast(bytes, self._obj.process(data, max_length)) + result = cast(bytes, self._obj.decompress(data, max_length)) + else: + result = cast(bytes, self._obj.process(data, max_length)) + # Only way to know that brotli has no further data is checking we get no output + self._last_empty = result == b"" + return result def flush(self) -> bytes: """Flush the decompressor.""" @@ -317,6 +342,10 @@ def flush(self) -> bytes: return cast(bytes, self._obj.flush()) return b"" + @property + def data_available(self) -> bool: + return not self._obj.is_finished() and not self._last_empty + class ZSTDDecompressor(DecompressionBaseHandler): def __init__( @@ -373,3 +402,9 @@ def decompress_sync( def flush(self) -> bytes: return b"" + + @property + def data_available(self) -> bool: + return ( + not self._obj.needs_input and not self._obj.eof + ) or self._pending_unused_data is not None diff --git a/aiohttp/http_exceptions.py b/aiohttp/http_exceptions.py index cf3c05434c5..95d0d6373ae 100644 --- a/aiohttp/http_exceptions.py +++ b/aiohttp/http_exceptions.py @@ -73,10 +73,6 @@ class ContentLengthError(PayloadEncodingError): """Not enough data to satisfy content length header.""" -class DecompressSizeError(PayloadEncodingError): - """Decompressed size exceeds the configured limit.""" - - class LineTooLong(BadHttpMessage): def __init__(self, line: bytes, limit: int) -> None: super().__init__(f"Got more than {limit} bytes when reading: {line!r}.") diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index 207cf8da39e..4601f201122 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -5,7 +5,16 @@ from contextlib import suppress from enum import IntEnum from re import Pattern -from typing import Any, ClassVar, Final, Generic, Literal, NamedTuple, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Final, + Generic, + Literal, + NamedTuple, + TypeVar, +) from multidict import CIMultiDict, CIMultiDictProxy, istr from yarl import URL @@ -35,7 +44,6 @@ BadStatusLine, ContentEncodingError, ContentLengthError, - DecompressSizeError, InvalidHeader, InvalidURLError, LineTooLong, @@ -45,6 +53,9 @@ from .streams import EMPTY_PAYLOAD, StreamReader from .typedefs import RawHeaders +if TYPE_CHECKING: + from .client_proto import ResponseHandler + __all__ = ( "HeadersParser", "HttpParser", @@ -124,6 +135,12 @@ class RawResponseMessage(NamedTuple): _MsgT = TypeVar("_MsgT", RawRequestMessage, RawResponseMessage) +class PayloadState(IntEnum): + PAYLOAD_COMPLETE = 0 + PAYLOAD_NEEDS_INPUT = 1 + PAYLOAD_HAS_PENDING_INPUT = 2 + + class ParseState(IntEnum): PARSE_NONE = 0 PARSE_LENGTH = 1 @@ -265,6 +282,7 @@ def __init__( self._upgraded = False self._payload = None self._payload_parser: HttpPayloadParser | None = None + self._payload_has_more_data = False self._auto_decompress = auto_decompress self._limit = limit self._headers_parser = HeadersParser(max_field_size, self.lax) @@ -275,10 +293,15 @@ def parse_message(self, lines: list[bytes]) -> _MsgT: ... @abc.abstractmethod def _is_chunked_te(self, te: str) -> bool: ... + def pause_reading(self) -> None: + assert self._payload_parser is not None + self._payload_parser.pause_reading() + def feed_eof(self) -> _MsgT | None: if self._payload_parser is not None: self._payload_parser.feed_eof() - self._payload_parser = None + if self._payload_parser.done: + self._payload_parser = None else: # try to extract partial message if self._tail: @@ -311,7 +334,7 @@ def feed_data( max_line_length = self.max_line_size should_close = False - while start_pos < data_len: + while start_pos < data_len or self._payload_has_more_data: # read HTTP message (request/response line + headers), \r\n\r\n # and split by lines if self._payload_parser is None and not self._upgraded: @@ -405,6 +428,7 @@ def get_content_length() -> int | None: max_line_size=self.max_line_size, max_field_size=self.max_field_size, max_trailers=max_trailers, + limit=self._limit, ) if not payload_parser.done: self._payload_parser = payload_parser @@ -427,6 +451,7 @@ def get_content_length() -> int | None: max_line_size=self.max_line_size, max_field_size=self.max_field_size, max_trailers=max_trailers, + limit=self._limit, ) elif not empty_body and length is None and self.read_until_eof: payload = StreamReader( @@ -449,6 +474,7 @@ def get_content_length() -> int | None: max_line_size=self.max_line_size, max_field_size=self.max_field_size, max_trailers=max_trailers, + limit=self._limit, ) if not payload_parser.done: self._payload_parser = payload_parser @@ -470,11 +496,13 @@ def get_content_length() -> int | None: break # feed payload - elif data and start_pos < data_len: + else: assert not self._lines assert self._payload_parser is not None try: - eof, data = self._payload_parser.feed_data(data[start_pos:], SEP) + payload_state, data = self._payload_parser.feed_data( + data[start_pos:], SEP + ) except Exception as underlying_exc: reraised_exc: BaseException = underlying_exc if self.payload_exception is not None: @@ -486,20 +514,25 @@ def get_content_length() -> int | None: underlying_exc, ) - eof = True + payload_state = PayloadState.PAYLOAD_COMPLETE data = b"" if isinstance( underlying_exc, (InvalidHeader, TransferEncodingError) ): raise - if eof: - start_pos = 0 - data_len = len(data) - self._payload_parser = None - continue - else: - break + self._payload_has_more_data = ( + payload_state == PayloadState.PAYLOAD_HAS_PENDING_INPUT + ) + + if payload_state is not PayloadState.PAYLOAD_COMPLETE: + # We've either consumed all available data, or we're pausing + # until the reader buffer is freed up. + break + + start_pos = 0 + data_len = len(data) + self._payload_parser = None if data and start_pos < data_len: data = data[start_pos:] @@ -681,6 +714,8 @@ class HttpResponseParser(HttpParser[RawResponseMessage]): Returns RawResponseMessage. """ + protocol: "ResponseHandler" + # Lax mode should only be enabled on response parser. lax = not DEBUG @@ -775,8 +810,10 @@ def __init__( max_line_size: int = 8190, max_field_size: int = 8190, max_trailers: int = 128, + limit: int = DEFAULT_MAX_DECOMPRESS_SIZE, ) -> None: self._length = 0 + self._paused = False self._type = ParseState.PARSE_UNTIL_EOF self._chunk = ChunkState.PARSE_CHUNKED_SIZE self._chunk_size = 0 @@ -787,13 +824,15 @@ def __init__( self._max_line_size = max_line_size self._max_field_size = max_field_size self._max_trailers = max_trailers + self._more_data_available = False self._trailer_lines: list[bytes] = [] self.done = False + self._eof_pending = False # payload decompression wrapper if response_with_body and compression and self._auto_decompress: real_payload: StreamReader | DeflateBuffer = DeflateBuffer( - payload, compression + payload, compression, max_decompress_size=limit ) else: real_payload = payload @@ -815,9 +854,20 @@ def __init__( self.payload = real_payload + def pause_reading(self) -> None: + self._paused = True + def feed_eof(self) -> None: if self._type == ParseState.PARSE_UNTIL_EOF: + self._eof_pending = True + while self._more_data_available: + if self._paused: + self._paused = False + return # Will resume via feed_data(b"") later + self._more_data_available = self.payload.feed_data(b"") self.payload.feed_eof() + self.done = True + self._eof_pending = False elif self._type == ParseState.PARSE_LENGTH: raise ContentLengthError( "Not enough data to satisfy content length header." @@ -829,32 +879,52 @@ def feed_eof(self) -> None: def feed_data( self, chunk: bytes, SEP: _SEP = b"\r\n", CHUNK_EXT: bytes = b";" - ) -> tuple[bool, bytes]: + ) -> tuple[PayloadState, bytes]: + """Receive a chunk of data to process. + + Return: + PayloadState - The current state of payload processing. + This function may be called with empty bytes after returning + PAYLOAD_HAS_PENDING_INPUT to continue processing after a pause. + bytes - If payload is complete, this is the unconsumed bytes intended for the + next message/payload, b"" otherwise. + """ # Read specified amount of bytes if self._type == ParseState.PARSE_LENGTH: + if self._chunk_tail: + chunk = self._chunk_tail + chunk + self._chunk_tail = b"" + required = self._length self._length = max(required - len(chunk), 0) - self.payload.feed_data(chunk[:required]) + self._more_data_available = self.payload.feed_data(chunk[:required]) + while self._more_data_available: + if self._paused: + self._paused = False + self._chunk_tail = chunk[required:] + return PayloadState.PAYLOAD_HAS_PENDING_INPUT, b"" + self._more_data_available = self.payload.feed_data(b"") + if self._length == 0: self.payload.feed_eof() - return True, chunk[required:] - + return PayloadState.PAYLOAD_COMPLETE, chunk[required:] # Chunked transfer encoding parser elif self._type == ParseState.PARSE_CHUNKED: if self._chunk_tail: - # We should never have a tail if we're inside the payload body. - assert self._chunk != ChunkState.PARSE_CHUNKED_CHUNK - # We should check the length is sane. - max_line_length = self._max_line_size - if self._chunk == ChunkState.PARSE_TRAILERS: - max_line_length = self._max_field_size - if len(self._chunk_tail) > max_line_length: - raise LineTooLong(self._chunk_tail[:100] + b"...", max_line_length) + # We should check the length is sane when not processing payload body. + if self._chunk != ChunkState.PARSE_CHUNKED_CHUNK: + max_line_length = self._max_line_size + if self._chunk == ChunkState.PARSE_TRAILERS: + max_line_length = self._max_field_size + if len(self._chunk_tail) > max_line_length: + raise LineTooLong( + self._chunk_tail[:100] + b"...", max_line_length + ) chunk = self._chunk_tail + chunk self._chunk_tail = b"" - while chunk: + while chunk or self._more_data_available: # read next chunk size if self._chunk == ChunkState.PARSE_CHUNKED_SIZE: pos = chunk.find(SEP) @@ -894,17 +964,26 @@ def feed_data( self.payload.begin_http_chunk_receiving() else: self._chunk_tail = chunk - return False, b"" + return PayloadState.PAYLOAD_NEEDS_INPUT, b"" # read chunk and feed buffer if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK: + if self._paused: + self._paused = False + self._chunk_tail = chunk + return PayloadState.PAYLOAD_HAS_PENDING_INPUT, b"" + required = self._chunk_size self._chunk_size = max(required - len(chunk), 0) - self.payload.feed_data(chunk[:required]) + self._more_data_available = self.payload.feed_data(chunk[:required]) + chunk = chunk[required:] + + if self._more_data_available: + continue if self._chunk_size: - return False, b"" - chunk = chunk[required:] + self._paused = False + return PayloadState.PAYLOAD_NEEDS_INPUT, b"" self._chunk = ChunkState.PARSE_CHUNKED_CHUNK_EOF self.payload.end_http_chunk_receiving() @@ -923,13 +1002,13 @@ def feed_data( raise exc else: self._chunk_tail = chunk - return False, b"" + return PayloadState.PAYLOAD_NEEDS_INPUT, b"" if self._chunk == ChunkState.PARSE_TRAILERS: pos = chunk.find(SEP) if pos < 0: # No line found self._chunk_tail = chunk - return False, b"" + return PayloadState.PAYLOAD_NEEDS_INPUT, b"" line = chunk[:pos] chunk = chunk[pos + len(SEP) :] @@ -955,13 +1034,24 @@ def feed_data( finally: self._trailer_lines.clear() self.payload.feed_eof() - return True, chunk + return PayloadState.PAYLOAD_COMPLETE, chunk # Read all bytes until eof elif self._type == ParseState.PARSE_UNTIL_EOF: - self.payload.feed_data(chunk) + self._more_data_available = self.payload.feed_data(chunk) + while self._more_data_available: + if self._paused: + self._paused = False + return PayloadState.PAYLOAD_HAS_PENDING_INPUT, b"" + self._more_data_available = self.payload.feed_data(b"") + + if self._eof_pending: + self.payload.feed_eof() + self.done = True + self._eof_pending = False + return PayloadState.PAYLOAD_COMPLETE, b"" - return False, b"" + return PayloadState.PAYLOAD_NEEDS_INPUT, b"" class DeflateBuffer: @@ -1006,10 +1096,8 @@ def set_exception( ) -> None: set_exception(self.out, exc, exc_cause) - def feed_data(self, chunk: bytes) -> None: - if not chunk: - return - + def feed_data(self, chunk: bytes) -> bool: + """Return True if more data is available and this method should be called again with b"".""" self.size += len(chunk) self.out.total_compressed_bytes = self.size @@ -1028,9 +1116,8 @@ def feed_data(self, chunk: bytes) -> None: ) try: - # Decompress with limit + 1 so we can detect if output exceeds limit chunk = self.decompressor.decompress_sync( - chunk, max_length=self._max_decompress_size + 1 + chunk, max_length=self._max_decompress_size ) except Exception: raise ContentEncodingError( @@ -1039,21 +1126,18 @@ def feed_data(self, chunk: bytes) -> None: self._started_decoding = True - # Check if decompression limit was exceeded - if len(chunk) > self._max_decompress_size: - raise DecompressSizeError( - "Decompressed data exceeds the configured limit of %d bytes" - % self._max_decompress_size - ) - if chunk: self.out.feed_data(chunk) + return self.decompressor.data_available def feed_eof(self) -> None: chunk = self.decompressor.flush() + # This should never contain data as we defer the call until exhausting + # the decompression. If .flush() is returning data, this may indicate a + # zip bomb vulnerability as it will decompress all remaining data at once. + assert not chunk - if chunk or self.size > 0: - self.out.feed_data(chunk) + if self.size > 0: # decompressor is not brotli unless encoding is "br" if self.encoding == "deflate" and not self.decompressor.eof: # type: ignore[union-attr] raise ContentEncodingError("deflate") diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index c44219e92b4..5bfce9e4074 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -268,6 +268,8 @@ def __init__( subtype: str = "mixed", default_charset: str | None = None, max_decompress_size: int = DEFAULT_MAX_DECOMPRESS_SIZE, + client_max_size: int = sys.maxsize, + max_size_error_cls: type[Exception] = ValueError, ) -> None: self.headers = headers self._boundary = boundary @@ -285,6 +287,8 @@ def __init__( self._content_eof = 0 self._cache: dict[str, Any] = {} self._max_decompress_size = max_decompress_size + self._client_max_size = client_max_size + self._max_size_error_cls = max_size_error_cls def __aiter__(self) -> Self: return self @@ -313,11 +317,15 @@ async def read(self, *, decode: bool = False) -> bytes: data = bytearray() while not self._at_eof: data.extend(await self.read_chunk(self.chunk_size)) + if len(data) > self._client_max_size: + raise self._max_size_error_cls(self._client_max_size) # https://github.com/python/mypy/issues/17537 if decode: # type: ignore[unreachable] decoded_data = bytearray() async for d in self.decode_iter(data): decoded_data.extend(d) + if len(decoded_data) > self._client_max_size: + raise self._max_size_error_cls(self._client_max_size) return decoded_data return data @@ -559,6 +567,8 @@ async def _decode_content_async(self, data: bytes) -> AsyncIterator[bytes]: suppress_deflate_header=True, ) yield await d.decompress(data, max_length=self._max_decompress_size) + while d.data_available: + yield await d.decompress(b"", max_length=self._max_decompress_size) else: raise RuntimeError(f"unknown content encoding: {encoding}") @@ -652,8 +662,10 @@ def __init__( headers: Mapping[str, str], content: StreamReader, *, + client_max_size: int = sys.maxsize, max_field_size: int = 8190, max_headers: int = 128, + max_size_error_cls: type[Exception] = ValueError, ) -> None: self._mimetype = parse_mimetype(headers[CONTENT_TYPE]) assert self._mimetype.type == "multipart", "multipart/* content type expected" @@ -664,11 +676,13 @@ def __init__( self.headers = headers self._boundary = ("--" + self._get_boundary()).encode() + self._client_max_size = client_max_size self._content = content self._default_charset: str | None = None self._last_part: MultipartReader | BodyPartReader | None = None self._max_field_size = max_field_size self._max_headers = max_headers + self._max_size_error_cls = max_size_error_cls self._at_eof = False self._at_bof = True self._unread: list[bytes] = [] @@ -768,12 +782,21 @@ def _get_part_reader( if mimetype.type == "multipart": if self.multipart_reader_cls is None: - return type(self)(headers, self._content) + return type(self)( + headers, + self._content, + client_max_size=self._client_max_size, + max_field_size=self._max_field_size, + max_headers=self._max_headers, + max_size_error_cls=self._max_size_error_cls, + ) return self.multipart_reader_cls( headers, self._content, + client_max_size=self._client_max_size, max_field_size=self._max_field_size, max_headers=self._max_headers, + max_size_error_cls=self._max_size_error_cls, ) else: return self.part_reader_cls( @@ -782,6 +805,8 @@ def _get_part_reader( self._content, subtype=self._mimetype.subtype, default_charset=self._default_charset, + client_max_size=self._client_max_size, + max_size_error_cls=self._max_size_error_cls, ) def _get_boundary(self) -> str: diff --git a/aiohttp/payload.py b/aiohttp/payload.py index 9a8dc2f3262..71c015499a6 100644 --- a/aiohttp/payload.py +++ b/aiohttp/payload.py @@ -43,7 +43,7 @@ ) TOO_LARGE_BYTES_BODY: Final[int] = 2**20 # 1 MB -READ_SIZE: Final[int] = 2**16 # 64 KB +READ_SIZE: Final[int] = 2**18 # 256 KiB _CLOSE_FUTURES: set[asyncio.Future[None]] = set() diff --git a/aiohttp/streams.py b/aiohttp/streams.py index bacb810958b..b9367066291 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -138,11 +138,11 @@ def __init__( self._protocol = protocol self._low_water = limit self._high_water = limit * 2 - # Ensure high_water_chunks >= 3 so it's always > low_water_chunks. - self._high_water_chunks = max(3, limit // 4) - # Use max(2, ...) because there's always at least 1 chunk split remaining + # Use max(4, ...) because there's always at least 1 chunk split remaining # (the current position), so we need low_water >= 2 to allow resume. - self._low_water_chunks = max(2, self._high_water_chunks // 2) + # limit // 16 gets us a reasonable value of 16k with default 256KiB limit. + self._high_water_chunks = max(4, limit // 16) + self._low_water_chunks = self._high_water_chunks // 2 self._loop = loop self._size = 0 self._cursor = 0 @@ -165,7 +165,7 @@ def __repr__(self) -> str: info.append("%d bytes" % self._size) if self._eof: info.append("eof") - if self._low_water != 2**16: # default limit + if self._low_water != 2**18: # default limit info.append("low=%d high=%d" % (self._low_water, self._high_water)) if self._waiter: info.append("w=%r" % self._waiter) @@ -219,8 +219,8 @@ def feed_eof(self) -> None: self._eof_waiter = None set_result(waiter, None) - if self._protocol._reading_paused: - self._protocol.resume_reading() + # At EOF the parser is done, there won't be unprocessed data. + self._protocol.resume_reading(resume_parser=False) for cb in self._eof_callbacks: try: @@ -274,11 +274,11 @@ def unread_data(self, data: bytes) -> None: self._buffer.appendleft(data) self._eof_counter = 0 - def feed_data(self, data: bytes) -> None: + def feed_data(self, data: bytes) -> bool: assert not self._eof, "feed_data after feed_eof" if not data: - return + return False data_len = len(data) self._size += data_len @@ -290,8 +290,9 @@ def feed_data(self, data: bytes) -> None: self._waiter = None set_result(waiter, None) - if self._size > self._high_water and not self._protocol._reading_paused: + if self._size > self._high_water: self._protocol.pause_reading() + return False def begin_http_chunk_receiving(self) -> None: if self._http_chunk_splits is None: @@ -328,10 +329,7 @@ def end_http_chunk_receiving(self) -> None: # If we get too many small chunks before self._high_water is reached, then any # .read() call becomes computationally expensive, and could block the event loop # for too long, hence an additional self._high_water_chunks here. - if ( - len(self._http_chunk_splits) > self._high_water_chunks - and not self._protocol._reading_paused - ): + if len(self._http_chunk_splits) > self._high_water_chunks: self._protocol.pause_reading() # wake up readchunk when end of http chunk received @@ -531,13 +529,9 @@ def _read_nowait_chunk(self, n: int) -> bytes: while chunk_splits and chunk_splits[0] < self._cursor: chunk_splits.popleft() - if ( - self._protocol._reading_paused - and self._size < self._low_water - and ( - self._http_chunk_splits is None - or len(self._http_chunk_splits) < self._low_water_chunks - ) + if self._size < self._low_water and ( + self._http_chunk_splits is None + or len(self._http_chunk_splits) < self._low_water_chunks ): self._protocol.resume_reading() return data @@ -597,8 +591,8 @@ def at_eof(self) -> bool: async def wait_eof(self) -> None: return - def feed_data(self, data: bytes) -> None: - pass + def feed_data(self, data: bytes) -> bool: + return False async def readline(self, *, max_line_length: int | None = None) -> bytes: return b"" diff --git a/aiohttp/web_exceptions.py b/aiohttp/web_exceptions.py index 782a4d39507..bd507a8813a 100644 --- a/aiohttp/web_exceptions.py +++ b/aiohttp/web_exceptions.py @@ -366,12 +366,8 @@ class HTTPPreconditionFailed(HTTPClientError): class HTTPRequestEntityTooLarge(HTTPClientError): status_code = 413 - def __init__(self, max_size: int, actual_size: int, **kwargs: Any) -> None: - kwargs.setdefault( - "text", - f"Maximum request body size {max_size} exceeded, " - f"actual body size {actual_size}", - ) + def __init__(self, max_size: int, **kwargs: Any) -> None: + kwargs.setdefault("text", f"Maximum request body size {max_size} exceeded.") super().__init__(**kwargs) diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 20d76408d4f..9785c13fa4f 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -22,6 +22,7 @@ HttpVersion10, RawRequestMessage, StreamWriter, + WebSocketReader, ) from .http_exceptions import BadHttpMethod from .log import access_logger, server_logger @@ -171,10 +172,8 @@ class RequestHandler(BaseProtocol, Generic[_Request]): "_handler_waiter", "_waiter", "_task_handler", - "_upgrade", "_payload_parser", "_data_received_cb", - "_request_parser", "logger", "access_log", "access_logger", @@ -203,11 +202,21 @@ def __init__( max_headers: int = 128, max_field_size: int = 8190, lingering_time: float = 10.0, - read_bufsize: int = 2**16, + read_bufsize: int = 2**18, auto_decompress: bool = True, timeout_ceil_threshold: float = 5, ): - super().__init__(loop) + parser = HttpRequestParser( + self, + loop, + read_bufsize, + max_line_size=max_line_size, + max_field_size=max_field_size, + max_headers=max_headers, + payload_exception=RequestPayloadError, + auto_decompress=auto_decompress, + ) + super().__init__(loop, parser) # _request_count is the number of requests processed with the same connection. self._request_count = 0 @@ -239,19 +248,7 @@ def __init__( self._waiter: asyncio.Future[None] | None = None self._handler_waiter: asyncio.Future[None] | None = None self._task_handler: asyncio.Task[None] | None = None - - self._upgrade = False self._payload_parser: Any = None - self._request_parser: HttpRequestParser | None = HttpRequestParser( - self, - loop, - read_bufsize, - max_line_size=max_line_size, - max_field_size=max_field_size, - max_headers=max_headers, - payload_exception=RequestPayloadError, - auto_decompress=auto_decompress, - ) self._timeout_ceil_threshold: float = 5 try: @@ -392,7 +389,7 @@ def connection_lost(self, exc: BaseException | None) -> None: self._manager = None self._request_factory = None self._request_handler = None - self._request_parser = None + self._parser = None if self._keepalive_handle is not None: self._keepalive_handle.cancel() @@ -412,9 +409,10 @@ def connection_lost(self, exc: BaseException | None) -> None: self._payload_parser = None def set_parser( - self, parser: Any, data_received_cb: Callable[[], None] | None = None + self, + parser: WebSocketReader, + data_received_cb: Callable[[], None] | None = None, ) -> None: - # Actual type is WebReader assert self._payload_parser is None self._payload_parser = parser @@ -432,10 +430,10 @@ def data_received(self, data: bytes) -> None: return # parse http messages messages: Sequence[_MsgType] - if self._payload_parser is None and not self._upgrade: - assert self._request_parser is not None + if self._payload_parser is None and not self._upgraded: + assert self._parser is not None try: - messages, upgraded, tail = self._request_parser.feed_data(data) + messages, upgraded, tail = self._parser.feed_data(data) except HttpProcessingError as exc: messages = [ (_ErrInfo(status=400, exc=exc, message=exc.message), EMPTY_PAYLOAD) @@ -452,12 +450,12 @@ def data_received(self, data: bytes) -> None: # don't set result twice waiter.set_result(None) - self._upgrade = upgraded + self._upgraded = upgraded if upgraded and tail: self._message_tail = tail # no parser, just store - elif self._payload_parser is None and self._upgrade and data: + elif self._payload_parser is None and self._upgraded and data: self._message_tail += data # feed payload @@ -719,11 +717,11 @@ async def finish_response( prematurely. """ request._finish() - if self._request_parser is not None: - self._request_parser.set_upgraded(False) - self._upgrade = False + if self._parser is not None: + self._parser.set_upgraded(False) + self._upgraded = False if self._message_tail: - self._request_parser.feed_data(self._message_tail) + self._parser.feed_data(self._message_tail) self._message_tail = b"" try: prepare_meth = resp.prepare diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 42be85e2e74..b8feae19cec 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -634,9 +634,7 @@ async def read(self) -> bytes: if self._client_max_size: body_size = len(body) if body_size > self._client_max_size: - raise HTTPRequestEntityTooLarge( - max_size=self._client_max_size, actual_size=body_size - ) + raise HTTPRequestEntityTooLarge(self._client_max_size) if not chunk: break self._read_bytes = bytes(body) @@ -675,8 +673,10 @@ async def multipart(self) -> MultipartReader: return MultipartReader( self._headers, self._payload, + client_max_size=self._client_max_size, max_field_size=self._protocol.max_field_size, max_headers=self._protocol.max_headers, + max_size_error_cls=HTTPRequestEntityTooLarge, ) async def post(self) -> "MultiDictProxy[str | bytes | FileField]": @@ -727,9 +727,7 @@ async def post(self) -> "MultiDictProxy[str | bytes | FileField]": size += len(decoded_chunk) if 0 < max_size < size: await self._loop.run_in_executor(None, tmp.close) - raise HTTPRequestEntityTooLarge( - max_size=max_size, actual_size=size - ) + raise HTTPRequestEntityTooLarge(max_size) await self._loop.run_in_executor(None, tmp.seek, 0) if field_ct is None: @@ -749,9 +747,7 @@ async def post(self) -> "MultiDictProxy[str | bytes | FileField]": while chunk := await field.read_chunk(): size += len(chunk) if 0 < max_size < size: - raise HTTPRequestEntityTooLarge( - max_size=max_size, actual_size=size - ) + raise HTTPRequestEntityTooLarge(max_size) raw_data.extend(chunk) value = bytearray() diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 2aeeb6dec1f..1a7622b8421 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -383,7 +383,7 @@ def _post_start( loop = self._loop assert loop is not None - self._reader = WebSocketDataQueue(request._protocol, 2**16, loop=loop) + self._reader = WebSocketDataQueue(request._protocol, 2**18, loop=loop) parser = WebSocketReader( self._reader, self._max_msg_size, diff --git a/setup.cfg b/setup.cfg index 2d7e24b0374..203c01c3754 100644 --- a/setup.cfg +++ b/setup.cfg @@ -88,6 +88,8 @@ filterwarnings = # https://github.com/spulec/freezegun/issues/508 # https://github.com/spulec/freezegun/pull/511 ignore:datetime.*utcnow\(\) is deprecated and scheduled for removal:DeprecationWarning:freezegun.api + # Weird issue in Python 3.13+ triggered in test_multipart.py + ignore:coroutine method 'aclose' of 'BodyPartReader._decode_content_async' was never awaited:RuntimeWarning junit_suite_name = aiohttp_test_suite norecursedirs = dist docs build .tox .eggs minversion = 3.8.2 diff --git a/tests/test_base_protocol.py b/tests/test_base_protocol.py index 713dba2d0c2..234e9927c02 100644 --- a/tests/test_base_protocol.py +++ b/tests/test_base_protocol.py @@ -5,6 +5,7 @@ import pytest from aiohttp.base_protocol import BaseProtocol +from aiohttp.http_parser import HttpParser async def test_loop() -> None: @@ -26,33 +27,28 @@ async def test_pause_writing() -> None: async def test_pause_reading_no_transport() -> None: loop = asyncio.get_event_loop() - pr = BaseProtocol(loop) - assert not pr._reading_paused + parser = mock.create_autospec(HttpParser, spec_set=True, instance=True) + pr = BaseProtocol(loop, parser=parser) pr.pause_reading() - assert not pr._reading_paused + parser.pause_reading.assert_called_once() async def test_pause_reading_stub_transport() -> None: loop = asyncio.get_event_loop() - pr = BaseProtocol(loop) + parser = mock.create_autospec(HttpParser, spec_set=True, instance=True) + pr = BaseProtocol(loop, parser=parser) tr = asyncio.Transport() pr.transport = tr assert not pr._reading_paused pr.pause_reading() assert pr._reading_paused - - -async def test_resume_reading_no_transport() -> None: - loop = asyncio.get_event_loop() - pr = BaseProtocol(loop) - pr._reading_paused = True - pr.resume_reading() - assert pr._reading_paused + parser.pause_reading.assert_called_once() # type: ignore[unreachable] async def test_resume_reading_stub_transport() -> None: loop = asyncio.get_event_loop() - pr = BaseProtocol(loop) + parser = mock.create_autospec(HttpParser, spec_set=True, instance=True) + pr = BaseProtocol(loop, parser=parser) tr = asyncio.Transport() pr.transport = tr pr._reading_paused = True diff --git a/tests/test_benchmarks_http_websocket.py b/tests/test_benchmarks_http_websocket.py index 2aa9b8294bd..10115c1a2bd 100644 --- a/tests/test_benchmarks_http_websocket.py +++ b/tests/test_benchmarks_http_websocket.py @@ -36,8 +36,8 @@ def test_read_one_hundred_websocket_text_messages( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture ) -> None: """Benchmark reading 100 WebSocket text messages.""" - queue = WebSocketDataQueue(BaseProtocol(loop), 2**16, loop=loop) - reader = WebSocketReader(queue, max_msg_size=2**16) + queue = WebSocketDataQueue(BaseProtocol(loop), 2**18, loop=loop) + reader = WebSocketReader(queue, max_msg_size=2**18) raw_message = ( b'\x81~\x01!{"id":1,"src":"shellyplugus-c049ef8c30e4","dst":"aios-1453812500' b'8","result":{"name":null,"id":"shellyplugus-c049ef8c30e4","mac":"C049EF8C30E' diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 8ee45330bb5..80e95c29512 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -53,7 +53,6 @@ ) from aiohttp.client_reqrep import ClientRequest from aiohttp.compression_utils import DEFAULT_MAX_DECOMPRESS_SIZE -from aiohttp.http_exceptions import DecompressSizeError from aiohttp.payload import ( AsyncIterablePayload, BufferedReaderPayload, @@ -2407,10 +2406,9 @@ async def test_payload_decompress_size_limit(aiohttp_client: AiohttpClient) -> N When a compressed payload expands beyond the configured limit, we raise DecompressSizeError. """ - # Create a highly compressible payload that exceeds the decompression limit. - # 64MiB of repeated bytes compresses to ~32KB but expands beyond the - # 32MiB per-call limit. - original = b"A" * (64 * 2**20) + # Create a highly compressible payload. + payload_size = 64 * 2**20 + original = b"A" * payload_size compressed = zlib.compress(original) assert len(original) > DEFAULT_MAX_DECOMPRESS_SIZE @@ -2427,11 +2425,11 @@ async def handler(request: web.Request) -> web.Response: async with client.get("/") as resp: assert resp.status == 200 - with pytest.raises(aiohttp.ClientPayloadError) as exc_info: - await resp.read() + received = 0 + async for chunk in resp.content.iter_chunked(1024): + received += len(chunk) - assert isinstance(exc_info.value.__cause__, DecompressSizeError) - assert "Decompressed data exceeds" in str(exc_info.value.__cause__) + assert received == payload_size @pytest.mark.skipif(brotli is None, reason="brotli is not installed") @@ -2440,8 +2438,9 @@ async def test_payload_decompress_size_limit_brotli( ) -> None: """Test that brotli decompression size limit triggers DecompressSizeError.""" assert brotli is not None - # Create a highly compressible payload that exceeds the decompression limit. - original = b"A" * (64 * 2**20) + # Create a highly compressible payload + payload_size = 64 * 2**20 + original = b"A" * payload_size compressed = brotli.compress(original) assert len(original) > DEFAULT_MAX_DECOMPRESS_SIZE @@ -2457,11 +2456,11 @@ async def handler(request: web.Request) -> web.Response: async with client.get("/") as resp: assert resp.status == 200 - with pytest.raises(aiohttp.ClientPayloadError) as exc_info: - await resp.read() + received = 0 + async for chunk in resp.content.iter_chunked(1024): + received += len(chunk) - assert isinstance(exc_info.value.__cause__, DecompressSizeError) - assert "Decompressed data exceeds" in str(exc_info.value.__cause__) + assert received == payload_size @pytest.mark.skipif(ZstdCompressor is None, reason="backports.zstd is not installed") @@ -2470,8 +2469,9 @@ async def test_payload_decompress_size_limit_zstd( ) -> None: """Test that zstd decompression size limit triggers DecompressSizeError.""" assert ZstdCompressor is not None - # Create a highly compressible payload that exceeds the decompression limit. - original = b"A" * (64 * 2**20) + # Create a highly compressible payload. + payload_size = 64 * 2**20 + original = b"A" * payload_size compressor = ZstdCompressor() compressed = compressor.compress(original) + compressor.flush() assert len(original) > DEFAULT_MAX_DECOMPRESS_SIZE @@ -2488,11 +2488,11 @@ async def handler(request: web.Request) -> web.Response: async with client.get("/") as resp: assert resp.status == 200 - with pytest.raises(aiohttp.ClientPayloadError) as exc_info: - await resp.read() + received = 0 + async for chunk in resp.content.iter_chunked(1024): + received += len(chunk) - assert isinstance(exc_info.value.__cause__, DecompressSizeError) - assert "Decompressed data exceeds" in str(exc_info.value.__cause__) + assert received == payload_size async def test_bad_payload_chunked_encoding(aiohttp_client: AiohttpClient) -> None: diff --git a/tests/test_client_proto.py b/tests/test_client_proto.py index 49a81c8dbb3..0a26a211453 100644 --- a/tests/test_client_proto.py +++ b/tests/test_client_proto.py @@ -10,7 +10,7 @@ from aiohttp.client_proto import ResponseHandler from aiohttp.client_reqrep import ClientResponse from aiohttp.helpers import TimerNoop -from aiohttp.http_parser import RawResponseMessage +from aiohttp.http_parser import HttpParser, RawResponseMessage async def test_force_close(loop: asyncio.AbstractEventLoop) -> None: @@ -35,7 +35,9 @@ async def test_oserror(loop: asyncio.AbstractEventLoop) -> None: async def test_pause_resume_on_error(loop: asyncio.AbstractEventLoop) -> None: + parser = mock.create_autospec(HttpParser, spec_set=True, instance=True) proto = ResponseHandler(loop=loop) + proto._parser = parser transport = mock.Mock() proto.connection_made(transport) diff --git a/tests/test_flowcontrol_streams.py b/tests/test_flowcontrol_streams.py index 9e21f786610..3654ba4aad2 100644 --- a/tests/test_flowcontrol_streams.py +++ b/tests/test_flowcontrol_streams.py @@ -5,6 +5,7 @@ from aiohttp import streams from aiohttp.base_protocol import BaseProtocol +from aiohttp.http_parser import HttpParser @pytest.fixture @@ -38,7 +39,6 @@ async def test_readline(self, stream: streams.StreamReader) -> None: stream.feed_data(b"d\n") res = await stream.readline() assert res == b"d\n" - assert not stream._protocol.resume_reading.called # type: ignore[attr-defined] async def test_readline_resume_paused(self, stream: streams.StreamReader) -> None: stream._protocol._reading_paused = True @@ -51,7 +51,6 @@ async def test_readany(self, stream: streams.StreamReader) -> None: stream.feed_data(b"data") res = await stream.readany() assert res == b"data" - assert not stream._protocol.resume_reading.called # type: ignore[attr-defined] async def test_readany_resume_paused(self, stream: streams.StreamReader) -> None: stream._protocol._reading_paused = True @@ -65,7 +64,6 @@ async def test_readchunk(self, stream: streams.StreamReader) -> None: res, end_of_http_chunk = await stream.readchunk() assert res == b"data" assert not end_of_http_chunk - assert not stream._protocol.resume_reading.called # type: ignore[attr-defined] async def test_readchunk_resume_paused(self, stream: streams.StreamReader) -> None: stream._protocol._reading_paused = True @@ -120,7 +118,8 @@ async def test_resumed_on_eof(self, stream: streams.StreamReader) -> None: async def test_stream_reader_eof_when_full() -> None: loop = asyncio.get_event_loop() - protocol = BaseProtocol(loop=loop) + parser = mock.create_autospec(HttpParser, spec_set=True, instance=True) + protocol = BaseProtocol(loop=loop, parser=parser) protocol.transport = asyncio.Transport() stream = streams.StreamReader(protocol, 1024, loop=loop) diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index 2c593a7589c..35fce70ef5f 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -1,10 +1,11 @@ # Tests for aiohttp/protocol.py import asyncio +import platform import re import sys import zlib -from collections.abc import Iterable +from collections.abc import Iterable, Iterator from contextlib import suppress from typing import Any from unittest import mock @@ -17,6 +18,7 @@ import aiohttp from aiohttp import http_exceptions, streams from aiohttp.base_protocol import BaseProtocol +from aiohttp.client_proto import ResponseHandler from aiohttp.helpers import NO_EXTENSIONS from aiohttp.http_parser import ( DeflateBuffer, @@ -27,8 +29,12 @@ HttpRequestParserPy, HttpResponseParser, HttpResponseParserPy, + PayloadState, ) from aiohttp.http_writer import HttpVersion +from aiohttp.web_protocol import RequestHandler +from aiohttp.web_request import Request +from aiohttp.web_server import Server try: try: @@ -56,9 +62,23 @@ RESPONSE_PARSERS.append(HttpResponseParserC) +@pytest.fixture +def server() -> Any: + return mock.create_autospec( + Server, + request_factory=mock.Mock(), + request_handler=mock.AsyncMock(), + instance=True, + ) + + @pytest.fixture def protocol() -> Any: - return mock.create_autospec(BaseProtocol, spec_set=True, instance=True) + return mock.create_autospec( + BaseProtocol, + spec_set=True, + instance=True, + ) def _gen_ids(parsers: Iterable[type[HttpParser[Any]]]) -> list[str]: @@ -71,18 +91,24 @@ def _gen_ids(parsers: Iterable[type[HttpParser[Any]]]) -> list[str]: @pytest.fixture(params=REQUEST_PARSERS, ids=_gen_ids(REQUEST_PARSERS)) def parser( loop: asyncio.AbstractEventLoop, - protocol: BaseProtocol, + server: Server[Request], request: pytest.FixtureRequest, -) -> HttpRequestParser: +) -> Iterator[HttpRequestParser]: + protocol = RequestHandler(server, loop=loop) + # Parser implementations - return request.param( # type: ignore[no-any-return] + parser = request.param( protocol, loop, - 2**16, + 2**18, max_line_size=8190, max_headers=128, max_field_size=8190, ) + protocol._force_close = False + protocol._parser = parser + with mock.patch.object(protocol, "transport", True): + yield parser @pytest.fixture(params=REQUEST_PARSERS, ids=_gen_ids(REQUEST_PARSERS)) @@ -94,19 +120,22 @@ def request_cls(request: pytest.FixtureRequest) -> type[HttpRequestParser]: @pytest.fixture(params=RESPONSE_PARSERS, ids=_gen_ids(RESPONSE_PARSERS)) def response( loop: asyncio.AbstractEventLoop, - protocol: BaseProtocol, request: pytest.FixtureRequest, ) -> HttpResponseParser: + protocol = ResponseHandler(loop) + # Parser implementations - return request.param( # type: ignore[no-any-return] + parser = request.param( protocol, loop, - 2**16, + 2**18, max_line_size=8190, max_headers=128, max_field_size=8190, read_until_eof=True, ) + protocol._parser = parser + return parser # type: ignore[no-any-return] @pytest.fixture(params=RESPONSE_PARSERS, ids=_gen_ids(RESPONSE_PARSERS)) @@ -154,9 +183,11 @@ def test_reject_obsolete_line_folding(parser: HttpRequestParser) -> None: @pytest.mark.skipif(NO_EXTENSIONS, reason="Only tests C parser.") def test_invalid_character( loop: asyncio.AbstractEventLoop, - protocol: BaseProtocol, + server: Server[Request], request: pytest.FixtureRequest, ) -> None: + protocol = RequestHandler(server, loop=loop) + parser = HttpRequestParserC( protocol, loop, @@ -164,6 +195,7 @@ def test_invalid_character( max_line_size=8190, max_field_size=8190, ) + protocol._parser = parser text = b"POST / HTTP/1.1\r\nHost: localhost:8080\r\nSet-Cookie: abc\x01def\r\n\r\n" error_detail = re.escape(r""": @@ -176,9 +208,11 @@ def test_invalid_character( @pytest.mark.skipif(NO_EXTENSIONS, reason="Only tests C parser.") def test_invalid_linebreak( loop: asyncio.AbstractEventLoop, - protocol: BaseProtocol, + server: Server[Request], request: pytest.FixtureRequest, ) -> None: + protocol = RequestHandler(server, loop=loop) + parser = HttpRequestParserC( protocol, loop, @@ -186,6 +220,7 @@ def test_invalid_linebreak( max_line_size=8190, max_field_size=8190, ) + protocol._parser = parser text = b"GET /world HTTP/1.1\r\nHost: 127.0.0.1\n\r\n" error_detail = re.escape(r""": @@ -250,8 +285,10 @@ def test_ctl_host_header_bad_characters(parser: HttpRequestParser) -> None: def test_unpaired_surrogate_in_header_py( - loop: asyncio.AbstractEventLoop, protocol: BaseProtocol + loop: asyncio.AbstractEventLoop, server: Server[Request] ) -> None: + protocol = RequestHandler(server, loop=loop) + parser = HttpRequestParserPy( protocol, loop, @@ -259,6 +296,7 @@ def test_unpaired_surrogate_in_header_py( max_line_size=8190, max_field_size=8190, ) + protocol._parser = parser text = b"POST / HTTP/1.1\r\n\xff\r\n\r\n" message = None try: @@ -1013,6 +1051,203 @@ def test_max_header_value_size_under_limit(parser: HttpRequestParser) -> None: assert msg.url == URL("/test") +async def test_chunk_splits_after_pause(parser: HttpRequestParser) -> None: + text = ( + b"GET /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + + b"1\r\nb\r\n" * 50000 + + b"0\r\n\r\n" + ) + + messages, upgrade, tail = parser.feed_data(text) + payload = messages[0][-1] + # Payload should have paused reading and stopped receiving new chunks after 16k. + assert payload._http_chunk_splits is not None + assert len(payload._http_chunk_splits) == 16385 + # We should still get the full result after read(), as it will continue processing. + result = await payload.read() + assert len(result) == 50000 # Compare len first, as it's easier to debug in diff. + assert result == b"b" * 50000 + + +async def test_compressed_with_tail(response: HttpResponseParser) -> None: + """Test compressed content-length body followed by a second response. + + With 2 responses arriving in one call and the first compressed, this should + trigger decompression pausing with the second response being saved as the tail. + Verify that the second response is resumed from the tail. + """ + # Must be large enough to exceed high water mark. + original = b"x" * 1024 * 1024 + compressed = zlib.compress(original) + resp1 = ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: " + str(len(compressed)).encode() + b"\r\n" + b"Content-Encoding: deflate\r\n" + b"\r\n" + ) + compressed + resp2 = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok" + + msgs, upgrade, tail = response.feed_data(resp1 + resp2) + payload = msgs[0][-1] + result = await payload.read() + assert len(result) == len(original) + assert result == original + + payload = response.protocol._buffer[0][-1] + result = await payload.read() + assert result == b"ok" + + +async def test_two_content_length_responses_in_one_call( + response: HttpResponseParser, +) -> None: + """Two complete responses in a single feed_data call. + + The first payload completes with tail data for the second, hitting the + PAYLOAD_COMPLETE branch that resets the parser for the next message. + """ + resp1 = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello" + resp2 = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nworld" + + msgs, upgrade, tail = response.feed_data(resp1 + resp2) + assert len(msgs) == 2 + assert await msgs[0][-1].read() == b"hello" + assert await msgs[1][-1].read() == b"world" + + +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_compressed_zlib_64kb(response_cls: type[HttpResponseParser]) -> None: + loop = asyncio.get_running_loop() + protocol = ResponseHandler(loop) + response = response_cls( + protocol, + loop, + # 64KiB limit triggered a bug with isal implementation not returning all data. + 2**16, + max_line_size=8190, + max_headers=128, + max_field_size=8190, + ) + protocol._parser = response + + original = b"".join( + bytes((*range(0, i), *range(i, 0, -1))) for _ in range(255) for i in range(255) + ) + compressed = zlib.compress(original) + headers = ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: " + str(len(compressed)).encode() + b"\r\n" + b"Content-Encoding: deflate\r\n" + b"\r\n" + ) + + msgs, upgrade, tail = response.feed_data(headers + compressed) + payload = msgs[0][-1] + result = await payload.read() + assert len(result) == len(original) + assert result == original + + +async def test_compressed_chunked_with_pending(response: HttpResponseParser) -> None: + """Test chunked + compressed where the decompressor needs to resume from pause. + + We need to verify that chunked messages continue parsing correctly after + a pause and resume in the decompression. + """ + # Must be large enough to exceed high water mark. + original = b"A" * 1024 * 1024 + compressed = zlib.compress(original) + chunk_data = hex(len(compressed))[2:].encode() + b"\r\n" + compressed + b"\r\n" + headers = ( + b"HTTP/1.1 200 OK\r\n" + b"Transfer-Encoding: chunked\r\n" + b"Content-Encoding: deflate\r\n" + b"\r\n" + ) + data = headers + chunk_data + b"0\r\n\r\n" + + msgs, upgrade, tail = response.feed_data(data) + payload = msgs[0][-1] + result = await payload.read() + assert len(result) == len(original) + assert result == original + + +async def test_compressed_until_eof_with_pending(response: HttpResponseParser) -> None: + """Test read-until-eof + compressed with pause.""" + # Must be large enough to exceed high water mark. + original = b"B" * 5 * 1024 * 1024 + compressed = zlib.compress(original) + # No Content-Length or Transfer-Encoding means the parser must parse until EOF. + headers = b"HTTP/1.1 200 OK\r\nContent-Encoding: deflate\r\n\r\n" + + msgs, upgrade, tail = response.feed_data(headers + compressed) + response.feed_eof() + payload = msgs[0][-1] + + # Check that .feed_eof() hasn't decompressed entire payload into memory. + assert sum(len(b) for b in payload._buffer) <= (1024 * 1024) + + result = await payload.read() + assert len(result) == len(original) + assert result == original + + +async def test_compressed_until_eof_high_water( + response_cls: type[HttpResponseParser], +) -> None: + """Test read-until-eof + compressed with higher limit.""" + loop = asyncio.get_running_loop() + protocol = ResponseHandler(loop) + response = response_cls( + protocol, + loop, + 2**19, # 512 KiB limit + max_line_size=8190, + max_headers=128, + max_field_size=8190, + read_until_eof=True, + ) + protocol._parser = response + + # Must be large enough to exceed high water mark. + original = b"B" * 5 * 1024 * 1024 + compressed = zlib.compress(original) + # No Content-Length or Transfer-Encoding means the parser must parse until EOF. + headers = b"HTTP/1.1 200 OK\r\nContent-Encoding: deflate\r\n\r\n" + + msgs, upgrade, tail = response.feed_data(headers + compressed) + response.feed_eof() + payload = msgs[0][-1] + + # Check that .feed_eof() hasn't decompressed entire payload into memory. + assert sum(len(b) for b in payload._buffer) <= (2 * 1024 * 1024) + # Individual chunks should have been decompressed at limit amount. + assert all(len(b) == 512 * 1024 for b in payload._buffer) + + result = await payload.read() + assert len(result) == len(original) + assert result == original + + +async def test_compressed_256kb(response: HttpResponseParser) -> None: + original = b"x" * 256 * 1024 + compressed = zlib.compress(original) + headers = ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: " + str(len(compressed)).encode() + b"\r\n" + b"Content-Encoding: deflate\r\n" + b"\r\n" + ) + + messages, upgrade, tail = response.feed_data(headers + compressed) + assert len(messages) == 1 + payload = messages[0][-1] + result = await payload.read() + assert len(result) == len(original) + assert result == original + + @pytest.mark.parametrize("size", [40965, 8191]) def test_max_header_value_size_continuation( response: HttpResponseParser, size: int @@ -1447,15 +1682,18 @@ async def test_http_response_parser_bad_chunked_lax( @pytest.mark.dev_mode async def test_http_response_parser_bad_chunked_strict_py( - loop: asyncio.AbstractEventLoop, protocol: BaseProtocol + loop: asyncio.AbstractEventLoop, ) -> None: + protocol = ResponseHandler(loop) + response = HttpResponseParserPy( protocol, loop, - 2**16, + 2**18, max_line_size=8190, max_field_size=8190, ) + protocol._parser = response text = ( b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5 \r\nabcde\r\n0\r\n\r\n" ) @@ -1469,8 +1707,10 @@ async def test_http_response_parser_bad_chunked_strict_py( reason="C based HTTP parser not available", ) async def test_http_response_parser_bad_chunked_strict_c( - loop: asyncio.AbstractEventLoop, protocol: BaseProtocol + loop: asyncio.AbstractEventLoop, ) -> None: + protocol = ResponseHandler(loop) + response = HttpResponseParserC( protocol, loop, @@ -1478,6 +1718,7 @@ async def test_http_response_parser_bad_chunked_strict_c( max_line_size=8190, max_field_size=8190, ) + protocol._parser = response text = ( b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5 \r\nabcde\r\n0\r\n\r\n" ) @@ -1628,10 +1869,12 @@ async def test_request_chunked_reject_bad_trailer(parser: HttpRequestParser) -> def test_parse_no_length_or_te_on_post( loop: asyncio.AbstractEventLoop, - protocol: BaseProtocol, + server: Server[Request], request_cls: type[HttpRequestParser], ) -> None: - parser = request_cls(protocol, loop, limit=2**16) + protocol = RequestHandler(server, loop=loop) + parser = request_cls(protocol, loop, limit=2**18) + protocol._parser = parser text = b"POST /test HTTP/1.1\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] @@ -1640,10 +1883,11 @@ def test_parse_no_length_or_te_on_post( def test_parse_payload_response_without_body( loop: asyncio.AbstractEventLoop, - protocol: BaseProtocol, response_cls: type[HttpResponseParser], ) -> None: + protocol = ResponseHandler(loop) parser = response_cls(protocol, loop, 2**16, response_with_body=False) + protocol._parser = parser text = b"HTTP/1.1 200 Ok\r\ncontent-length: 10\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] @@ -1904,17 +2148,20 @@ def test_parse_uri_utf8_percent_encoded(parser: HttpRequestParser) -> None: reason="C based HTTP parser not available", ) def test_parse_bad_method_for_c_parser_raises( - loop: asyncio.AbstractEventLoop, protocol: BaseProtocol + loop: asyncio.AbstractEventLoop, server: Server[Request] ) -> None: + protocol = RequestHandler(server, loop=loop) + payload = b"GET1 /test HTTP/1.1\r\n\r\n" parser = HttpRequestParserC( protocol, loop, - 2**16, + 2**18, max_line_size=8190, max_headers=128, max_field_size=8190, ) + protocol._parser = parser with pytest.raises(aiohttp.http_exceptions.BadStatusLine): messages, upgrade, tail = parser.feed_data(payload) @@ -1931,7 +2178,7 @@ async def test_parse_eof_payload(self, protocol: BaseProtocol) -> None: assert [bytearray(b"data")] == list(out._buffer) async def test_parse_length_payload_eof(self, protocol: BaseProtocol) -> None: - out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, length=4, headers_parser=HeadersParser()) p.feed_data(b"da") @@ -1955,7 +2202,7 @@ async def test_parse_chunked_payload_size_data_mismatch( Regression test for #10596. """ - out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) # Declared chunk-size is 4 but actual data is "Hello" (5 bytes). # After consuming 4 bytes, remaining starts with "o" not "\r\n". @@ -1970,7 +2217,7 @@ async def test_parse_chunked_payload_size_data_mismatch_too_short( Regression test for #10596. """ - out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) # Declared chunk-size is 6 but actual data before CRLF is "Hello" (5 bytes). # Parser reads 6 bytes: "Hello\r", then expects \r\n but sees "\n0\r\n..." @@ -1992,7 +2239,7 @@ async def test_parse_chunked_payload_split_end( async def test_parse_chunked_payload_split_end2( self, protocol: BaseProtocol ) -> None: - out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) p.feed_data(b"4\r\nasdf\r\n0\r\n\r") p.feed_data(b"\n") @@ -2003,7 +2250,7 @@ async def test_parse_chunked_payload_split_end2( async def test_parse_chunked_payload_split_end_trailers( self, protocol: BaseProtocol ) -> None: - out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) p.feed_data(b"4\r\nasdf\r\n0\r\n") p.feed_data(b"Content-MD5: 912ec803b2ce49e4a541068d495ab570\r\n") @@ -2015,7 +2262,7 @@ async def test_parse_chunked_payload_split_end_trailers( async def test_parse_chunked_payload_split_end_trailers2( self, protocol: BaseProtocol ) -> None: - out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) p.feed_data(b"4\r\nasdf\r\n0\r\n") p.feed_data(b"Content-MD5: 912ec803b2ce49e4a541068d495ab570\r\n\r") @@ -2047,10 +2294,10 @@ async def test_parse_chunked_payload_split_end_trailers4( assert b"asdf" == b"".join(out._buffer) async def test_http_payload_parser_length(self, protocol: BaseProtocol) -> None: - out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, length=2, headers_parser=HeadersParser()) - eof, tail = p.feed_data(b"1245") - assert eof + state, tail = p.feed_data(b"1245") + assert state is PayloadState.PAYLOAD_COMPLETE assert b"12" == out._buffer[0] assert b"45" == tail @@ -2060,7 +2307,7 @@ async def test_http_payload_parser_deflate(self, protocol: BaseProtocol) -> None COMPRESSED = b"x\x9cKI,I\x04\x00\x04\x00\x01\x9b" length = len(COMPRESSED) - out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) p = HttpPayloadParser( out, length=length, compression="deflate", headers_parser=HeadersParser() ) @@ -2131,7 +2378,7 @@ async def test_http_payload_parser_deflate_split_err( async def test_http_payload_parser_length_zero( self, protocol: BaseProtocol ) -> None: - out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, length=0, headers_parser=HeadersParser()) assert p.done assert out.is_eof() @@ -2139,7 +2386,7 @@ async def test_http_payload_parser_length_zero( @pytest.mark.skipif(brotli is None, reason="brotli is not installed") async def test_http_payload_brotli(self, protocol: BaseProtocol) -> None: compressed = brotli.compress(b"brotli data") - out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) p = HttpPayloadParser( out, length=len(compressed), @@ -2153,7 +2400,7 @@ async def test_http_payload_brotli(self, protocol: BaseProtocol) -> None: @pytest.mark.skipif(zstandard is None, reason="zstandard is not installed") async def test_http_payload_zstandard(self, protocol: BaseProtocol) -> None: compressed = zstandard.compress(b"zstd data") - out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) p = HttpPayloadParser( out, length=len(compressed), @@ -2171,7 +2418,7 @@ async def test_http_payload_zstandard_multi_frame( frame1 = zstandard.compress(b"first") frame2 = zstandard.compress(b"second") payload = frame1 + frame2 - out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) p = HttpPayloadParser( out, length=len(payload), @@ -2188,7 +2435,7 @@ async def test_http_payload_zstandard_multi_frame_chunked( ) -> None: frame1 = zstandard.compress(b"chunk1") frame2 = zstandard.compress(b"chunk2") - out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) p = HttpPayloadParser( out, length=len(frame1) + len(frame2), @@ -2208,7 +2455,7 @@ async def test_http_payload_zstandard_frame_split_mid_chunk( frame2 = zstandard.compress(b"BBBB") combined = frame1 + frame2 split_point = len(frame1) + 3 # 3 bytes into frame2 - out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) p = HttpPayloadParser( out, length=len(combined), @@ -2226,7 +2473,7 @@ async def test_http_payload_zstandard_many_small_frames( ) -> None: parts = [f"part{i}".encode() for i in range(10)] payload = b"".join(zstandard.compress(p) for p in parts) - out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) p = HttpPayloadParser( out, length=len(payload), @@ -2240,7 +2487,7 @@ async def test_http_payload_zstandard_many_small_frames( class TestDeflateBuffer: async def test_feed_data(self, protocol: BaseProtocol) -> None: - buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + buf = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) dbuf = DeflateBuffer(buf, "deflate") dbuf.decompressor = mock.Mock() @@ -2268,10 +2515,10 @@ async def test_feed_eof(self, protocol: BaseProtocol) -> None: dbuf = DeflateBuffer(buf, "deflate") dbuf.decompressor = mock.Mock() - dbuf.decompressor.flush.return_value = b"line" + dbuf.decompressor.data_available = False + dbuf.decompressor.flush.return_value = b"" dbuf.feed_eof() - assert [b"line"] == list(buf._buffer) assert buf._eof async def test_feed_eof_err_deflate(self, protocol: BaseProtocol) -> None: @@ -2279,8 +2526,10 @@ async def test_feed_eof_err_deflate(self, protocol: BaseProtocol) -> None: dbuf = DeflateBuffer(buf, "deflate") dbuf.decompressor = mock.Mock() - dbuf.decompressor.flush.return_value = b"line" + dbuf.decompressor.data_available = False + dbuf.decompressor.flush.return_value = b"" dbuf.decompressor.eof = False + dbuf.size = 1 # Simulate that data was previously fed with pytest.raises(http_exceptions.ContentEncodingError): dbuf.feed_eof() @@ -2290,22 +2539,24 @@ async def test_feed_eof_no_err_gzip(self, protocol: BaseProtocol) -> None: dbuf = DeflateBuffer(buf, "gzip") dbuf.decompressor = mock.Mock() - dbuf.decompressor.flush.return_value = b"line" + dbuf.decompressor.data_available = False + dbuf.decompressor.flush.return_value = b"" dbuf.decompressor.eof = False dbuf.feed_eof() - assert [b"line"] == list(buf._buffer) + assert buf._eof async def test_feed_eof_no_err_brotli(self, protocol: BaseProtocol) -> None: buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) dbuf = DeflateBuffer(buf, "br") dbuf.decompressor = mock.Mock() - dbuf.decompressor.flush.return_value = b"line" + dbuf.decompressor.data_available = False + dbuf.decompressor.flush.return_value = b"" dbuf.decompressor.eof = False dbuf.feed_eof() - assert [b"line"] == list(buf._buffer) + assert buf._eof @pytest.mark.skipif(zstandard is None, reason="zstandard is not installed") async def test_feed_eof_no_err_zstandard(self, protocol: BaseProtocol) -> None: @@ -2313,19 +2564,21 @@ async def test_feed_eof_no_err_zstandard(self, protocol: BaseProtocol) -> None: dbuf = DeflateBuffer(buf, "zstd") dbuf.decompressor = mock.Mock() - dbuf.decompressor.flush.return_value = b"line" + dbuf.decompressor.data_available = False + dbuf.decompressor.flush.return_value = b"" dbuf.decompressor.eof = False dbuf.feed_eof() - assert [b"line"] == list(buf._buffer) + assert buf._eof async def test_empty_body(self, protocol: BaseProtocol) -> None: - buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + buf = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) dbuf = DeflateBuffer(buf, "deflate") dbuf.feed_eof() assert buf.at_eof() + @pytest.mark.skipif(platform.python_implementation() == "PyPy", reason="Broken") @pytest.mark.parametrize( "chunk_size", [1024, 2**14, 2**16], # 1KB, 16KB, 64KB @@ -2344,13 +2597,14 @@ async def test_streaming_decompress_large_payload( original = b"A" * (3 * 2**20) compressed = zlib.compress(original) - buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + buf = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) dbuf = DeflateBuffer(buf, "deflate") # Feed compressed data in chunks (simulating network streaming) for i in range(0, len(compressed), chunk_size): # pragma: no branch chunk = compressed[i : i + chunk_size] - dbuf.feed_data(chunk) + while dbuf.feed_data(chunk): + chunk = b"" dbuf.feed_eof() diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index e2596ea2e96..546ea60cd8b 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -1459,7 +1459,7 @@ async def test_write_drain_condition_with_small_buffer( protocol._drain_helper.reset_mock() # type: ignore[attr-defined] # Write small amount of data with drain=True but buffer under limit - small_data = b"x" * 100 # Much less than LIMIT (2**16) + small_data = b"x" * 100 # Much less than LIMIT (2**18) await msg.write(small_data, drain=True) # Drain should NOT be called because buffer_size <= LIMIT @@ -1488,7 +1488,7 @@ async def test_write_drain_condition_with_large_buffer( protocol._drain_helper.reset_mock() # type: ignore[attr-defined] # Write large amount of data with drain=True - large_data = b"x" * (2**16 + 1) # Just over LIMIT + large_data = b"x" * (2**18 + 1) # Just over LIMIT await msg.write(large_data, drain=True) # Drain should be called because drain=True AND buffer_size > LIMIT @@ -1517,12 +1517,12 @@ async def test_write_no_drain_with_large_buffer( protocol._drain_helper.reset_mock() # type: ignore[attr-defined] # Write large amount of data with drain=False - large_data = b"x" * (2**16 + 1) # Just over LIMIT + large_data = b"x" * (2**18 + 1) # Just over LIMIT await msg.write(large_data, drain=False) # Drain should NOT be called because drain=False assert not protocol._drain_helper.called # type: ignore[attr-defined] - assert msg.buffer_size == (2**16 + 1) # Buffer not reset + assert msg.buffer_size == (2**18 + 1) # Buffer not reset assert large_data in buf diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 52e97a993a3..30659405599 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -1,4 +1,5 @@ import asyncio +import gzip import io import json import pathlib @@ -27,6 +28,7 @@ MultipartResponseWrapper, ) from aiohttp.streams import StreamReader +from aiohttp.web_exceptions import HTTPRequestEntityTooLarge if sys.version_info >= (3, 11): from typing import Self @@ -354,12 +356,17 @@ async def test_read_with_content_encoding_gzip(self) -> None: result = await obj.read(decode=True) assert b"Time to Relax!" == result + @pytest.mark.skipif(sys.version_info < (3, 11), reason="wbits not available") async def test_read_with_content_encoding_deflate(self) -> None: + content = b"A" * 1_000_000 # Large enough to exceed max_length. + compressed = ZLibBackend.compress(content, wbits=-ZLibBackend.MAX_WBITS) + h = CIMultiDictProxy(CIMultiDict({CONTENT_ENCODING: "deflate"})) - with Stream(b"\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00\r\n--:--") as stream: + with Stream(compressed + b"\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = await obj.read(decode=True) - assert b"Time to Relax!" == result + assert len(result) == len(content) # Simplifies diff on failure + assert result == content async def test_read_with_content_encoding_identity(self) -> None: thing = ( @@ -380,6 +387,22 @@ async def test_read_with_content_encoding_unknown(self) -> None: with pytest.raises(RuntimeError): await obj.read(decode=True) + async def test_read_decode_compressed_exceeds_max_size(self) -> None: + # Compressed data is small, but decompresses beyond client_max_size. + original = b"A" * 1024 + compressed = gzip.compress(original) + h = CIMultiDictProxy(CIMultiDict({CONTENT_ENCODING: "gzip"})) + with Stream(compressed + b"\r\n--:--") as stream: + obj = aiohttp.BodyPartReader( + BOUNDARY, + h, + stream, + client_max_size=256, + max_size_error_cls=HTTPRequestEntityTooLarge, + ) + with pytest.raises(HTTPRequestEntityTooLarge): + await obj.read(decode=True) + async def test_read_with_content_transfer_encoding_base64(self) -> None: h = CIMultiDictProxy(CIMultiDict({CONTENT_TRANSFER_ENCODING: "base64"})) with Stream(b"VGltZSB0byBSZWxheCE=\r\n--:--") as stream: @@ -651,9 +674,9 @@ async def test_filename(self) -> None: assert "foo.html" == part.filename async def test_reading_long_part(self) -> None: - size = 2 * 2**16 + size = 2 * 2**18 protocol = mock.Mock(_reading_paused=False) - stream = StreamReader(protocol, 2**16, loop=asyncio.get_event_loop()) + stream = StreamReader(protocol, 2**18, loop=asyncio.get_event_loop()) stream.feed_data(b"0" * size + b"\r\n--:--") stream.feed_eof() d = CIMultiDictProxy[str](CIMultiDict()) @@ -1721,6 +1744,35 @@ async def test_body_part_reader_payload_as_bytes() -> None: payload.decode() +@pytest.mark.skipif(sys.version_info < (3, 11), reason="No wbits parameter") +async def test_body_part_reader_payload_write() -> None: + content = b"A" * 1_000_000 # Large enough to exceed max_length. + compressed = ZLibBackend.compress(content, wbits=-ZLibBackend.MAX_WBITS) + output = b"" + + async def write(inp: bytes) -> None: + nonlocal output + output += inp + + h = CIMultiDictProxy(CIMultiDict({CONTENT_ENCODING: "deflate"})) + if sys.version_info >= (3, 12): + writer = mock.create_autospec( + AbstractStreamWriter, write=write, spec_set=True, instance=True + ) + else: + writer = mock.create_autospec( + AbstractStreamWriter, spec_set=True, instance=True + ) + writer.write.side_effect = write + with Stream(compressed + b"\r\n--:--") as stream: + body_part = aiohttp.BodyPartReader(BOUNDARY, h, stream) + payload = BodyPartReaderPayload(body_part) + await payload.write(writer) + + assert len(output) == len(content) # Simplifies diff on failure + assert output == content + + async def test_multipart_writer_close_with_exceptions() -> None: """Test that MultipartWriter.close() continues closing all parts even if one raises.""" writer = aiohttp.MultipartWriter() diff --git a/tests/test_payload.py b/tests/test_payload.py index 205a3efdf81..e38335f546f 100644 --- a/tests/test_payload.py +++ b/tests/test_payload.py @@ -328,7 +328,7 @@ def mock_read(size: int | None = None) -> bytes: async def test_bytesio_payload_large_data_multiple_chunks() -> None: """Test BytesIOPayload with large data requiring multiple read chunks.""" - chunk_size = 2**16 # 64KB (READ_SIZE) + chunk_size = 2**18 # 256KiB (READ_SIZE) data = b"x" * (chunk_size + 1000) # Slightly larger than READ_SIZE payload_bytesio = payload.BytesIOPayload(io.BytesIO(data)) writer = MockStreamWriter() @@ -352,7 +352,7 @@ async def test_bytesio_payload_remaining_bytes_exhausted() -> None: async def test_iobase_payload_exact_chunk_size_limit() -> None: """Test IOBasePayload with content length matching exactly one read chunk.""" - chunk_size = 2**16 # 65536 bytes (READ_SIZE) + chunk_size = 2**18 # 256KiB (READ_SIZE) data = b"x" * chunk_size + b"extra" # Slightly larger than one read chunk p = payload.IOBasePayload(io.BytesIO(data)) writer = MockStreamWriter() diff --git a/tests/test_streams.py b/tests/test_streams.py index 93e0caaac9b..6560b4698fb 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -29,7 +29,7 @@ def chunkify(seq: Sequence[_T], n: int) -> Iterator[Sequence[_T]]: async def create_stream() -> streams.StreamReader: loop = asyncio.get_event_loop() protocol = mock.Mock(_reading_paused=False) - stream = streams.StreamReader(protocol, 2**16, loop=loop) + stream = streams.StreamReader(protocol, 2**18, loop=loop) stream.feed_data(DATA) stream.feed_eof() return stream @@ -75,7 +75,7 @@ def get_memory_usage(obj: object) -> int: class TestStreamReader: DATA: bytes = b"line1\nline2\nline3\n" - def _make_one(self, limit: int = 2**16) -> streams.StreamReader: + def _make_one(self, limit: int = 2**18) -> streams.StreamReader: loop = asyncio.get_event_loop() return streams.StreamReader(mock.Mock(_reading_paused=False), limit, loop=loop) @@ -1110,7 +1110,7 @@ async def test_empty_stream_reader() -> None: assert s.set_exception(ValueError()) is None # type: ignore[func-returns-value] assert s.exception() is None assert s.feed_eof() is None # type: ignore[func-returns-value] - assert s.feed_data(b"data") is None # type: ignore[func-returns-value] + assert s.feed_data(b"data") is False assert s.at_eof() await s.wait_eof() assert await s.read() == b"" @@ -1276,7 +1276,7 @@ async def set_err() -> None: async def test_feed_data_waiters(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() - reader = streams.StreamReader(protocol, 2**16, loop=loop) + reader = streams.StreamReader(protocol, 2**18, loop=loop) waiter = reader._waiter = loop.create_future() eof_waiter = reader._eof_waiter = loop.create_future() @@ -1304,7 +1304,7 @@ async def test_feed_data_completed_waiters(protocol: BaseProtocol) -> None: async def test_feed_eof_waiters(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() - reader = streams.StreamReader(protocol, 2**16, loop=loop) + reader = streams.StreamReader(protocol, 2**18, loop=loop) waiter = reader._waiter = loop.create_future() eof_waiter = reader._eof_waiter = loop.create_future() @@ -1336,7 +1336,7 @@ async def test_feed_eof_cancelled(protocol: BaseProtocol) -> None: async def test_on_eof(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() - reader = streams.StreamReader(protocol, 2**16, loop=loop) + reader = streams.StreamReader(protocol, 2**18, loop=loop) on_eof = mock.Mock() reader.on_eof(on_eof) @@ -1357,7 +1357,7 @@ async def test_on_eof_empty_reader() -> None: async def test_on_eof_exc_in_callback(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() - reader = streams.StreamReader(protocol, 2**16, loop=loop) + reader = streams.StreamReader(protocol, 2**18, loop=loop) on_eof = mock.Mock() on_eof.side_effect = ValueError @@ -1392,7 +1392,7 @@ async def test_on_eof_eof_is_set(protocol: BaseProtocol) -> None: async def test_on_eof_eof_is_set_exception(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() - reader = streams.StreamReader(protocol, 2**16, loop=loop) + reader = streams.StreamReader(protocol, 2**18, loop=loop) reader.feed_eof() on_eof = mock.Mock() @@ -1438,7 +1438,7 @@ async def test_set_exception_cancelled(protocol: BaseProtocol) -> None: async def test_set_exception_eof_callbacks(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() - reader = streams.StreamReader(protocol, 2**16, loop=loop) + reader = streams.StreamReader(protocol, 2**18, loop=loop) on_eof = mock.Mock() reader.on_eof(on_eof) @@ -1545,8 +1545,8 @@ async def test_stream_reader_pause_on_high_water_chunks( ) -> None: """Test that reading is paused when chunk count exceeds high water mark.""" loop = asyncio.get_event_loop() - # Use small limit so high_water_chunks is small: limit // 4 = 10 - stream = streams.StreamReader(protocol, limit=40, loop=loop) + # Use small limit so high_water_chunks is small: limit // 16 = 10 + stream = streams.StreamReader(protocol, limit=160, loop=loop) assert stream._high_water_chunks == 10 assert stream._low_water_chunks == 5 @@ -1566,8 +1566,8 @@ async def test_stream_reader_resume_on_low_water_chunks( ) -> None: """Test that reading resumes when chunk count drops below low water mark.""" loop = asyncio.get_event_loop() - # Use small limit so high_water_chunks is small: limit // 4 = 10 - stream = streams.StreamReader(protocol, limit=40, loop=loop) + # Use small limit so high_water_chunks is small: limit // 16 = 10 + stream = streams.StreamReader(protocol, limit=160, loop=loop) assert stream._high_water_chunks == 10 assert stream._low_water_chunks == 5 @@ -1661,14 +1661,14 @@ async def test_stream_reader_resume_non_chunked_when_paused( protocol.resume_reading.assert_called() -@pytest.mark.parametrize("limit", [1, 2, 4]) +@pytest.mark.parametrize("limit", (1, 4, 7, 16)) async def test_stream_reader_small_limit_resumes_reading( protocol: mock.Mock, limit: int, ) -> None: """Test that small limits still allow resume_reading to be called. - Even with very small limits, high_water_chunks should be at least 3 + Even with very small limits, high_water_chunks should be at least 4 and low_water_chunks should be at least 2, with high > low to ensure proper flow control. """ @@ -1676,8 +1676,8 @@ async def test_stream_reader_small_limit_resumes_reading( stream = streams.StreamReader(protocol, limit=limit, loop=loop) # Verify minimum thresholds are enforced and high > low - assert stream._high_water_chunks >= 3 - assert stream._low_water_chunks >= 2 + assert stream._high_water_chunks == 4 + assert stream._low_water_chunks == 2 assert stream._high_water_chunks > stream._low_water_chunks # Set up pause/resume side effects @@ -1691,8 +1691,8 @@ def resume_reading() -> None: protocol.resume_reading.side_effect = resume_reading - # Feed 4 chunks (triggers pause at > high_water_chunks which is >= 3) - for char in b"abcd": + # Feed 5 chunks (triggers pause at > high_water_chunks which is 4) + for char in b"abcde": stream.begin_http_chunk_receiving() stream.feed_data(bytes([char])) stream.end_http_chunk_receiving() @@ -1703,7 +1703,7 @@ def resume_reading() -> None: # Read all data - should resume (chunk count drops below low_water_chunks) data = stream.read_nowait() - assert data == b"abcd" + assert data == b"abcde" assert stream._size == 0 protocol.resume_reading.assert_called() diff --git a/tests/test_web_exceptions.py b/tests/test_web_exceptions.py index b29de08b170..0f5d7e22ba1 100644 --- a/tests/test_web_exceptions.py +++ b/tests/test_web_exceptions.py @@ -329,14 +329,9 @@ def test_pickle(self) -> None: class TestHTTPRequestEntityTooLarge: def test_ctor(self) -> None: resp = web.HTTPRequestEntityTooLarge( - max_size=100, - actual_size=123, - headers={"X-Custom": "value"}, - reason="Too large", - ) - assert resp.text == ( - "Maximum request body size 100 exceeded, actual body size 123" + max_size=100, headers={"X-Custom": "value"}, reason="Too large" ) + assert resp.text == "Maximum request body size 100 exceeded." compare: Mapping[str, str] = {"X-Custom": "value", "Content-Type": "text/plain"} assert resp.headers == compare assert resp.reason == "Too large" @@ -344,7 +339,7 @@ def test_ctor(self) -> None: def test_pickle(self) -> None: resp = web.HTTPRequestEntityTooLarge( - 100, actual_size=123, headers={"X-Custom": "value"}, reason="Too large" + 100, headers={"X-Custom": "value"}, reason="Too large" ) resp.foo = "bar" # type: ignore[attr-defined] for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index eb0a0b54798..30976d7377a 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -328,6 +328,27 @@ async def handler(request: web.Request) -> web.Response: resp.release() +async def test_multipart_client_max_size(aiohttp_client: AiohttpClient) -> None: + with multipart.MultipartWriter() as writer: + writer.append("A" * 1020) + + async def handler(request: web.Request) -> web.Response: + reader = await request.multipart() + assert isinstance(reader, multipart.MultipartReader) + + part = await reader.next() + assert isinstance(part, multipart.BodyPartReader) + await part.text() # Should raise HttpRequestEntityTooLarge + assert False + + app = web.Application(client_max_size=1000) + app.router.add_post("/", handler) + client = await aiohttp_client(app) + + async with client.post("/", data=writer) as resp: + assert resp.status == 413 + + async def test_multipart_empty(aiohttp_client: AiohttpClient) -> None: with multipart.MultipartWriter() as writer: pass @@ -1700,10 +1721,7 @@ async def handler(request: web.Request) -> NoReturn: resp = await client.post("/", data=data) assert 413 == resp.status resp_text = await resp.text() - assert "Maximum request body size 1048576 exceeded, actual body size" in resp_text - # Maximum request body size X exceeded, actual body size X - body_size = int(resp_text.split()[-1]) - assert body_size >= max_size + assert "Maximum request body size 1048576 exceeded" in resp_text resp.release() @@ -1725,7 +1743,7 @@ async def handler(request: web.Request) -> NoReturn: async with client.post("/", data=form) as resp: assert resp.status == 413 resp_text = await resp.text() - assert "Maximum request body size 1048576 exceeded, actual body size" in resp_text + assert "Maximum request body size 1048576 exceeded" in resp_text async def test_app_max_client_size_adjusted(aiohttp_client: AiohttpClient) -> None: @@ -1752,10 +1770,7 @@ async def handler(request: web.Request) -> web.Response: resp = await client.post("/", data=too_large_data) assert 413 == resp.status resp_text = await resp.text() - assert "Maximum request body size 2097152 exceeded, actual body size" in resp_text - # Maximum request body size X exceeded, actual body size X - body_size = int(resp_text.split()[-1]) - assert body_size >= custom_max_size + assert "Maximum request body size 2097152 exceeded" in resp_text resp.release() @@ -1802,10 +1817,7 @@ async def handler(request: web.Request) -> NoReturn: assert 413 == resp.status resp_text = await resp.text() - assert ( - "Maximum request body size 10 exceeded, " - "actual body size 1024" in resp_text - ) + assert "Maximum request body size 10 exceeded" in resp_text data_file = data["file"] assert isinstance(data_file, io.BytesIO) data_file.close() diff --git a/tests/test_web_protocol.py b/tests/test_web_protocol.py index 5ae1e5dd756..9acad2f2101 100644 --- a/tests/test_web_protocol.py +++ b/tests/test_web_protocol.py @@ -1,48 +1,51 @@ import asyncio -from typing import Any, cast from unittest import mock -from aiohttp.web_protocol import RequestHandler +import pytest +from aiohttp.http import WebSocketReader +from aiohttp.web_protocol import RequestHandler +from aiohttp.web_request import BaseRequest +from aiohttp.web_server import Server -class _DummyManager: - def __init__(self) -> None: - self.request_handler = mock.Mock() - self.request_factory = mock.Mock() +@pytest.fixture +def dummy_manager() -> Server[BaseRequest]: + return mock.create_autospec(Server[BaseRequest], request_handler=mock.Mock(), request_factory=mock.Mock(), instance=True) # type: ignore[no-any-return] -class _DummyParser: - def __init__(self) -> None: - self.received: list[bytes] = [] - def feed_data(self, data: bytes) -> tuple[bool, bytes]: - self.received.append(data) - return False, b"" +@pytest.fixture +def dummy_reader() -> tuple[WebSocketReader, mock.Mock]: + m = mock.create_autospec(WebSocketReader, spec_set=True, instance=True) + m.feed_data.return_value = False, b"" + return m, m def test_set_parser_does_not_call_data_received_cb_for_tail( loop: asyncio.AbstractEventLoop, + dummy_manager: Server[BaseRequest], + dummy_reader: tuple[WebSocketReader, mock.Mock], ) -> None: - handler: RequestHandler[Any] = RequestHandler(cast(Any, _DummyManager()), loop=loop) + handler = RequestHandler(dummy_manager, loop=loop) handler._message_tail = b"tail" cb = mock.Mock() - parser = _DummyParser() - handler.set_parser(parser, data_received_cb=cb) + handler.set_parser(dummy_reader[0], data_received_cb=cb) cb.assert_not_called() - assert parser.received == [b"tail"] + dummy_reader[1].feed_data.assert_called_once_with(b"tail") def test_data_received_calls_data_received_cb( loop: asyncio.AbstractEventLoop, + dummy_manager: Server[BaseRequest], + dummy_reader: tuple[WebSocketReader, mock.Mock], ) -> None: - handler: RequestHandler[Any] = RequestHandler(cast(Any, _DummyManager()), loop=loop) + handler = RequestHandler(dummy_manager, loop=loop) cb = mock.Mock() - parser = _DummyParser() - handler.set_parser(parser, data_received_cb=cb) + handler.set_parser(dummy_reader[0], data_received_cb=cb) handler.data_received(b"x") - assert cb.call_count == 1 - assert parser.received == [b"x"] + cb.assert_called_once() + dummy_reader[1].feed_data.assert_called_once_with(b"x") diff --git a/tests/test_web_request.py b/tests/test_web_request.py index 038e6c141d9..efeb6b766b0 100644 --- a/tests/test_web_request.py +++ b/tests/test_web_request.py @@ -837,7 +837,7 @@ def test_clone_headers_dict() -> None: async def test_cannot_clone_after_read(protocol: BaseProtocol) -> None: - payload = StreamReader(protocol, 2**16, loop=asyncio.get_event_loop()) + payload = StreamReader(protocol, 2**18, loop=asyncio.get_event_loop()) payload.feed_data(b"data") payload.feed_eof() req = make_mocked_request("GET", "/path", payload=payload) @@ -860,7 +860,7 @@ async def test_make_too_big_request(protocol: BaseProtocol) -> None: async def test_request_with_wrong_content_type_encoding(protocol: BaseProtocol) -> None: - payload = StreamReader(protocol, 2**16, loop=asyncio.get_event_loop()) + payload = StreamReader(protocol, 2**18, loop=asyncio.get_event_loop()) payload.feed_data(b"{}") payload.feed_eof() headers = {"Content-Type": "text/html; charset=test"} @@ -920,7 +920,7 @@ async def test_multipart_formdata(protocol: BaseProtocol) -> None: async def test_multipart_formdata_field_missing_name(protocol: BaseProtocol) -> None: # Ensure ValueError is raised when Content-Disposition has no name - payload = StreamReader(protocol, 2**16, loop=asyncio.get_event_loop()) + payload = StreamReader(protocol, 2**18, loop=asyncio.get_event_loop()) payload.feed_data( b"-----------------------------326931944431359\r\n" b"Content-Disposition: form-data\r\n" # Missing name! @@ -972,7 +972,7 @@ async def test_multipart_formdata_headers_too_many(protocol: BaseProtocol) -> No b"--b--\r\n" ) content_type = "multipart/form-data; boundary=b" - payload = StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + payload = StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) payload.feed_data(body) payload.feed_eof() req = make_mocked_request( @@ -999,7 +999,7 @@ async def test_multipart_formdata_header_too_long(protocol: BaseProtocol) -> Non b"--b--\r\n" ) content_type = "multipart/form-data; boundary=b" - payload = StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + payload = StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) payload.feed_data(body) payload.feed_eof() req = make_mocked_request( diff --git a/tests/test_websocket_parser.py b/tests/test_websocket_parser.py index 26d1a275327..27dbae6630a 100644 --- a/tests/test_websocket_parser.py +++ b/tests/test_websocket_parser.py @@ -19,7 +19,7 @@ from aiohttp._websocket.reader import WebSocketDataQueue from aiohttp.base_protocol import BaseProtocol from aiohttp.compression_utils import ZLibBackend, ZLibBackendWrapper -from aiohttp.http import WebSocketError, WSCloseCode, WSMsgType +from aiohttp.http import HttpParser, WebSocketError, WSCloseCode, WSMsgType from aiohttp.http_websocket import ( WebSocketReader, WSMessageBinary, @@ -113,8 +113,9 @@ def build_close_frame( @pytest.fixture() def protocol(loop: asyncio.AbstractEventLoop) -> BaseProtocol: + parser = mock.create_autospec(HttpParser, spec_set=True, instance=True) transport = mock.Mock(spec_set=asyncio.Transport) - protocol = BaseProtocol(loop) + protocol = BaseProtocol(loop, parser=parser) protocol.connection_made(transport) return protocol diff --git a/tests/test_websocket_writer.py b/tests/test_websocket_writer.py index 8dffc7c015e..3b6bc98b54f 100644 --- a/tests/test_websocket_writer.py +++ b/tests/test_websocket_writer.py @@ -158,7 +158,7 @@ async def test_send_compress_cancelled( monkeypatch.setattr("aiohttp._websocket.writer.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", 1024) writer = WebSocketWriter(protocol, transport, compress=15) loop = asyncio.get_running_loop() - queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**16, loop=loop) + queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**18, loop=loop) reader = WebSocketReader(queue, 50000) # Replace executor with slow one to make race condition reproducible @@ -305,7 +305,7 @@ async def test_concurrent_messages( ): writer = WebSocketWriter(protocol, transport, compress=15) loop = asyncio.get_running_loop() - queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**16, loop=loop) + queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**18, loop=loop) reader = WebSocketReader(queue, 50000) writers = [] payloads = []