Skip to content

Commit bbf4237

Browse files
authored
KTOR-8312 Make the clearToken behavior consistent (#4735)
1 parent fc0fc4f commit bbf4237

File tree

2 files changed

+241
-81
lines changed

2 files changed

+241
-81
lines changed

ktor-client/ktor-client-plugins/ktor-client-auth/common/src/io/ktor/client/plugins/auth/providers/AuthTokenHolder.kt

Lines changed: 82 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -4,82 +4,103 @@
44

55
package io.ktor.client.plugins.auth.providers
66

7-
import kotlinx.atomicfu.*
8-
import kotlinx.coroutines.*
7+
import kotlinx.coroutines.DelicateCoroutinesApi
8+
import kotlinx.coroutines.GlobalScope
9+
import kotlinx.coroutines.launch
10+
import kotlinx.coroutines.sync.Mutex
11+
import kotlinx.coroutines.sync.withLock
12+
import kotlinx.coroutines.withContext
13+
import kotlin.concurrent.Volatile
14+
import kotlin.coroutines.CoroutineContext
15+
import kotlin.coroutines.coroutineContext
916

10-
internal class AuthTokenHolder<T>(
11-
private val loadTokens: suspend () -> T?
12-
) {
13-
private val refreshTokensDeferred = atomic<CompletableDeferred<T?>?>(null)
14-
private val loadTokensDeferred = atomic<CompletableDeferred<T?>?>(null)
17+
internal class AuthTokenHolder<T>(private val loadTokens: suspend () -> T?) {
1518

16-
internal fun clearToken() {
17-
loadTokensDeferred.value = null
18-
refreshTokensDeferred.value = null
19-
}
19+
@Volatile private var value: T? = null
2020

21-
internal suspend fun loadToken(): T? {
22-
var deferred: CompletableDeferred<T?>?
23-
lateinit var newDeferred: CompletableDeferred<T?>
24-
while (true) {
25-
deferred = loadTokensDeferred.value
26-
val newValue = deferred ?: CompletableDeferred()
27-
if (loadTokensDeferred.compareAndSet(deferred, newValue)) {
28-
newDeferred = newValue
29-
break
30-
}
31-
}
21+
@Volatile private var isLoadRequest = false
3222

33-
// if there's already a pending loadTokens(), just wait for it to complete
34-
if (deferred != null) {
35-
return deferred.await()
36-
}
23+
private val mutex = Mutex()
24+
25+
/**
26+
* Exist only for testing
27+
*/
28+
internal fun get(): T? = value
3729

38-
try {
39-
val newTokens = loadTokens()
30+
/**
31+
* Returns a cached value if any. Otherwise, computes a value using [loadTokens] and caches it.
32+
* Only one [loadToken] call can be executed at a time. The other calls are suspended and have no effect on the cached value.
33+
*/
34+
internal suspend fun loadToken(): T? {
35+
if (value != null) return value // Hot path
36+
val prevValue = value
4037

41-
// [loadTokensDeferred.value] could be null by now (if clearToken() was called while
42-
// suspended), which is why we are using [newDeferred] to complete the suspending callback.
43-
newDeferred.complete(newTokens)
38+
return if (coroutineContext[SetTokenContext] != null) { // Already locked by setToken
39+
value = loadTokens()
40+
value
41+
} else {
42+
mutex.withLock {
43+
isLoadRequest = true
44+
try {
45+
if (prevValue == value) { // Raced first
46+
value = loadTokens()
47+
}
48+
} finally {
49+
isLoadRequest = false
50+
}
4451

45-
return newTokens
46-
} catch (cause: Throwable) {
47-
newDeferred.completeExceptionally(cause)
48-
loadTokensDeferred.compareAndSet(newDeferred, null)
49-
throw cause
52+
value
53+
}
5054
}
5155
}
5256

57+
private class SetTokenContext : CoroutineContext.Element {
58+
override val key: CoroutineContext.Key<*>
59+
get() = SetTokenContext
60+
61+
companion object : CoroutineContext.Key<SetTokenContext>
62+
}
63+
64+
private val setTokenMarker = SetTokenContext()
65+
66+
/**
67+
* Replaces the current cached value with one computed with [block].
68+
* Only one [loadToken] or [setToken] call can be executed at a time,
69+
* although the resumed [setToken] call recomputes the value cached by [loadToken].
70+
*/
5371
internal suspend fun setToken(block: suspend () -> T?): T? {
54-
var deferred: CompletableDeferred<T?>?
55-
lateinit var newDeferred: CompletableDeferred<T?>
56-
while (true) {
57-
deferred = refreshTokensDeferred.value
58-
val newValue = deferred ?: CompletableDeferred()
59-
if (refreshTokensDeferred.compareAndSet(deferred, newValue)) {
60-
newDeferred = newValue
61-
break
72+
val prevValue = value
73+
val lockedByLoad = isLoadRequest
74+
75+
return mutex.withLock {
76+
if (prevValue == value || lockedByLoad) { // Raced first
77+
val newValue = withContext(coroutineContext + setTokenMarker) {
78+
block()
79+
}
80+
81+
if (newValue != null) {
82+
value = newValue
83+
}
6284
}
85+
86+
value
6387
}
88+
}
6489

65-
try {
66-
val newToken = if (deferred == null) {
67-
val newTokens = block()
68-
69-
// [refreshTokensDeferred.value] could be null by now (if clearToken() was called while
70-
// suspended), which is why we are using [newDeferred] to complete the suspending callback.
71-
newDeferred.complete(newTokens)
72-
refreshTokensDeferred.value = null
73-
newTokens
74-
} else {
75-
deferred.await()
90+
/**
91+
* Resets the cached value.
92+
*/
93+
@OptIn(DelicateCoroutinesApi::class)
94+
internal fun clearToken() {
95+
if (mutex.tryLock()) {
96+
value = null
97+
mutex.unlock()
98+
} else {
99+
GlobalScope.launch {
100+
mutex.withLock {
101+
value = null
102+
}
76103
}
77-
loadTokensDeferred.value = CompletableDeferred(newToken)
78-
return newToken
79-
} catch (cause: Throwable) {
80-
newDeferred.completeExceptionally(cause)
81-
refreshTokensDeferred.compareAndSet(newDeferred, null)
82-
throw cause
83104
}
84105
}
85106
}

0 commit comments

Comments
 (0)