Skip to content

Commit 3d96d73

Browse files
committed
Break up large outbound noise messages
1 parent 542422b commit 3d96d73

File tree

2 files changed

+54
-12
lines changed

2 files changed

+54
-12
lines changed

service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import com.southernstorm.noise.protocol.CipherState;
88
import com.southernstorm.noise.protocol.CipherStatePair;
9+
import com.southernstorm.noise.protocol.Noise;
910
import io.netty.buffer.ByteBuf;
1011
import io.netty.buffer.ByteBufUtil;
1112
import io.netty.buffer.Unpooled;
@@ -16,6 +17,7 @@
1617
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
1718
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
1819
import io.netty.util.ReferenceCountUtil;
20+
import io.netty.util.concurrent.PromiseCombiner;
1921
import io.netty.util.internal.EmptyArrays;
2022
import java.util.Optional;
2123
import java.util.concurrent.CompletableFuture;
@@ -163,25 +165,35 @@ private void fail(final ChannelHandlerContext context, final Throwable cause) {
163165
@Override
164166
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
165167
throws Exception {
166-
if (message instanceof ByteBuf plaintext) {
168+
if (message instanceof ByteBuf byteBuf) {
167169
try {
168170
// TODO Buffer/consolidate Noise writes to avoid sending a bazillion tiny (or empty) frames
169171
final CipherState cipherState = cipherStatePair.getSender();
170-
final int plaintextLength = plaintext.readableBytes();
171172

172-
// We've read these bytes from a local connection; although that likely means they're backed by a heap array, the
173-
// buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a
174-
// mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC.
175-
final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()];
176-
plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes());
173+
// Server message might not fit in a single noise packet, break it up into as many chunks as we need
174+
final PromiseCombiner pc = new PromiseCombiner(context.executor());
175+
while (byteBuf.isReadable()) {
176+
final ByteBuf plaintext = byteBuf.readSlice(Math.min(
177+
// need room for a 16-byte AEAD tag
178+
Noise.MAX_PACKET_LEN - 16,
179+
byteBuf.readableBytes()));
177180

178-
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
179-
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
181+
final int plaintextLength = plaintext.readableBytes();
180182

181-
context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer)), promise);
183+
// We've read these bytes from a local connection; although that likely means they're backed by a heap array, the
184+
// buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a
185+
// mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC.
186+
final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()];
187+
plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes());
182188

189+
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
190+
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
191+
192+
pc.add(context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer))));
193+
}
194+
pc.finish(promise);
183195
} finally {
184-
ReferenceCountUtil.release(plaintext);
196+
ReferenceCountUtil.release(byteBuf);
185197
}
186198
} else {
187199
if (!(message instanceof WebSocketFrame)) {

service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import static org.junit.jupiter.api.Assertions.assertTrue;
1010

1111
import com.southernstorm.noise.protocol.CipherStatePair;
12+
import com.southernstorm.noise.protocol.Noise;
1213
import io.netty.buffer.ByteBuf;
1314
import io.netty.buffer.ByteBufUtil;
1415
import io.netty.buffer.Unpooled;
@@ -19,18 +20,22 @@
1920
import io.netty.channel.ChannelInboundHandlerAdapter;
2021
import io.netty.channel.embedded.EmbeddedChannel;
2122
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
23+
import io.netty.util.ReferenceCountUtil;
2224
import java.nio.charset.StandardCharsets;
25+
import java.util.Arrays;
2326
import java.util.concurrent.ThreadLocalRandom;
2427
import javax.annotation.Nullable;
2528
import javax.crypto.AEADBadTagException;
2629
import javax.crypto.BadPaddingException;
2730
import javax.crypto.ShortBufferException;
28-
import io.netty.util.ReferenceCountUtil;
2931
import org.junit.jupiter.api.AfterEach;
3032
import org.junit.jupiter.api.BeforeEach;
3133
import org.junit.jupiter.api.Test;
34+
import org.junit.jupiter.params.ParameterizedTest;
35+
import org.junit.jupiter.params.provider.ValueSource;
3236
import org.signal.libsignal.protocol.ecc.Curve;
3337
import org.signal.libsignal.protocol.ecc.ECKeyPair;
38+
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
3439

3540
abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
3641

@@ -254,4 +259,29 @@ void writeUnexpectedMessageType() throws Throwable {
254259
assertTrue(embeddedChannel.outboundMessages().isEmpty());
255260
}
256261

262+
@ParameterizedTest
263+
@ValueSource(ints = {Noise.MAX_PACKET_LEN - 16, Noise.MAX_PACKET_LEN - 15, Noise.MAX_PACKET_LEN * 5})
264+
void writeHugeOutboundMessage(final int plaintextLength) throws Throwable {
265+
final CipherStatePair clientCipherStatePair = doHandshake();
266+
final byte[] plaintext = TestRandomUtil.nextBytes(plaintextLength);
267+
final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(Arrays.copyOf(plaintext, plaintext.length));
268+
269+
final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer);
270+
assertTrue(writePlaintextFuture.isSuccess());
271+
272+
final byte[] decryptedPlaintext = new byte[plaintextLength];
273+
int plaintextOffset = 0;
274+
BinaryWebSocketFrame ciphertextFrame;
275+
while ((ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll()) != null) {
276+
assertTrue(ciphertextFrame.content().readableBytes() <= Noise.MAX_PACKET_LEN);
277+
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content());
278+
ciphertextFrame.release();
279+
plaintextOffset += clientCipherStatePair.getReceiver()
280+
.decryptWithAd(null, ciphertext, 0, decryptedPlaintext, plaintextOffset, ciphertext.length);
281+
}
282+
assertArrayEquals(plaintext, decryptedPlaintext);
283+
assertEquals(0, plaintextBuffer.refCnt());
284+
285+
}
286+
257287
}

0 commit comments

Comments
 (0)