Skip to content

Commit a1d9c4c

Browse files
authored
Check presence before updating last message versionstamp
1 parent 4acb3b5 commit a1d9c4c

File tree

2 files changed

+232
-51
lines changed

2 files changed

+232
-51
lines changed

service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java

Lines changed: 85 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,63 +8,106 @@
88
import com.apple.foundationdb.tuple.Versionstamp;
99
import com.google.common.annotations.VisibleForTesting;
1010
import com.google.common.hash.Hashing;
11+
import java.time.Clock;
12+
import java.util.List;
1113
import java.util.Map;
14+
import java.util.Optional;
1215
import java.util.concurrent.CompletableFuture;
1316
import java.util.concurrent.Executor;
1417
import java.util.function.Function;
1518
import org.whispersystems.textsecuregcm.entities.MessageProtos;
1619
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
20+
import org.whispersystems.textsecuregcm.util.Conversions;
21+
import org.whispersystems.textsecuregcm.util.Pair;
1722

1823
/// An implementation of a message store backed by FoundationDB.
1924
///
2025
/// @implNote The layout of elements in FoundationDB is as follows:
2126
/// * messages
2227
/// * {aci}
23-
/// * last => versionstamp
28+
/// * messageAvailableWatch => versionstamp
2429
/// * {deviceId}
30+
/// * presence => server_id | last_seen_seconds_since_epoch
2531
/// * queue
2632
/// * {versionstamp_1} => envelope_1
2733
/// * {versionstamp_2} => envelope_2
2834
public class FoundationDbMessageStore {
2935

3036
private final Database[] databases;
31-
private static final Subspace MESSAGES_SUBSPACE = new Subspace(Tuple.from("M"));
3237
private final Executor executor;
38+
private final Clock clock;
39+
40+
private static final Subspace MESSAGES_SUBSPACE = new Subspace(Tuple.from("M"));
41+
private static final int MAX_SECONDS_SINCE_UPDATE_FOR_PRESENCE = 300;
3342

34-
public FoundationDbMessageStore(final Database[] databases, final Executor executor) {
43+
public FoundationDbMessageStore(final Database[] databases, final Executor executor, final Clock clock) {
3544
this.databases = databases;
3645
this.executor = executor;
46+
this.clock = clock;
3747
}
3848

39-
/**
40-
* Insert a message bundle for a set of devices belonging to a single recipient
41-
*
42-
* @param aci destination account identifier
43-
* @param messagesByDeviceId a map of deviceId => message envelope
44-
* @return a future that completes with a {@link Versionstamp} of the committed transaction
45-
*/
46-
public CompletableFuture<Versionstamp> insert(final AciServiceIdentifier aci,
49+
/// Insert a message bundle for a set of devices belonging to a single recipient. A message may not be inserted if the
50+
/// device is not present (as determined from its presence key) and the message is ephemeral. If all messages in the
51+
/// bundle don't end up being inserted, we won't return a versionstamp since the transaction was read-only.
52+
///
53+
/// @param aci destination account identifier
54+
/// @param messagesByDeviceId a map of deviceId => message envelope
55+
/// @return a future that completes with a [Versionstamp] of the committed transaction if at least one message was
56+
/// inserted
57+
public CompletableFuture<Optional<Versionstamp>> insert(final AciServiceIdentifier aci,
4758
final Map<Byte, MessageProtos.Envelope> messagesByDeviceId) {
4859
// We use Database#runAsync and not Database#run here because the latter would commit the transaction synchronously
4960
// and we would like to avoid any potential blocking in native code that could unexpectedly pin virtual threads. See https://forums.foundationdb.org/t/fdbdatabase-usage-from-java-api/593/2
5061
// for details.
51-
return getShardForAci(aci).runAsync(transaction -> {
52-
insert(aci, messagesByDeviceId, transaction);
53-
return CompletableFuture.completedFuture(transaction.getVersionstamp());
54-
})
62+
return getShardForAci(aci).runAsync(transaction -> insert(aci, messagesByDeviceId, transaction)
63+
.thenApply(hasMutations -> {
64+
if (hasMutations) {
65+
return transaction.getVersionstamp();
66+
}
67+
return CompletableFuture.completedFuture((byte[]) null);
68+
}))
5569
.thenComposeAsync(Function.identity(), executor)
56-
.thenApply(Versionstamp::complete);
70+
.thenApply(versionstampBytes -> Optional.ofNullable(versionstampBytes).map(Versionstamp::complete));
5771
}
5872

59-
private void insert(final AciServiceIdentifier aci, final Map<Byte, MessageProtos.Envelope> messagesByDeviceId,
73+
private CompletableFuture<Boolean> insert(final AciServiceIdentifier aci,
74+
final Map<Byte, MessageProtos.Envelope> messagesByDeviceId,
6075
final Transaction transaction) {
61-
messagesByDeviceId.forEach((deviceId, message) -> {
62-
final Subspace deviceQueueSubspace = getDeviceQueueSubspace(aci, deviceId);
63-
transaction.mutate(MutationType.SET_VERSIONSTAMPED_KEY, deviceQueueSubspace.packWithVersionstamp(Tuple.from(
64-
Versionstamp.incomplete())), message.toByteArray());
65-
});
66-
transaction.mutate(MutationType.SET_VERSIONSTAMPED_VALUE, getLastMessageKey(aci),
67-
Tuple.from(Versionstamp.incomplete()).packWithVersionstamp());
76+
final List<CompletableFuture<Pair<Boolean, Boolean>>> messageInsertFutures = messagesByDeviceId.entrySet()
77+
.stream()
78+
.map(e -> {
79+
final byte deviceId = e.getKey();
80+
final MessageProtos.Envelope message = e.getValue();
81+
final byte[] presenceKey = getPresenceKey(aci, deviceId);
82+
return transaction.get(presenceKey)
83+
.thenApply(this::isClientPresent)
84+
.thenApply(isPresent -> {
85+
boolean hasMutations = false;
86+
if (isPresent || !message.getEphemeral()) {
87+
final Subspace deviceQueueSubspace = getDeviceQueueSubspace(aci, deviceId);
88+
transaction.mutate(MutationType.SET_VERSIONSTAMPED_KEY,
89+
deviceQueueSubspace.packWithVersionstamp(Tuple.from(
90+
Versionstamp.incomplete())), message.toByteArray());
91+
hasMutations = true;
92+
}
93+
return new Pair<>(isPresent, hasMutations);
94+
});
95+
})
96+
.toList();
97+
return CompletableFuture.allOf(messageInsertFutures.toArray(CompletableFuture[]::new))
98+
.thenApply(_ -> {
99+
final boolean anyClientPresent = messageInsertFutures
100+
.stream()
101+
.anyMatch(future -> future.join().first());
102+
final boolean hasMutations = messageInsertFutures
103+
.stream()
104+
.anyMatch(future -> future.join().second());
105+
if (anyClientPresent && hasMutations) {
106+
transaction.mutate(MutationType.SET_VERSIONSTAMPED_VALUE, getMessagesAvailableWatchKey(aci),
107+
Tuple.from(Versionstamp.incomplete()).packWithVersionstamp());
108+
}
109+
return hasMutations;
110+
});
68111
}
69112

70113
private Database getShardForAci(final AciServiceIdentifier aci) {
@@ -90,8 +133,25 @@ private Subspace getAccountSubspace(final AciServiceIdentifier aci) {
90133
}
91134

92135
@VisibleForTesting
93-
byte[] getLastMessageKey(final AciServiceIdentifier aci) {
136+
byte[] getMessagesAvailableWatchKey(final AciServiceIdentifier aci) {
94137
return getAccountSubspace(aci).pack("l");
95138
}
96139

140+
@VisibleForTesting
141+
byte[] getPresenceKey(final AciServiceIdentifier aci, final byte deviceId) {
142+
return getDeviceQueueSubspace(aci, deviceId).pack("p");
143+
}
144+
145+
@VisibleForTesting
146+
boolean isClientPresent(final byte[] presenceValueBytes) {
147+
if (presenceValueBytes == null) {
148+
return false;
149+
}
150+
final long presenceValue = Conversions.byteArrayToLong(presenceValueBytes);
151+
// The presence value is a long with the higher order 16 bits containing a server id, and the lower 48 bits
152+
// containing the timestamp (seconds since epoch) that the client updates periodically.
153+
final long lastSeenSecondsSinceEpoch = presenceValue & 0x0000ffffffffffffL;
154+
return (clock.instant().getEpochSecond() - lastSeenSecondsSinceEpoch) <= MAX_SECONDS_SINCE_UPDATE_FOR_PRESENCE;
155+
}
156+
97157
}
Lines changed: 147 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,42 @@
11
package org.whispersystems.textsecuregcm.storage.foundationdb;
22

3+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
34
import static org.junit.jupiter.api.Assertions.assertEquals;
45
import static org.junit.jupiter.api.Assertions.assertNotNull;
6+
import static org.junit.jupiter.api.Assertions.assertTrue;
57

68
import com.apple.foundationdb.Database;
79
import com.apple.foundationdb.tuple.Tuple;
810
import com.apple.foundationdb.tuple.Versionstamp;
911
import com.google.protobuf.ByteString;
1012
import com.google.protobuf.InvalidProtocolBufferException;
1113
import java.io.UncheckedIOException;
14+
import java.time.Clock;
15+
import java.time.Instant;
16+
import java.time.ZoneId;
1217
import java.util.List;
1318
import java.util.Map;
19+
import java.util.Objects;
20+
import java.util.Optional;
1421
import java.util.UUID;
1522
import java.util.concurrent.Executors;
1623
import java.util.function.Function;
1724
import java.util.stream.Collectors;
1825
import java.util.stream.IntStream;
26+
import java.util.stream.Stream;
1927
import org.junit.jupiter.api.BeforeEach;
2028
import org.junit.jupiter.api.Test;
2129
import org.junit.jupiter.api.Timeout;
2230
import org.junit.jupiter.api.extension.RegisterExtension;
31+
import org.junit.jupiter.params.ParameterizedTest;
32+
import org.junit.jupiter.params.provider.Arguments;
33+
import org.junit.jupiter.params.provider.MethodSource;
34+
import org.junit.jupiter.params.provider.ValueSource;
2335
import org.whispersystems.textsecuregcm.entities.MessageProtos;
2436
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
2537
import org.whispersystems.textsecuregcm.storage.Device;
2638
import org.whispersystems.textsecuregcm.storage.FoundationDbExtension;
39+
import org.whispersystems.textsecuregcm.util.Conversions;
2740
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
2841

2942
@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
@@ -34,50 +47,139 @@ class FoundationDbMessageStoreTest {
3447

3548
private FoundationDbMessageStore foundationDbMessageStore;
3649

50+
private static final Clock CLOCK = Clock.fixed(Instant.ofEpochSecond(500), ZoneId.of("UTC"));
51+
3752
@BeforeEach
3853
void setup() {
3954
foundationDbMessageStore = new FoundationDbMessageStore(
4055
new Database[]{FOUNDATION_DB_EXTENSION.getDatabase()},
41-
Executors.newVirtualThreadPerTaskExecutor());
56+
Executors.newVirtualThreadPerTaskExecutor(),
57+
CLOCK);
4258
}
4359

44-
@Test
45-
void insert() {
60+
@ParameterizedTest
61+
@MethodSource
62+
void insert(final long presenceUpdatedBeforeSeconds, final boolean ephemeral, final boolean expectMessagesInserted,
63+
final boolean expectVersionstampUpdated) {
4664
final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID());
4765
final List<Byte> deviceIds = IntStream.range(Device.PRIMARY_ID, Device.PRIMARY_ID + 6)
4866
.mapToObj(i -> (byte) i)
4967
.toList();
68+
deviceIds.forEach(deviceId -> writePresenceKey(aci, deviceId, 1, presenceUpdatedBeforeSeconds));
5069
final Map<Byte, MessageProtos.Envelope> messagesByDeviceId = deviceIds.stream()
51-
.collect(Collectors.toMap(Function.identity(), _ -> generateRandomMessage()));
52-
final Versionstamp versionstamp = foundationDbMessageStore.insert(aci, messagesByDeviceId).join();
70+
.collect(Collectors.toMap(Function.identity(), _ -> generateRandomMessage(ephemeral)));
71+
final Optional<Versionstamp> versionstamp = foundationDbMessageStore.insert(aci, messagesByDeviceId).join();
5372
assertNotNull(versionstamp);
5473

55-
final Map<Byte, MessageProtos.Envelope> storedMessagesByDeviceId = deviceIds.stream()
56-
.collect(Collectors.toMap(Function.identity(), deviceId -> {
57-
try {
58-
return MessageProtos.Envelope.parseFrom(getMessageByVersionstamp(aci, deviceId, versionstamp));
59-
} catch (final InvalidProtocolBufferException e) {
60-
throw new UncheckedIOException(e);
61-
}
62-
}));
63-
64-
assertEquals(messagesByDeviceId, storedMessagesByDeviceId);
65-
assertEquals(versionstamp, getLastMessageVersionstamp(aci),
66-
"last message versionstamp should be the versionstamp of the last insert transaction");
74+
if (expectMessagesInserted) {
75+
assertTrue(versionstamp.isPresent());
76+
final Map<Byte, MessageProtos.Envelope> storedMessagesByDeviceId = deviceIds.stream()
77+
.collect(Collectors.toMap(Function.identity(), deviceId -> {
78+
try {
79+
return MessageProtos.Envelope.parseFrom(getMessageByVersionstamp(aci, deviceId, versionstamp.get()));
80+
} catch (final InvalidProtocolBufferException e) {
81+
throw new UncheckedIOException(e);
82+
}
83+
}));
84+
85+
assertEquals(messagesByDeviceId, storedMessagesByDeviceId);
86+
} else {
87+
assertTrue(versionstamp.isEmpty());
88+
}
89+
90+
if (expectVersionstampUpdated) {
91+
assertEquals(versionstamp, getMessagesAvailableWatch(aci),
92+
"messages available versionstamp should be the versionstamp of the last insert transaction");
93+
} else {
94+
assertTrue(getMessagesAvailableWatch(aci).isEmpty());
95+
}
96+
}
97+
98+
private static Stream<Arguments> insert() {
99+
return Stream.of(
100+
Arguments.argumentSet("Non-ephemeral messages with all devices online",
101+
10L, false, true, true),
102+
Arguments.argumentSet(
103+
"Ephemeral messages with presence updated exactly at the second before which the device would be considered offline",
104+
300L, true, true, true),
105+
Arguments.argumentSet("Non-ephemeral messages for with all devices offline",
106+
310L, false, true, false),
107+
Arguments.argumentSet("Ephemeral messages with all devices offline",
108+
310L, true, false, false)
109+
);
67110
}
68111

69112
@Test
70113
void versionstampCorrectlyUpdatedOnMultipleInserts() {
71114
final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID());
72-
foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, generateRandomMessage())).join();
73-
final Versionstamp secondMessageVersionstamp = foundationDbMessageStore.insert(aci,
74-
Map.of(Device.PRIMARY_ID, generateRandomMessage())).join();
75-
assertEquals(secondMessageVersionstamp, getLastMessageVersionstamp(aci));
115+
writePresenceKey(aci, Device.PRIMARY_ID, 1, 10L);
116+
foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, generateRandomMessage(false))).join();
117+
final Optional<Versionstamp> secondMessageVersionstamp = foundationDbMessageStore.insert(aci,
118+
Map.of(Device.PRIMARY_ID, generateRandomMessage(false))).join();
119+
assertEquals(secondMessageVersionstamp, getMessagesAvailableWatch(aci));
76120
}
77121

78-
private static MessageProtos.Envelope generateRandomMessage() {
122+
@ParameterizedTest
123+
@ValueSource(booleans = {true, false})
124+
void insertOnlyOneDevicePresent(final boolean ephemeral) {
125+
final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID());
126+
final List<Byte> deviceIds = IntStream.range(Device.PRIMARY_ID, Device.PRIMARY_ID + 6)
127+
.mapToObj(i -> (byte) i)
128+
.toList();
129+
// Only 1 device has a recent presence, the others do not have presence keys present.
130+
writePresenceKey(aci, Device.PRIMARY_ID, 1, 10L);
131+
final Map<Byte, MessageProtos.Envelope> messagesByDeviceId = deviceIds.stream()
132+
.collect(Collectors.toMap(Function.identity(), _ -> generateRandomMessage(ephemeral)));
133+
final Optional<Versionstamp> versionstamp = foundationDbMessageStore.insert(aci, messagesByDeviceId).join();
134+
assertNotNull(versionstamp);
135+
assertTrue(versionstamp.isPresent(),
136+
"versionstamp should be present since at least one message should be inserted");
137+
138+
assertArrayEquals(
139+
messagesByDeviceId.get(Device.PRIMARY_ID).toByteArray(),
140+
getMessageByVersionstamp(aci, Device.PRIMARY_ID, versionstamp.get()),
141+
"Message for primary should always be stored since it has a recently updated presence");
142+
143+
if (ephemeral) {
144+
assertTrue(IntStream.range(Device.PRIMARY_ID + 1, Device.PRIMARY_ID + 6)
145+
.mapToObj(deviceId -> getMessageByVersionstamp(aci, (byte) deviceId, versionstamp.get()))
146+
.allMatch(Objects::isNull), "Ephemeral messages for non-present devices must not be stored");
147+
} else {
148+
IntStream.range(Device.PRIMARY_ID + 1, Device.PRIMARY_ID)
149+
.forEach(deviceId -> {
150+
final byte[] messageBytes = getMessageByVersionstamp(aci, (byte) deviceId, versionstamp.get());
151+
assertEquals(messagesByDeviceId.get((byte) deviceId).toByteArray(), messageBytes,
152+
"Non-ephemeral messages must always be stored");
153+
});
154+
}
155+
156+
}
157+
158+
@ParameterizedTest
159+
@MethodSource
160+
void isClientPresent(final byte[] presenceValueBytes, final boolean expectPresent) {
161+
assertEquals(expectPresent, foundationDbMessageStore.isClientPresent(presenceValueBytes));
162+
}
163+
164+
static Stream<Arguments> isClientPresent() {
165+
return Stream.of(
166+
Arguments.argumentSet("Presence value doesn't exist",
167+
null, false),
168+
Arguments.argumentSet("Presence updated recently",
169+
Conversions.longToByteArray(constructPresenceValue(42, getEpochSecondsBeforeClock(5))), true),
170+
Arguments.argumentSet("Presence updated same second as current time",
171+
Conversions.longToByteArray(constructPresenceValue(42, getEpochSecondsBeforeClock(0))), true),
172+
Arguments.argumentSet("Presence updated exactly at the second before which it would have been considered offline",
173+
Conversions.longToByteArray(constructPresenceValue(42, getEpochSecondsBeforeClock(300))), true),
174+
Arguments.argumentSet("Presence expired",
175+
Conversions.longToByteArray(constructPresenceValue(42, getEpochSecondsBeforeClock(400))), false)
176+
);
177+
}
178+
179+
private static MessageProtos.Envelope generateRandomMessage(final boolean ephemeral) {
79180
return MessageProtos.Envelope.newBuilder()
80181
.setContent(ByteString.copyFrom(TestRandomUtil.nextBytes(16)))
182+
.setEphemeral(ephemeral)
81183
.build();
82184
}
83185

@@ -90,12 +192,31 @@ private byte[] getMessageByVersionstamp(final AciServiceIdentifier aci, final by
90192
}).join();
91193
}
92194

93-
private Versionstamp getLastMessageVersionstamp(final AciServiceIdentifier aci) {
195+
private Optional<Versionstamp> getMessagesAvailableWatch(final AciServiceIdentifier aci) {
94196
return FOUNDATION_DB_EXTENSION.getDatabase()
95-
.read(transaction -> transaction.get(foundationDbMessageStore.getLastMessageKey(aci))
96-
.thenApply(Tuple::fromBytes)
97-
.thenApply(t -> t.getVersionstamp(0)))
197+
.read(transaction -> transaction.get(foundationDbMessageStore.getMessagesAvailableWatchKey(aci))
198+
.thenApply(value -> value == null ? null : Tuple.fromBytes(value).getVersionstamp(0))
199+
.thenApply(Optional::ofNullable))
98200
.join();
99201
}
100202

203+
private void writePresenceKey(final AciServiceIdentifier aci, final byte deviceId, final int serverId,
204+
final long secondsBeforeCurrentTime) {
205+
FOUNDATION_DB_EXTENSION.getDatabase().run(transaction -> {
206+
final byte[] presenceKey = foundationDbMessageStore.getPresenceKey(aci, deviceId);
207+
final long presenceUpdateEpochSeconds = getEpochSecondsBeforeClock(secondsBeforeCurrentTime);
208+
final long presenceValue = constructPresenceValue(serverId, presenceUpdateEpochSeconds);
209+
transaction.set(presenceKey, Conversions.longToByteArray(presenceValue));
210+
return null;
211+
});
212+
}
213+
214+
private static long getEpochSecondsBeforeClock(final long secondsBefore) {
215+
return CLOCK.instant().minusSeconds(secondsBefore).getEpochSecond();
216+
}
217+
218+
private static long constructPresenceValue(final int serverId, final long presenceUpdateEpochSeconds) {
219+
return (long) (serverId & 0x0ffff) << 48 | (presenceUpdateEpochSeconds & 0x0000ffffffffffffL);
220+
}
221+
101222
}

0 commit comments

Comments
 (0)