diff --git a/httpclient5-websocket/src/main/java/org/apache/hc/client5/http/websocket/transport/WebSocketSessionEngine.java b/httpclient5-websocket/src/main/java/org/apache/hc/client5/http/websocket/transport/WebSocketSessionEngine.java index 149a0074c8..f2e9b3a27c 100644 --- a/httpclient5-websocket/src/main/java/org/apache/hc/client5/http/websocket/transport/WebSocketSessionEngine.java +++ b/httpclient5-websocket/src/main/java/org/apache/hc/client5/http/websocket/transport/WebSocketSessionEngine.java @@ -404,7 +404,11 @@ private void handleFrame() { final byte[] comp = WebSocketBufferOps.toBytes(payload); final byte[] plain; try { - plain = decChain.decode(comp); + plain = decChain.decode(comp, cfg.getMaxMessageSize()); + } catch (final WebSocketProtocolException wspe) { + initiateClose(wspe.closeCode, wspe.getMessage()); + inbuf.clear(); + return; } catch (final Exception e) { initiateClose(1007, "Extension decode failed"); inbuf.clear(); @@ -506,7 +510,10 @@ private void deliverAssembledMessage() { byte[] data = body; if (compressed && decChain != null) { try { - data = decChain.decode(body); + data = decChain.decode(body, cfg.getMaxMessageSize()); + } catch (final WebSocketProtocolException wspe) { + initiateClose(wspe.closeCode, wspe.getMessage()); + return; } catch (final Exception e) { try { listener.onError(e); diff --git a/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/extension/ExtensionChain.java b/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/extension/ExtensionChain.java index 03190a9eb4..3b73baf01b 100644 --- a/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/extension/ExtensionChain.java +++ b/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/extension/ExtensionChain.java @@ -125,9 +125,21 @@ public DecodeChain(final List decs) { * Decode a full message (reverse order if stacking). */ public byte[] decode(final byte[] data) throws Exception { + return decode(data, 0L); + } + + /** + * Decode a full message (reverse order if stacking), enforcing a hard cap on + * the decoded payload size at every step. A non-positive {@code maxDecodedSize} + * disables the cap. The cap is propagated into each extension so that expanding + * extensions (e.g. permessage-deflate) abort during expansion rather than after. + * + * @since 5.7 + */ + public byte[] decode(final byte[] data, final long maxDecodedSize) throws Exception { byte[] out = data; for (int i = decs.size() - 1; i >= 0; i--) { - out = decs.get(i).decode(out); + out = decs.get(i).decode(out, maxDecodedSize); } return out; } diff --git a/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/extension/PerMessageDeflate.java b/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/extension/PerMessageDeflate.java index de2612b0ed..002b9d486e 100644 --- a/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/extension/PerMessageDeflate.java +++ b/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/extension/PerMessageDeflate.java @@ -31,6 +31,7 @@ import java.util.zip.Inflater; import org.apache.hc.core5.annotation.Internal; +import org.apache.hc.core5.websocket.exceptions.WebSocketProtocolException; import org.apache.hc.core5.websocket.frame.FrameHeaderBits; /** @@ -144,6 +145,11 @@ public Decoder newDecoder() { @Override public byte[] decode(final byte[] compressedMessage) throws Exception { + return decode(compressedMessage, 0L); + } + + @Override + public byte[] decode(final byte[] compressedMessage, final long maxDecodedSize) throws Exception { final byte[] withTail; if (compressedMessage == null || compressedMessage.length == 0) { withTail = TAIL.clone(); @@ -156,10 +162,17 @@ public byte[] decode(final byte[] compressedMessage) throws Exception { inf.setInput(withTail); final ByteArrayOutputStream out = new ByteArrayOutputStream(Math.max(128, withTail.length * 2)); final byte[] buf = new byte[Math.min(16384, Math.max(1024, withTail.length * 2))]; + long produced = 0L; while (!inf.needsInput()) { final int n = inf.inflate(buf); if (n > 0) { + // Enforce the decoded size cap during inflation, not after, so a small + // compressed payload cannot expand into a huge buffer before we react. + if (maxDecodedSize > 0L && produced + n > maxDecodedSize) { + throw new WebSocketProtocolException(1009, "Message too big"); + } out.write(buf, 0, n); + produced += n; } else { break; } diff --git a/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/extension/WebSocketExtensionChain.java b/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/extension/WebSocketExtensionChain.java index 006ee8eeab..92ad719015 100644 --- a/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/extension/WebSocketExtensionChain.java +++ b/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/extension/WebSocketExtensionChain.java @@ -76,5 +76,17 @@ interface Decoder { * Decode a full message produced with this extension. */ byte[] decode(byte[] payload) throws Exception; + + /** + * Decode a full message, aborting as soon as the produced output exceeds + * {@code maxDecodedSize}. A non-positive limit means no limit. Implementations + * that may expand input (e.g. permessage-deflate) MUST honour the limit during + * the expansion step, not only after it, to prevent decompression-bomb attacks. + * + * @since 5.7 + */ + default byte[] decode(final byte[] payload, final long maxDecodedSize) throws Exception { + return decode(payload); + } } } diff --git a/httpclient5-websocket/src/test/java/org/apache/hc/core5/websocket/extension/MessageDeflateTest.java b/httpclient5-websocket/src/test/java/org/apache/hc/core5/websocket/extension/MessageDeflateTest.java index 84a3b92f6b..284566dac4 100644 --- a/httpclient5-websocket/src/test/java/org/apache/hc/core5/websocket/extension/MessageDeflateTest.java +++ b/httpclient5-websocket/src/test/java/org/apache/hc/core5/websocket/extension/MessageDeflateTest.java @@ -30,10 +30,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import org.apache.hc.core5.websocket.exceptions.WebSocketProtocolException; import org.apache.hc.core5.websocket.frame.FrameHeaderBits; import org.junit.jupiter.api.Test; @@ -81,6 +84,56 @@ void roundTrip_message() throws Exception { assertArrayEquals(plain, roundTrip); } + @Test + void decode_withinLimit_succeeds() throws Exception { + final PerMessageDeflate pmce = new PerMessageDeflate(true, true, true, null, null); + final WebSocketExtensionChain.Encoder enc = pmce.newEncoder(); + final WebSocketExtensionChain.Decoder dec = pmce.newDecoder(); + + final byte[] plain = "hello world hello world hello world".getBytes(StandardCharsets.UTF_8); + final byte[] wire = enc.encode(plain, true, true).payload; + + // Limit comfortably above the inflated size. + final byte[] roundTrip = dec.decode(wire, plain.length + 16); + assertArrayEquals(plain, roundTrip); + } + + @Test + void decode_inflationBomb_isRejectedDuringInflate() { + // A small, highly compressible payload that inflates to a much larger plaintext. + final byte[] plain = new byte[64 * 1024]; + Arrays.fill(plain, (byte) 'A'); + + final PerMessageDeflate pmce = new PerMessageDeflate(true, true, true, null, null); + final WebSocketExtensionChain.Encoder enc = pmce.newEncoder(); + final WebSocketExtensionChain.Decoder dec = pmce.newDecoder(); + + final byte[] wire = enc.encode(plain, true, true).payload; + // Sanity: the compressed wire form is far smaller than the inflated payload. + assertTrue(wire.length < plain.length / 4, + "test setup: payload should be highly compressible, was " + wire.length + " vs " + plain.length); + + // maxDecodedSize is well below the inflated size; decode must abort with 1009. + final WebSocketProtocolException ex = assertThrows(WebSocketProtocolException.class, + () -> dec.decode(wire, 1024L)); + assertEquals(1009, ex.closeCode); + assertEquals("Message too big", ex.getMessage()); + } + + @Test + void decode_zeroLimitMeansUnlimited() throws Exception { + final PerMessageDeflate pmce = new PerMessageDeflate(true, true, true, null, null); + final WebSocketExtensionChain.Encoder enc = pmce.newEncoder(); + final WebSocketExtensionChain.Decoder dec = pmce.newDecoder(); + + final byte[] plain = new byte[8 * 1024]; + Arrays.fill(plain, (byte) 'B'); + final byte[] wire = enc.encode(plain, true, true).payload; + + final byte[] roundTrip = dec.decode(wire, 0L); + assertArrayEquals(plain, roundTrip); + } + private static boolean endsWithTail(final byte[] b) { if (b.length < 4) { return false;