1
1
package org .whispersystems .textsecuregcm .storage .foundationdb ;
2
2
3
+ import static org .junit .jupiter .api .Assertions .assertArrayEquals ;
3
4
import static org .junit .jupiter .api .Assertions .assertEquals ;
4
5
import static org .junit .jupiter .api .Assertions .assertNotNull ;
6
+ import static org .junit .jupiter .api .Assertions .assertTrue ;
5
7
6
8
import com .apple .foundationdb .Database ;
7
9
import com .apple .foundationdb .tuple .Tuple ;
8
10
import com .apple .foundationdb .tuple .Versionstamp ;
9
11
import com .google .protobuf .ByteString ;
10
12
import com .google .protobuf .InvalidProtocolBufferException ;
11
13
import java .io .UncheckedIOException ;
14
+ import java .time .Clock ;
15
+ import java .time .Instant ;
16
+ import java .time .ZoneId ;
12
17
import java .util .List ;
13
18
import java .util .Map ;
19
+ import java .util .Objects ;
20
+ import java .util .Optional ;
14
21
import java .util .UUID ;
15
22
import java .util .concurrent .Executors ;
16
23
import java .util .function .Function ;
17
24
import java .util .stream .Collectors ;
18
25
import java .util .stream .IntStream ;
26
+ import java .util .stream .Stream ;
19
27
import org .junit .jupiter .api .BeforeEach ;
20
28
import org .junit .jupiter .api .Test ;
21
29
import org .junit .jupiter .api .Timeout ;
22
30
import org .junit .jupiter .api .extension .RegisterExtension ;
31
+ import org .junit .jupiter .params .ParameterizedTest ;
32
+ import org .junit .jupiter .params .provider .Arguments ;
33
+ import org .junit .jupiter .params .provider .MethodSource ;
34
+ import org .junit .jupiter .params .provider .ValueSource ;
23
35
import org .whispersystems .textsecuregcm .entities .MessageProtos ;
24
36
import org .whispersystems .textsecuregcm .identity .AciServiceIdentifier ;
25
37
import org .whispersystems .textsecuregcm .storage .Device ;
26
38
import org .whispersystems .textsecuregcm .storage .FoundationDbExtension ;
39
+ import org .whispersystems .textsecuregcm .util .Conversions ;
27
40
import org .whispersystems .textsecuregcm .util .TestRandomUtil ;
28
41
29
42
@ Timeout (value = 5 , threadMode = Timeout .ThreadMode .SEPARATE_THREAD )
@@ -34,50 +47,139 @@ class FoundationDbMessageStoreTest {
34
47
35
48
private FoundationDbMessageStore foundationDbMessageStore ;
36
49
50
+ private static final Clock CLOCK = Clock .fixed (Instant .ofEpochSecond (500 ), ZoneId .of ("UTC" ));
51
+
37
52
@ BeforeEach
38
53
void setup () {
39
54
foundationDbMessageStore = new FoundationDbMessageStore (
40
55
new Database []{FOUNDATION_DB_EXTENSION .getDatabase ()},
41
- Executors .newVirtualThreadPerTaskExecutor ());
56
+ Executors .newVirtualThreadPerTaskExecutor (),
57
+ CLOCK );
42
58
}
43
59
44
- @ Test
45
- void insert () {
60
+ @ ParameterizedTest
61
+ @ MethodSource
62
+ void insert (final long presenceUpdatedBeforeSeconds , final boolean ephemeral , final boolean expectMessagesInserted ,
63
+ final boolean expectVersionstampUpdated ) {
46
64
final AciServiceIdentifier aci = new AciServiceIdentifier (UUID .randomUUID ());
47
65
final List <Byte > deviceIds = IntStream .range (Device .PRIMARY_ID , Device .PRIMARY_ID + 6 )
48
66
.mapToObj (i -> (byte ) i )
49
67
.toList ();
68
+ deviceIds .forEach (deviceId -> writePresenceKey (aci , deviceId , 1 , presenceUpdatedBeforeSeconds ));
50
69
final Map <Byte , MessageProtos .Envelope > messagesByDeviceId = deviceIds .stream ()
51
- .collect (Collectors .toMap (Function .identity (), _ -> generateRandomMessage ()));
52
- final Versionstamp versionstamp = foundationDbMessageStore .insert (aci , messagesByDeviceId ).join ();
70
+ .collect (Collectors .toMap (Function .identity (), _ -> generateRandomMessage (ephemeral )));
71
+ final Optional < Versionstamp > versionstamp = foundationDbMessageStore .insert (aci , messagesByDeviceId ).join ();
53
72
assertNotNull (versionstamp );
54
73
55
- final Map <Byte , MessageProtos .Envelope > storedMessagesByDeviceId = deviceIds .stream ()
56
- .collect (Collectors .toMap (Function .identity (), deviceId -> {
57
- try {
58
- return MessageProtos .Envelope .parseFrom (getMessageByVersionstamp (aci , deviceId , versionstamp ));
59
- } catch (final InvalidProtocolBufferException e ) {
60
- throw new UncheckedIOException (e );
61
- }
62
- }));
63
-
64
- assertEquals (messagesByDeviceId , storedMessagesByDeviceId );
65
- assertEquals (versionstamp , getLastMessageVersionstamp (aci ),
66
- "last message versionstamp should be the versionstamp of the last insert transaction" );
74
+ if (expectMessagesInserted ) {
75
+ assertTrue (versionstamp .isPresent ());
76
+ final Map <Byte , MessageProtos .Envelope > storedMessagesByDeviceId = deviceIds .stream ()
77
+ .collect (Collectors .toMap (Function .identity (), deviceId -> {
78
+ try {
79
+ return MessageProtos .Envelope .parseFrom (getMessageByVersionstamp (aci , deviceId , versionstamp .get ()));
80
+ } catch (final InvalidProtocolBufferException e ) {
81
+ throw new UncheckedIOException (e );
82
+ }
83
+ }));
84
+
85
+ assertEquals (messagesByDeviceId , storedMessagesByDeviceId );
86
+ } else {
87
+ assertTrue (versionstamp .isEmpty ());
88
+ }
89
+
90
+ if (expectVersionstampUpdated ) {
91
+ assertEquals (versionstamp , getMessagesAvailableWatch (aci ),
92
+ "messages available versionstamp should be the versionstamp of the last insert transaction" );
93
+ } else {
94
+ assertTrue (getMessagesAvailableWatch (aci ).isEmpty ());
95
+ }
96
+ }
97
+
98
+ private static Stream <Arguments > insert () {
99
+ return Stream .of (
100
+ Arguments .argumentSet ("Non-ephemeral messages with all devices online" ,
101
+ 10L , false , true , true ),
102
+ Arguments .argumentSet (
103
+ "Ephemeral messages with presence updated exactly at the second before which the device would be considered offline" ,
104
+ 300L , true , true , true ),
105
+ Arguments .argumentSet ("Non-ephemeral messages for with all devices offline" ,
106
+ 310L , false , true , false ),
107
+ Arguments .argumentSet ("Ephemeral messages with all devices offline" ,
108
+ 310L , true , false , false )
109
+ );
67
110
}
68
111
69
112
@ Test
70
113
void versionstampCorrectlyUpdatedOnMultipleInserts () {
71
114
final AciServiceIdentifier aci = new AciServiceIdentifier (UUID .randomUUID ());
72
- foundationDbMessageStore .insert (aci , Map .of (Device .PRIMARY_ID , generateRandomMessage ())).join ();
73
- final Versionstamp secondMessageVersionstamp = foundationDbMessageStore .insert (aci ,
74
- Map .of (Device .PRIMARY_ID , generateRandomMessage ())).join ();
75
- assertEquals (secondMessageVersionstamp , getLastMessageVersionstamp (aci ));
115
+ writePresenceKey (aci , Device .PRIMARY_ID , 1 , 10L );
116
+ foundationDbMessageStore .insert (aci , Map .of (Device .PRIMARY_ID , generateRandomMessage (false ))).join ();
117
+ final Optional <Versionstamp > secondMessageVersionstamp = foundationDbMessageStore .insert (aci ,
118
+ Map .of (Device .PRIMARY_ID , generateRandomMessage (false ))).join ();
119
+ assertEquals (secondMessageVersionstamp , getMessagesAvailableWatch (aci ));
76
120
}
77
121
78
- private static MessageProtos .Envelope generateRandomMessage () {
122
+ @ ParameterizedTest
123
+ @ ValueSource (booleans = {true , false })
124
+ void insertOnlyOneDevicePresent (final boolean ephemeral ) {
125
+ final AciServiceIdentifier aci = new AciServiceIdentifier (UUID .randomUUID ());
126
+ final List <Byte > deviceIds = IntStream .range (Device .PRIMARY_ID , Device .PRIMARY_ID + 6 )
127
+ .mapToObj (i -> (byte ) i )
128
+ .toList ();
129
+ // Only 1 device has a recent presence, the others do not have presence keys present.
130
+ writePresenceKey (aci , Device .PRIMARY_ID , 1 , 10L );
131
+ final Map <Byte , MessageProtos .Envelope > messagesByDeviceId = deviceIds .stream ()
132
+ .collect (Collectors .toMap (Function .identity (), _ -> generateRandomMessage (ephemeral )));
133
+ final Optional <Versionstamp > versionstamp = foundationDbMessageStore .insert (aci , messagesByDeviceId ).join ();
134
+ assertNotNull (versionstamp );
135
+ assertTrue (versionstamp .isPresent (),
136
+ "versionstamp should be present since at least one message should be inserted" );
137
+
138
+ assertArrayEquals (
139
+ messagesByDeviceId .get (Device .PRIMARY_ID ).toByteArray (),
140
+ getMessageByVersionstamp (aci , Device .PRIMARY_ID , versionstamp .get ()),
141
+ "Message for primary should always be stored since it has a recently updated presence" );
142
+
143
+ if (ephemeral ) {
144
+ assertTrue (IntStream .range (Device .PRIMARY_ID + 1 , Device .PRIMARY_ID + 6 )
145
+ .mapToObj (deviceId -> getMessageByVersionstamp (aci , (byte ) deviceId , versionstamp .get ()))
146
+ .allMatch (Objects ::isNull ), "Ephemeral messages for non-present devices must not be stored" );
147
+ } else {
148
+ IntStream .range (Device .PRIMARY_ID + 1 , Device .PRIMARY_ID )
149
+ .forEach (deviceId -> {
150
+ final byte [] messageBytes = getMessageByVersionstamp (aci , (byte ) deviceId , versionstamp .get ());
151
+ assertEquals (messagesByDeviceId .get ((byte ) deviceId ).toByteArray (), messageBytes ,
152
+ "Non-ephemeral messages must always be stored" );
153
+ });
154
+ }
155
+
156
+ }
157
+
158
+ @ ParameterizedTest
159
+ @ MethodSource
160
+ void isClientPresent (final byte [] presenceValueBytes , final boolean expectPresent ) {
161
+ assertEquals (expectPresent , foundationDbMessageStore .isClientPresent (presenceValueBytes ));
162
+ }
163
+
164
+ static Stream <Arguments > isClientPresent () {
165
+ return Stream .of (
166
+ Arguments .argumentSet ("Presence value doesn't exist" ,
167
+ null , false ),
168
+ Arguments .argumentSet ("Presence updated recently" ,
169
+ Conversions .longToByteArray (constructPresenceValue (42 , getEpochSecondsBeforeClock (5 ))), true ),
170
+ Arguments .argumentSet ("Presence updated same second as current time" ,
171
+ Conversions .longToByteArray (constructPresenceValue (42 , getEpochSecondsBeforeClock (0 ))), true ),
172
+ Arguments .argumentSet ("Presence updated exactly at the second before which it would have been considered offline" ,
173
+ Conversions .longToByteArray (constructPresenceValue (42 , getEpochSecondsBeforeClock (300 ))), true ),
174
+ Arguments .argumentSet ("Presence expired" ,
175
+ Conversions .longToByteArray (constructPresenceValue (42 , getEpochSecondsBeforeClock (400 ))), false )
176
+ );
177
+ }
178
+
179
+ private static MessageProtos .Envelope generateRandomMessage (final boolean ephemeral ) {
79
180
return MessageProtos .Envelope .newBuilder ()
80
181
.setContent (ByteString .copyFrom (TestRandomUtil .nextBytes (16 )))
182
+ .setEphemeral (ephemeral )
81
183
.build ();
82
184
}
83
185
@@ -90,12 +192,31 @@ private byte[] getMessageByVersionstamp(final AciServiceIdentifier aci, final by
90
192
}).join ();
91
193
}
92
194
93
- private Versionstamp getLastMessageVersionstamp (final AciServiceIdentifier aci ) {
195
+ private Optional < Versionstamp > getMessagesAvailableWatch (final AciServiceIdentifier aci ) {
94
196
return FOUNDATION_DB_EXTENSION .getDatabase ()
95
- .read (transaction -> transaction .get (foundationDbMessageStore .getLastMessageKey (aci ))
96
- .thenApply (Tuple :: fromBytes )
97
- .thenApply (t -> t . getVersionstamp ( 0 ) ))
197
+ .read (transaction -> transaction .get (foundationDbMessageStore .getMessagesAvailableWatchKey (aci ))
198
+ .thenApply (value -> value == null ? null : Tuple . fromBytes ( value ). getVersionstamp ( 0 ) )
199
+ .thenApply (Optional :: ofNullable ))
98
200
.join ();
99
201
}
100
202
203
+ private void writePresenceKey (final AciServiceIdentifier aci , final byte deviceId , final int serverId ,
204
+ final long secondsBeforeCurrentTime ) {
205
+ FOUNDATION_DB_EXTENSION .getDatabase ().run (transaction -> {
206
+ final byte [] presenceKey = foundationDbMessageStore .getPresenceKey (aci , deviceId );
207
+ final long presenceUpdateEpochSeconds = getEpochSecondsBeforeClock (secondsBeforeCurrentTime );
208
+ final long presenceValue = constructPresenceValue (serverId , presenceUpdateEpochSeconds );
209
+ transaction .set (presenceKey , Conversions .longToByteArray (presenceValue ));
210
+ return null ;
211
+ });
212
+ }
213
+
214
+ private static long getEpochSecondsBeforeClock (final long secondsBefore ) {
215
+ return CLOCK .instant ().minusSeconds (secondsBefore ).getEpochSecond ();
216
+ }
217
+
218
+ private static long constructPresenceValue (final int serverId , final long presenceUpdateEpochSeconds ) {
219
+ return (long ) (serverId & 0x0ffff ) << 48 | (presenceUpdateEpochSeconds & 0x0000ffffffffffffL );
220
+ }
221
+
101
222
}
0 commit comments