Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,11 @@ private void handleFrame() {
final byte[] comp = WebSocketBufferOps.toBytes(payload);
final byte[] plain;
try {
plain = decChain.decode(comp);
plain = decChain.decode(comp, cfg.getMaxMessageSize());
} catch (final WebSocketProtocolException wspe) {
initiateClose(wspe.closeCode, wspe.getMessage());
inbuf.clear();
return;
} catch (final Exception e) {
initiateClose(1007, "Extension decode failed");
inbuf.clear();
Expand Down Expand Up @@ -506,7 +510,10 @@ private void deliverAssembledMessage() {
byte[] data = body;
if (compressed && decChain != null) {
try {
data = decChain.decode(body);
data = decChain.decode(body, cfg.getMaxMessageSize());
} catch (final WebSocketProtocolException wspe) {
initiateClose(wspe.closeCode, wspe.getMessage());
return;
} catch (final Exception e) {
try {
listener.onError(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,21 @@ public DecodeChain(final List<WebSocketExtensionChain.Decoder> decs) {
* Decode a full message (reverse order if stacking).
*/
public byte[] decode(final byte[] data) throws Exception {
return decode(data, 0L);
}

/**
* Decode a full message (reverse order if stacking), enforcing a hard cap on
* the decoded payload size at every step. A non-positive {@code maxDecodedSize}
* disables the cap. The cap is propagated into each extension so that expanding
* extensions (e.g. permessage-deflate) abort during expansion rather than after.
*
* @since 5.7
*/
public byte[] decode(final byte[] data, final long maxDecodedSize) throws Exception {
byte[] out = data;
for (int i = decs.size() - 1; i >= 0; i--) {
out = decs.get(i).decode(out);
out = decs.get(i).decode(out, maxDecodedSize);
}
return out;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.zip.Inflater;

import org.apache.hc.core5.annotation.Internal;
import org.apache.hc.core5.websocket.exceptions.WebSocketProtocolException;
import org.apache.hc.core5.websocket.frame.FrameHeaderBits;

/**
Expand Down Expand Up @@ -144,6 +145,11 @@ public Decoder newDecoder() {

@Override
public byte[] decode(final byte[] compressedMessage) throws Exception {
return decode(compressedMessage, 0L);
}

@Override
public byte[] decode(final byte[] compressedMessage, final long maxDecodedSize) throws Exception {
final byte[] withTail;
if (compressedMessage == null || compressedMessage.length == 0) {
withTail = TAIL.clone();
Expand All @@ -156,10 +162,17 @@ public byte[] decode(final byte[] compressedMessage) throws Exception {
inf.setInput(withTail);
final ByteArrayOutputStream out = new ByteArrayOutputStream(Math.max(128, withTail.length * 2));
final byte[] buf = new byte[Math.min(16384, Math.max(1024, withTail.length * 2))];
long produced = 0L;
while (!inf.needsInput()) {
final int n = inf.inflate(buf);
if (n > 0) {
// Enforce the decoded size cap during inflation, not after, so a small
// compressed payload cannot expand into a huge buffer before we react.
if (maxDecodedSize > 0L && produced + n > maxDecodedSize) {
throw new WebSocketProtocolException(1009, "Message too big");
}
out.write(buf, 0, n);
produced += n;
} else {
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,17 @@ interface Decoder {
* Decode a full message produced with this extension.
*/
byte[] decode(byte[] payload) throws Exception;

/**
* Decode a full message, aborting as soon as the produced output exceeds
* {@code maxDecodedSize}. A non-positive limit means no limit. Implementations
* that may expand input (e.g. permessage-deflate) MUST honour the limit during
* the expansion step, not only after it, to prevent decompression-bomb attacks.
*
* @since 5.7
*/
default byte[] decode(final byte[] payload, final long maxDecodedSize) throws Exception {
return decode(payload);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.nio.charset.StandardCharsets;
import java.util.Arrays;

import org.apache.hc.core5.websocket.exceptions.WebSocketProtocolException;
import org.apache.hc.core5.websocket.frame.FrameHeaderBits;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -81,6 +84,56 @@ void roundTrip_message() throws Exception {
assertArrayEquals(plain, roundTrip);
}

@Test
void decode_withinLimit_succeeds() throws Exception {
final PerMessageDeflate pmce = new PerMessageDeflate(true, true, true, null, null);
final WebSocketExtensionChain.Encoder enc = pmce.newEncoder();
final WebSocketExtensionChain.Decoder dec = pmce.newDecoder();

final byte[] plain = "hello world hello world hello world".getBytes(StandardCharsets.UTF_8);
final byte[] wire = enc.encode(plain, true, true).payload;

// Limit comfortably above the inflated size.
final byte[] roundTrip = dec.decode(wire, plain.length + 16);
assertArrayEquals(plain, roundTrip);
}

@Test
void decode_inflationBomb_isRejectedDuringInflate() {
// A small, highly compressible payload that inflates to a much larger plaintext.
final byte[] plain = new byte[64 * 1024];
Arrays.fill(plain, (byte) 'A');

final PerMessageDeflate pmce = new PerMessageDeflate(true, true, true, null, null);
final WebSocketExtensionChain.Encoder enc = pmce.newEncoder();
final WebSocketExtensionChain.Decoder dec = pmce.newDecoder();

final byte[] wire = enc.encode(plain, true, true).payload;
// Sanity: the compressed wire form is far smaller than the inflated payload.
assertTrue(wire.length < plain.length / 4,
"test setup: payload should be highly compressible, was " + wire.length + " vs " + plain.length);

// maxDecodedSize is well below the inflated size; decode must abort with 1009.
final WebSocketProtocolException ex = assertThrows(WebSocketProtocolException.class,
() -> dec.decode(wire, 1024L));
assertEquals(1009, ex.closeCode);
assertEquals("Message too big", ex.getMessage());
}

@Test
void decode_zeroLimitMeansUnlimited() throws Exception {
final PerMessageDeflate pmce = new PerMessageDeflate(true, true, true, null, null);
final WebSocketExtensionChain.Encoder enc = pmce.newEncoder();
final WebSocketExtensionChain.Decoder dec = pmce.newDecoder();

final byte[] plain = new byte[8 * 1024];
Arrays.fill(plain, (byte) 'B');
final byte[] wire = enc.encode(plain, true, true).payload;

final byte[] roundTrip = dec.decode(wire, 0L);
assertArrayEquals(plain, roundTrip);
}

private static boolean endsWithTail(final byte[] b) {
if (b.length < 4) {
return false;
Expand Down
Loading