Skip to content

Commit 30c194c

Browse files
authored
Exclude RateLimitExceededException from fail-open checks
1 parent cc7b030 commit 30c194c

File tree

2 files changed

+84
-1
lines changed

2 files changed

+84
-1
lines changed

service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ public void validate(final String key, final int amount) throws RateLimitExceede
6262
throw new RateLimitExceededException(retryAfter);
6363
}
6464
} catch (final Exception e) {
65+
if (e instanceof RateLimitExceededException rateLimitExceededException) {
66+
throw rateLimitExceededException;
67+
}
68+
6569
if (!config.failOpen()) {
6670
throw e;
6771
}
@@ -81,10 +85,15 @@ public CompletionStage<Void> validateAsync(final String key, final int amount) {
8185
return failedFuture(new RateLimitExceededException(retryAfter));
8286
})
8387
.exceptionally(throwable -> {
88+
if (ExceptionUtils.unwrap(throwable) instanceof RateLimitExceededException rateLimitExceededException) {
89+
throw ExceptionUtils.wrap(rateLimitExceededException);
90+
}
91+
8492
if (config.failOpen()) {
8593
return null;
8694
}
87-
throw ExceptionUtils.wrap(new RateLimitExceededException(null));
95+
96+
throw ExceptionUtils.wrap(throwable);
8897
});
8998
}
9099

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright 2025 Signal Messenger, LLC
3+
* SPDX-License-Identifier: AGPL-3.0-only
4+
*/
5+
6+
package org.whispersystems.textsecuregcm.limits;
7+
8+
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
9+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
10+
import static org.junit.jupiter.api.Assertions.assertThrows;
11+
12+
import io.lettuce.core.ScriptOutputType;
13+
import java.io.IOException;
14+
import java.time.Duration;
15+
import java.time.Instant;
16+
import java.util.concurrent.CompletionException;
17+
import org.apache.commons.lang3.RandomStringUtils;
18+
import org.junit.jupiter.api.BeforeEach;
19+
import org.junit.jupiter.api.extension.RegisterExtension;
20+
import org.junit.jupiter.params.ParameterizedTest;
21+
import org.junit.jupiter.params.provider.ValueSource;
22+
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
23+
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
24+
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
25+
import org.whispersystems.textsecuregcm.util.TestClock;
26+
27+
class StaticRateLimiterTest {
28+
29+
private ClusterLuaScript validateRateLimitScript;
30+
31+
private static final TestClock CLOCK = TestClock.pinned(Instant.now());
32+
33+
@RegisterExtension
34+
private static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
35+
36+
@BeforeEach
37+
void setUp() throws IOException {
38+
validateRateLimitScript = ClusterLuaScript.fromResource(
39+
REDIS_CLUSTER_EXTENSION.getRedisCluster(), "lua/validate_rate_limit.lua", ScriptOutputType.INTEGER);
40+
}
41+
42+
@ParameterizedTest
43+
@ValueSource(booleans = {true, false})
44+
void validate(final boolean failOpen) {
45+
final StaticRateLimiter rateLimiter = new StaticRateLimiter("test",
46+
new RateLimiterConfig(1, Duration.ofHours(1), failOpen),
47+
validateRateLimitScript,
48+
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
49+
CLOCK);
50+
51+
final String key = RandomStringUtils.insecure().nextAlphanumeric(16);
52+
53+
assertDoesNotThrow(() -> rateLimiter.validate(key));
54+
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key));
55+
}
56+
57+
@ParameterizedTest
58+
@ValueSource(booleans = {true, false})
59+
void validateAsync(final boolean failOpen) {
60+
final StaticRateLimiter rateLimiter = new StaticRateLimiter("test",
61+
new RateLimiterConfig(1, Duration.ofHours(1), failOpen),
62+
validateRateLimitScript,
63+
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
64+
CLOCK);
65+
66+
final String key = RandomStringUtils.insecure().nextAlphanumeric(16);
67+
68+
assertDoesNotThrow(() -> rateLimiter.validateAsync(key).toCompletableFuture().join());
69+
final CompletionException completionException =
70+
assertThrows(CompletionException.class, () -> rateLimiter.validateAsync(key).toCompletableFuture().join());
71+
72+
assertInstanceOf(RateLimitExceededException.class, completionException.getCause());
73+
}
74+
}

0 commit comments

Comments
 (0)