Skip to content

Commit 6032764

Browse files
authored
Validate parsed message size, not base64-encoded message size
1 parent 908a418 commit 6032764

File tree

4 files changed

+55
-33
lines changed

4 files changed

+55
-33
lines changed

service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -167,15 +167,13 @@ public class MessageController {
167167
private static final String OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME = name(MessageController.class, "outgoingMessageListSizeBytes");
168168
private static final String RATE_LIMITED_MESSAGE_COUNTER_NAME = name(MessageController.class, "rateLimitedMessage");
169169

170-
private static final String REJECT_INVALID_ENVELOPE_TYPE = name(MessageController.class, "rejectInvalidEnvelopeType");
171170
private static final String SEND_MESSAGE_LATENCY_TIMER_NAME = MetricsUtil.name(MessageController.class, "sendMessageLatency");
172171

173172
private static final String EPHEMERAL_TAG_NAME = "ephemeral";
174173
private static final String SENDER_TYPE_TAG_NAME = "senderType";
175174
private static final String AUTH_TYPE_TAG_NAME = "authType";
176175
private static final String SENDER_COUNTRY_TAG_NAME = "senderCountry";
177176
private static final String RATE_LIMIT_REASON_TAG_NAME = "rateLimitReason";
178-
private static final String ENVELOPE_TYPE_TAG_NAME = "envelopeType";
179177
private static final String IDENTITY_TYPE_TAG_NAME = "identityType";
180178
private static final String ENDPOINT_TYPE_TAG_NAME = "endpoint";
181179

@@ -192,7 +190,7 @@ public class MessageController {
192190
private static final String ENDPOINT_TYPE_MULTI = "multi";
193191

194192
@VisibleForTesting
195-
static final long MAX_MESSAGE_SIZE = DataSize.kibibytes(256).toBytes();
193+
static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes();
196194
private static final long LARGE_MESSAGE_SIZE = DataSize.kibibytes(8).toBytes();
197195

198196
// The Signal desktop client (really, JavaScript in general) can handle message timestamps at most 100,000,000 days
@@ -332,14 +330,9 @@ public Response sendMessage(@ReadOnly @Auth final Optional<AuthenticatedDevice>
332330
int totalContentLength = 0;
333331

334332
for (final IncomingMessage message : messages.messages()) {
335-
int contentLength = 0;
336-
337-
if (StringUtils.isNotEmpty(message.content())) {
338-
contentLength += message.content().length();
339-
}
333+
final int contentLength = decodedSize(message.content());
340334

341335
validateContentLength(contentLength, false, isSyncMessage, isStory, userAgent);
342-
validateEnvelopeType(message.type(), userAgent);
343336

344337
totalContentLength += contentLength;
345338
}
@@ -971,12 +964,18 @@ private void validateContentLength(final int contentLength,
971964
}
972965
}
973966

974-
private void validateEnvelopeType(final int type, final String userAgent) {
975-
if (type == Type.SERVER_DELIVERY_RECEIPT_VALUE) {
976-
Metrics.counter(REJECT_INVALID_ENVELOPE_TYPE,
977-
Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), Tag.of(ENVELOPE_TYPE_TAG_NAME, String.valueOf(type))))
978-
.increment();
979-
throw new BadRequestException("reserved envelope type");
967+
@VisibleForTesting
968+
static int decodedSize(final String base64) {
969+
final int padding;
970+
971+
if (StringUtils.endsWith(base64, "==")) {
972+
padding = 2;
973+
} else if (StringUtils.endsWith(base64, "=")) {
974+
padding = 1;
975+
} else {
976+
padding = 0;
980977
}
978+
979+
return ((StringUtils.length(base64) - padding) * 3) / 4;
981980
}
982981
}

service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,21 @@
55
package org.whispersystems.textsecuregcm.entities;
66

77
import com.google.protobuf.ByteString;
8+
import io.micrometer.core.instrument.Metrics;
9+
import jakarta.validation.constraints.AssertTrue;
810
import java.util.Base64;
911
import javax.annotation.Nullable;
1012
import org.apache.commons.lang3.StringUtils;
1113
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
1214
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
15+
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
1316
import org.whispersystems.textsecuregcm.storage.Account;
1417

1518
public record IncomingMessage(int type, byte destinationDeviceId, int destinationRegistrationId, String content) {
1619

20+
private static final String REJECT_INVALID_ENVELOPE_TYPE_COUNTER_NAME =
21+
MetricsUtil.name(IncomingMessage.class, "rejectInvalidEnvelopeType");
22+
1723
public MessageProtos.Envelope toEnvelope(final ServiceIdentifier destinationIdentifier,
1824
@Nullable Account sourceAccount,
1925
@Nullable Byte sourceDeviceId,
@@ -23,15 +29,10 @@ public MessageProtos.Envelope toEnvelope(final ServiceIdentifier destinationIden
2329
final boolean urgent,
2430
@Nullable byte[] reportSpamToken) {
2531

26-
final MessageProtos.Envelope.Type envelopeType = MessageProtos.Envelope.Type.forNumber(type());
27-
28-
if (envelopeType == null) {
29-
throw new IllegalArgumentException("Bad envelope type: " + type());
30-
}
31-
3232
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder();
3333

34-
envelopeBuilder.setType(envelopeType)
34+
envelopeBuilder
35+
.setType(MessageProtos.Envelope.Type.forNumber(type))
3536
.setClientTimestamp(timestamp)
3637
.setServerTimestamp(System.currentTimeMillis())
3738
.setDestinationServiceId(destinationIdentifier.toServiceIdentifierString())
@@ -55,4 +56,17 @@ public MessageProtos.Envelope toEnvelope(final ServiceIdentifier destinationIden
5556

5657
return envelopeBuilder.build();
5758
}
59+
60+
@AssertTrue
61+
public boolean isValidEnvelopeType() {
62+
if (type() == MessageProtos.Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE ||
63+
MessageProtos.Envelope.Type.forNumber(type()) == null) {
64+
65+
Metrics.counter(REJECT_INVALID_ENVELOPE_TYPE_COUNTER_NAME).increment();
66+
67+
return false;
68+
}
69+
70+
return true;
71+
}
5872
}

service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessageList.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,16 @@
88

99
import com.fasterxml.jackson.annotation.JsonCreator;
1010
import com.fasterxml.jackson.annotation.JsonProperty;
11-
12-
import jakarta.validation.constraints.Max;
13-
import jakarta.validation.constraints.PositiveOrZero;
14-
import org.whispersystems.textsecuregcm.controllers.MessageController;
15-
1611
import io.micrometer.core.instrument.Counter;
1712
import io.micrometer.core.instrument.Metrics;
18-
19-
import java.util.List;
2013
import jakarta.validation.Valid;
2114
import jakarta.validation.constraints.AssertTrue;
15+
import jakarta.validation.constraints.Max;
2216
import jakarta.validation.constraints.NotNull;
17+
import jakarta.validation.constraints.PositiveOrZero;
18+
import java.util.List;
19+
import java.util.Objects;
20+
import org.whispersystems.textsecuregcm.controllers.MessageController;
2321

2422
public record IncomingMessageList(@NotNull
2523
@Valid
@@ -49,10 +47,14 @@ public IncomingMessageList(@JsonProperty("messages") @NotNull @Valid List<@NotNu
4947

5048
@AssertTrue
5149
public boolean hasNoDuplicateRecipients() {
52-
boolean valid = messages.stream().filter(m -> m != null).map(IncomingMessage::destinationDeviceId).distinct().count() == messages.size();
50+
final boolean valid = messages.stream()
51+
.filter(Objects::nonNull)
52+
.map(IncomingMessage::destinationDeviceId).distinct().count() == messages.size();
53+
5354
if (!valid) {
5455
REJECT_DUPLICATE_RECIPIENT_COUNTER.increment();
5556
}
57+
5658
return valid;
5759
}
5860
}

service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import static org.hamcrest.CoreMatchers.is;
1010
import static org.hamcrest.CoreMatchers.not;
1111
import static org.hamcrest.MatcherAssert.assertThat;
12-
import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder;
1312
import static org.junit.jupiter.api.Assertions.assertEquals;
1413
import static org.junit.jupiter.api.Assertions.assertFalse;
1514
import static org.junit.jupiter.api.Assertions.assertNotNull;
@@ -48,7 +47,6 @@
4847
import java.time.temporal.ChronoUnit;
4948
import java.util.Arrays;
5049
import java.util.Base64;
51-
import java.util.Collection;
5250
import java.util.Collections;
5351
import java.util.HashSet;
5452
import java.util.List;
@@ -1141,7 +1139,7 @@ void testValidateEnvelopeType(String payloadFilename, boolean expectOk) throws E
11411139
assertEquals(200, response.getStatus());
11421140
verify(messageSender).sendMessages(any(), any());
11431141
} else {
1144-
assertEquals(400, response.getStatus());
1142+
assertEquals(422, response.getStatus());
11451143
verify(messageSender, never()).sendMessages(any(), any());
11461144
}
11471145
}
@@ -1662,4 +1660,13 @@ private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UU
16621660
return builder.build();
16631661
}
16641662

1663+
@Test
1664+
void decodedSize() {
1665+
for (int size = MessageController.MAX_MESSAGE_SIZE - 3; size <= MessageController.MAX_MESSAGE_SIZE + 3; size++) {
1666+
final byte[] bytes = TestRandomUtil.nextBytes(size);
1667+
final String base64Encoded = Base64.getEncoder().encodeToString(bytes);
1668+
1669+
assertEquals(bytes.length, MessageController.decodedSize(base64Encoded));
1670+
}
1671+
}
16651672
}

0 commit comments

Comments
 (0)