Make sure to save the tokens the Client might return when its session is restored (#3378)
* Use `ClientSessionDelegate` to ensure tokens are always updated. Refreshed tokens on client restoration might not have been stored to disk if the token refresh happened before `RustMatrixClient` was built and the `ClientDelegate` was set in it. Using `ClientSessionDelegate` should ensure the tokens refreshed callback is called at any point in time. * Improve how assigning the Client works, fix docs * Fix review comments
This commit is contained in:
parent
9fb82a1e86
commit
2c8b0d0b95
10 changed files with 161 additions and 98 deletions
|
|
@ -33,7 +33,7 @@ class DefaultLogoutUseCase @Inject constructor(
|
|||
return if (currentSession != null) {
|
||||
matrixClientProvider.getOrRestore(currentSession)
|
||||
.getOrThrow()
|
||||
.logout(ignoreSdkError = true)
|
||||
.logout(userInitiated = true, ignoreSdkError = true)
|
||||
} else {
|
||||
error("No session to sign out")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ class LogoutPresenter @Inject constructor(
|
|||
ignoreSdkError: Boolean,
|
||||
) = launch {
|
||||
suspend {
|
||||
matrixClient.logout(ignoreSdkError)
|
||||
matrixClient.logout(userInitiated = true, ignoreSdkError)
|
||||
}.runCatchingUpdatingState(logoutAction)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ class DefaultDirectLogoutPresenter @Inject constructor(
|
|||
ignoreSdkError: Boolean,
|
||||
) = launch {
|
||||
suspend {
|
||||
matrixClient.logout(ignoreSdkError)
|
||||
matrixClient.logout(userInitiated = true, ignoreSdkError)
|
||||
}.runCatchingUpdatingState(logoutAction)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -144,7 +144,7 @@ class LogoutPresenterTest {
|
|||
@Test
|
||||
fun `present - logout with error then cancel`() = runTest {
|
||||
val matrixClient = FakeMatrixClient().apply {
|
||||
logoutLambda = { _ ->
|
||||
logoutLambda = { _, _ ->
|
||||
throw A_THROWABLE
|
||||
}
|
||||
}
|
||||
|
|
@ -172,7 +172,7 @@ class LogoutPresenterTest {
|
|||
@Test
|
||||
fun `present - logout with error then force`() = runTest {
|
||||
val matrixClient = FakeMatrixClient().apply {
|
||||
logoutLambda = { ignoreSdkError ->
|
||||
logoutLambda = { ignoreSdkError, _ ->
|
||||
if (!ignoreSdkError) {
|
||||
throw A_THROWABLE
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ class DefaultDirectLogoutPresenterTest {
|
|||
@Test
|
||||
fun `present - logout with error then cancel`() = runTest {
|
||||
val matrixClient = FakeMatrixClient().apply {
|
||||
logoutLambda = { _ ->
|
||||
logoutLambda = { _, _ ->
|
||||
throw A_THROWABLE
|
||||
}
|
||||
}
|
||||
|
|
@ -153,7 +153,7 @@ class DefaultDirectLogoutPresenterTest {
|
|||
@Test
|
||||
fun `present - logout with error then force`() = runTest {
|
||||
val matrixClient = FakeMatrixClient().apply {
|
||||
logoutLambda = { ignoreSdkError ->
|
||||
logoutLambda = { ignoreSdkError, _ ->
|
||||
if (!ignoreSdkError) {
|
||||
throw A_THROWABLE
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -88,9 +88,10 @@ interface MatrixClient : Closeable {
|
|||
* Logout the user.
|
||||
* Returns an optional URL. When the URL is there, it should be presented to the user after logout for
|
||||
* Relying Party (RP) initiated logout on their account page.
|
||||
* @param userInitiated if false, the logout came from the HS, no request will be made and the session entry will be kept in the store.
|
||||
* @param ignoreSdkError if true, the SDK will ignore any error and delete the session data anyway.
|
||||
*/
|
||||
suspend fun logout(ignoreSdkError: Boolean): String?
|
||||
suspend fun logout(userInitiated: Boolean, ignoreSdkError: Boolean): String?
|
||||
|
||||
/**
|
||||
* Retrieve the user profile, will also eventually emit a new value to [userProfile].
|
||||
|
|
|
|||
|
|
@ -0,0 +1,133 @@
|
|||
/*
|
||||
* Copyright (c) 2024 New Vector Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.element.android.libraries.matrix.impl
|
||||
|
||||
import io.element.android.libraries.core.coroutine.CoroutineDispatchers
|
||||
import io.element.android.libraries.matrix.impl.mapper.toSessionData
|
||||
import io.element.android.libraries.matrix.impl.paths.getSessionPaths
|
||||
import io.element.android.libraries.matrix.impl.util.anonymizedTokens
|
||||
import io.element.android.libraries.sessionstorage.api.SessionStore
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
||||
import kotlinx.coroutines.launch
|
||||
import org.matrix.rustcomponents.sdk.ClientDelegate
|
||||
import org.matrix.rustcomponents.sdk.ClientSessionDelegate
|
||||
import org.matrix.rustcomponents.sdk.Session
|
||||
import timber.log.Timber
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
/**
|
||||
* This class is responsible for handling the session data for the Rust SDK.
|
||||
*
|
||||
* It implements both [ClientSessionDelegate] and [ClientDelegate] to react to session data updates and auth errors.
|
||||
*
|
||||
* IMPORTANT: you must set the [client] property as soon as possible so [didReceiveAuthError] can work properly.
|
||||
*/
|
||||
@OptIn(ExperimentalCoroutinesApi::class)
|
||||
class RustClientSessionDelegate(
|
||||
private val sessionStore: SessionStore,
|
||||
private val appCoroutineScope: CoroutineScope,
|
||||
coroutineDispatchers: CoroutineDispatchers,
|
||||
) : ClientSessionDelegate, ClientDelegate {
|
||||
private val clientLog = Timber.tag("$this")
|
||||
|
||||
// Used to ensure several calls to `didReceiveAuthError` don't trigger multiple logouts
|
||||
private val isLoggingOut = AtomicBoolean(false)
|
||||
|
||||
// To make sure only one coroutine affecting the token persistence can run at a time
|
||||
private val updateTokensDispatcher = coroutineDispatchers.io.limitedParallelism(1)
|
||||
|
||||
// This Client needs to be set up as soon as possible so `didReceiveAuthError` can work properly.
|
||||
private var client: RustMatrixClient? = null
|
||||
|
||||
/**
|
||||
* Sets the [ClientDelegate] for the [RustMatrixClient], and keeps a reference to the client so it can be used later.
|
||||
*/
|
||||
fun bindClient(client: RustMatrixClient) {
|
||||
this.client = client
|
||||
client.setDelegate(this)
|
||||
}
|
||||
|
||||
override fun saveSessionInKeychain(session: Session) {
|
||||
appCoroutineScope.launch(updateTokensDispatcher) {
|
||||
val existingData = sessionStore.getSession(session.userId) ?: return@launch
|
||||
val (anonymizedAccessToken, anonymizedRefreshToken) = session.anonymizedTokens()
|
||||
clientLog.d(
|
||||
"Saving new session data with token: access token '$anonymizedAccessToken' and refresh token '$anonymizedRefreshToken'. " +
|
||||
"Was token valid: ${existingData.isTokenValid}"
|
||||
)
|
||||
val newData = session.toSessionData(
|
||||
isTokenValid = true,
|
||||
loginType = existingData.loginType,
|
||||
passphrase = existingData.passphrase,
|
||||
sessionPaths = existingData.getSessionPaths(),
|
||||
)
|
||||
sessionStore.updateData(newData)
|
||||
clientLog.d("Saved new session data with access token: '$anonymizedAccessToken'.")
|
||||
}.invokeOnCompletion {
|
||||
if (it != null) {
|
||||
clientLog.e(it, "Failed to save new session data.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun didReceiveAuthError(isSoftLogout: Boolean) {
|
||||
clientLog.w("didReceiveAuthError(isSoftLogout=$isSoftLogout)")
|
||||
if (isLoggingOut.getAndSet(true).not()) {
|
||||
clientLog.v("didReceiveAuthError -> do the cleanup")
|
||||
// TODO handle isSoftLogout parameter.
|
||||
appCoroutineScope.launch(updateTokensDispatcher) {
|
||||
val currentClient = client
|
||||
if (currentClient == null) {
|
||||
clientLog.w("didReceiveAuthError -> no client, exiting")
|
||||
isLoggingOut.set(false)
|
||||
return@launch
|
||||
}
|
||||
val existingData = sessionStore.getSession(currentClient.sessionId.value)
|
||||
val (anonymizedAccessToken, anonymizedRefreshToken) = existingData.anonymizedTokens()
|
||||
clientLog.d(
|
||||
"Removing session data with access token '$anonymizedAccessToken' " +
|
||||
"and refresh token '$anonymizedRefreshToken'."
|
||||
)
|
||||
if (existingData != null) {
|
||||
// Set isTokenValid to false
|
||||
val newData = existingData.copy(isTokenValid = false)
|
||||
sessionStore.updateData(newData)
|
||||
clientLog.d("Invalidated session data with access token: '$anonymizedAccessToken'.")
|
||||
} else {
|
||||
clientLog.d("No session data found.")
|
||||
}
|
||||
client?.logout(userInitiated = false, ignoreSdkError = true)
|
||||
}.invokeOnCompletion {
|
||||
if (it != null) {
|
||||
clientLog.e(it, "Failed to remove session data.")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
clientLog.v("didReceiveAuthError -> already cleaning up")
|
||||
}
|
||||
}
|
||||
|
||||
override fun didRefreshTokens() {
|
||||
// This is done in `saveSessionInKeychain(Session)` instead.
|
||||
}
|
||||
|
||||
override fun retrieveSessionFromKeychain(userId: String): Session {
|
||||
// This should never be called, as it's only used for multi-process setups
|
||||
error("retrieveSessionFromKeychain should never be called for Android")
|
||||
}
|
||||
}
|
||||
|
|
@ -51,12 +51,10 @@ import io.element.android.libraries.matrix.api.user.MatrixUser
|
|||
import io.element.android.libraries.matrix.api.verification.SessionVerificationService
|
||||
import io.element.android.libraries.matrix.impl.core.toProgressWatcher
|
||||
import io.element.android.libraries.matrix.impl.encryption.RustEncryptionService
|
||||
import io.element.android.libraries.matrix.impl.mapper.toSessionData
|
||||
import io.element.android.libraries.matrix.impl.media.RustMediaLoader
|
||||
import io.element.android.libraries.matrix.impl.notification.RustNotificationService
|
||||
import io.element.android.libraries.matrix.impl.notificationsettings.RustNotificationSettingsService
|
||||
import io.element.android.libraries.matrix.impl.oidc.toRustAction
|
||||
import io.element.android.libraries.matrix.impl.paths.getSessionPaths
|
||||
import io.element.android.libraries.matrix.impl.pushers.RustPushersService
|
||||
import io.element.android.libraries.matrix.impl.room.RoomContentForwarder
|
||||
import io.element.android.libraries.matrix.impl.room.RoomSyncSubscriber
|
||||
|
|
@ -69,7 +67,6 @@ import io.element.android.libraries.matrix.impl.sync.RustSyncService
|
|||
import io.element.android.libraries.matrix.impl.usersearch.UserProfileMapper
|
||||
import io.element.android.libraries.matrix.impl.usersearch.UserSearchResultMapper
|
||||
import io.element.android.libraries.matrix.impl.util.SessionPathsProvider
|
||||
import io.element.android.libraries.matrix.impl.util.anonymizedTokens
|
||||
import io.element.android.libraries.matrix.impl.util.cancelAndDestroy
|
||||
import io.element.android.libraries.matrix.impl.util.mxCallbackFlow
|
||||
import io.element.android.libraries.matrix.impl.verification.RustSessionVerificationService
|
||||
|
|
@ -100,7 +97,6 @@ import kotlinx.coroutines.withContext
|
|||
import kotlinx.coroutines.withTimeout
|
||||
import org.matrix.rustcomponents.sdk.BackupState
|
||||
import org.matrix.rustcomponents.sdk.Client
|
||||
import org.matrix.rustcomponents.sdk.ClientDelegate
|
||||
import org.matrix.rustcomponents.sdk.ClientException
|
||||
import org.matrix.rustcomponents.sdk.IgnoredUsersListener
|
||||
import org.matrix.rustcomponents.sdk.NotificationProcessSetup
|
||||
|
|
@ -111,7 +107,6 @@ import org.matrix.rustcomponents.sdk.use
|
|||
import timber.log.Timber
|
||||
import java.io.File
|
||||
import java.util.Optional
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
import kotlin.time.Duration
|
||||
import kotlin.time.Duration.Companion.INFINITE
|
||||
import kotlin.time.Duration.Companion.seconds
|
||||
|
|
@ -130,6 +125,7 @@ class RustMatrixClient(
|
|||
private val baseDirectory: File,
|
||||
baseCacheDirectory: File,
|
||||
private val clock: SystemClock,
|
||||
sessionDelegate: RustClientSessionDelegate,
|
||||
) : MatrixClient {
|
||||
override val sessionId: UserId = UserId(client.userId())
|
||||
override val deviceId: String = client.deviceId()
|
||||
|
|
@ -138,8 +134,6 @@ class RustMatrixClient(
|
|||
private val innerRoomListService = syncService.roomListService()
|
||||
private val sessionDispatcher = dispatchers.io.limitedParallelism(64)
|
||||
|
||||
// To make sure only one coroutine affecting the token persistence can run at a time
|
||||
private val tokenRefreshDispatcher = sessionDispatcher.limitedParallelism(1)
|
||||
private val rustSyncService = RustSyncService(syncService, sessionCoroutineScope)
|
||||
private val pushersService = RustPushersService(
|
||||
client = client,
|
||||
|
|
@ -164,72 +158,6 @@ class RustMatrixClient(
|
|||
|
||||
private val sessionPathsProvider = SessionPathsProvider(sessionStore)
|
||||
|
||||
private val isLoggingOut = AtomicBoolean(false)
|
||||
|
||||
private val clientDelegate = object : ClientDelegate {
|
||||
private val clientLog get() = Timber.tag(this@RustMatrixClient.toString())
|
||||
|
||||
override fun didReceiveAuthError(isSoftLogout: Boolean) {
|
||||
clientLog.w("didReceiveAuthError(isSoftLogout=$isSoftLogout)")
|
||||
if (isLoggingOut.getAndSet(true).not()) {
|
||||
clientLog.v("didReceiveAuthError -> do the cleanup")
|
||||
// TODO handle isSoftLogout parameter.
|
||||
appCoroutineScope.launch(tokenRefreshDispatcher) {
|
||||
val existingData = sessionStore.getSession(client.userId())
|
||||
val (anonymizedAccessToken, anonymizedRefreshToken) = existingData.anonymizedTokens()
|
||||
clientLog.d(
|
||||
"Removing session data with access token '$anonymizedAccessToken' " +
|
||||
"and refresh token '$anonymizedRefreshToken'."
|
||||
)
|
||||
if (existingData != null) {
|
||||
// Set isTokenValid to false
|
||||
val newData = client.session().toSessionData(
|
||||
isTokenValid = false,
|
||||
loginType = existingData.loginType,
|
||||
passphrase = existingData.passphrase,
|
||||
sessionPaths = existingData.getSessionPaths(),
|
||||
)
|
||||
sessionStore.updateData(newData)
|
||||
clientLog.d("Removed session data with access token: '$anonymizedAccessToken'.")
|
||||
} else {
|
||||
clientLog.d("No session data found.")
|
||||
}
|
||||
doLogout(doRequest = false, removeSession = false, ignoreSdkError = false)
|
||||
}.invokeOnCompletion {
|
||||
if (it != null) {
|
||||
clientLog.e(it, "Failed to remove session data.")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
clientLog.v("didReceiveAuthError -> already cleaning up")
|
||||
}
|
||||
}
|
||||
|
||||
override fun didRefreshTokens() {
|
||||
clientLog.w("didRefreshTokens()")
|
||||
appCoroutineScope.launch(tokenRefreshDispatcher) {
|
||||
val existingData = sessionStore.getSession(client.userId()) ?: return@launch
|
||||
val (anonymizedAccessToken, anonymizedRefreshToken) = client.session().anonymizedTokens()
|
||||
clientLog.d(
|
||||
"Saving new session data with token: access token '$anonymizedAccessToken' and refresh token '$anonymizedRefreshToken'. " +
|
||||
"Was token valid: ${existingData.isTokenValid}"
|
||||
)
|
||||
val newData = client.session().toSessionData(
|
||||
isTokenValid = true,
|
||||
loginType = existingData.loginType,
|
||||
passphrase = existingData.passphrase,
|
||||
sessionPaths = existingData.getSessionPaths(),
|
||||
)
|
||||
sessionStore.updateData(newData)
|
||||
clientLog.d("Saved new session data with access token: '$anonymizedAccessToken'.")
|
||||
}.invokeOnCompletion {
|
||||
if (it != null) {
|
||||
clientLog.e(it, "Failed to save new session data.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private val roomSyncSubscriber: RoomSyncSubscriber = RoomSyncSubscriber(innerRoomListService, dispatchers)
|
||||
|
||||
override val roomListService: RoomListService = RustRoomListService(
|
||||
|
|
@ -271,7 +199,7 @@ class RustMatrixClient(
|
|||
|
||||
private val roomMembershipObserver = RoomMembershipObserver()
|
||||
|
||||
private val clientDelegateTaskHandle: TaskHandle? = client.setDelegate(clientDelegate)
|
||||
private val clientDelegateTaskHandle: TaskHandle? = client.setDelegate(sessionDelegate)
|
||||
|
||||
private val _userProfile: MutableStateFlow<MatrixUser> = MutableStateFlow(
|
||||
MatrixUser(
|
||||
|
|
@ -295,6 +223,9 @@ class RustMatrixClient(
|
|||
.stateIn(sessionCoroutineScope, started = SharingStarted.Eagerly, initialValue = persistentListOf())
|
||||
|
||||
init {
|
||||
// Make sure the session delegate has a reference to the client to be able to logout on auth error
|
||||
sessionDelegate.bindClient(this)
|
||||
|
||||
sessionCoroutineScope.launch {
|
||||
// Force a refresh of the profile
|
||||
getUserProfile()
|
||||
|
|
@ -536,21 +467,11 @@ class RustMatrixClient(
|
|||
deleteSessionDirectory(deleteCryptoDb = false)
|
||||
}
|
||||
|
||||
override suspend fun logout(ignoreSdkError: Boolean): String? = doLogout(
|
||||
doRequest = true,
|
||||
removeSession = true,
|
||||
ignoreSdkError = ignoreSdkError,
|
||||
)
|
||||
|
||||
private suspend fun doLogout(
|
||||
doRequest: Boolean,
|
||||
removeSession: Boolean,
|
||||
ignoreSdkError: Boolean,
|
||||
): String? {
|
||||
override suspend fun logout(userInitiated: Boolean, ignoreSdkError: Boolean): String? {
|
||||
var result: String? = null
|
||||
syncService.stop()
|
||||
withContext(sessionDispatcher) {
|
||||
if (doRequest) {
|
||||
if (userInitiated) {
|
||||
try {
|
||||
result = client.logout()
|
||||
} catch (failure: Throwable) {
|
||||
|
|
@ -564,7 +485,7 @@ class RustMatrixClient(
|
|||
}
|
||||
close()
|
||||
deleteSessionDirectory(deleteCryptoDb = true)
|
||||
if (removeSession) {
|
||||
if (userInitiated) {
|
||||
sessionStore.removeSession(sessionId.value)
|
||||
}
|
||||
}
|
||||
|
|
@ -615,6 +536,10 @@ class RustMatrixClient(
|
|||
})
|
||||
}.buffer(Channel.UNLIMITED)
|
||||
|
||||
internal fun setDelegate(delegate: RustClientSessionDelegate) {
|
||||
client.setDelegate(delegate)
|
||||
}
|
||||
|
||||
private suspend fun File.getCacheSize(
|
||||
includeCryptoDb: Boolean = false,
|
||||
): Long = withContext(sessionDispatcher) {
|
||||
|
|
|
|||
|
|
@ -56,6 +56,8 @@ class RustMatrixClientFactory @Inject constructor(
|
|||
private val appPreferencesStore: AppPreferencesStore,
|
||||
) {
|
||||
suspend fun create(sessionData: SessionData): RustMatrixClient = withContext(coroutineDispatchers.io) {
|
||||
val sessionDelegate = RustClientSessionDelegate(sessionStore, appCoroutineScope, coroutineDispatchers)
|
||||
|
||||
val client = getBaseClientBuilder(
|
||||
sessionPaths = sessionData.getSessionPaths(),
|
||||
passphrase = sessionData.passphrase,
|
||||
|
|
@ -67,6 +69,7 @@ class RustMatrixClientFactory @Inject constructor(
|
|||
)
|
||||
.homeserverUrl(sessionData.homeserverUrl)
|
||||
.username(sessionData.userId)
|
||||
.setSessionDelegate(sessionDelegate)
|
||||
.use { it.build() }
|
||||
|
||||
client.restoreSession(sessionData.toSession())
|
||||
|
|
@ -86,6 +89,7 @@ class RustMatrixClientFactory @Inject constructor(
|
|||
baseDirectory = baseDirectory,
|
||||
baseCacheDirectory = cacheDirectory,
|
||||
clock = clock,
|
||||
sessionDelegate = sessionDelegate,
|
||||
).also {
|
||||
Timber.tag(it.toString()).d("Creating Client with access token '$anonymizedAccessToken' and refresh token '$anonymizedRefreshToken'")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ class FakeMatrixClient(
|
|||
var getRoomInfoFlowLambda = { _: RoomId ->
|
||||
flowOf<Optional<MatrixRoomInfo>>(Optional.empty())
|
||||
}
|
||||
var logoutLambda: (Boolean) -> String? = {
|
||||
var logoutLambda: (Boolean, Boolean) -> String? = { _, _ ->
|
||||
null
|
||||
}
|
||||
|
||||
|
|
@ -170,8 +170,8 @@ class FakeMatrixClient(
|
|||
clearCacheLambda()
|
||||
}
|
||||
|
||||
override suspend fun logout(ignoreSdkError: Boolean): String? = simulateLongTask {
|
||||
return logoutLambda(ignoreSdkError)
|
||||
override suspend fun logout(userInitiated: Boolean, ignoreSdkError: Boolean): String? = simulateLongTask {
|
||||
return logoutLambda(ignoreSdkError, userInitiated)
|
||||
}
|
||||
|
||||
override fun close() = Unit
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue