diff --git a/features/logout/impl/src/main/kotlin/io/element/android/features/logout/impl/DefaultLogoutUseCase.kt b/features/logout/impl/src/main/kotlin/io/element/android/features/logout/impl/DefaultLogoutUseCase.kt index 0ae58aa61e..85300034b9 100644 --- a/features/logout/impl/src/main/kotlin/io/element/android/features/logout/impl/DefaultLogoutUseCase.kt +++ b/features/logout/impl/src/main/kotlin/io/element/android/features/logout/impl/DefaultLogoutUseCase.kt @@ -33,7 +33,7 @@ class DefaultLogoutUseCase @Inject constructor( return if (currentSession != null) { matrixClientProvider.getOrRestore(currentSession) .getOrThrow() - .logout(ignoreSdkError = true) + .logout(userInitiated = true, ignoreSdkError = true) } else { error("No session to sign out") } diff --git a/features/logout/impl/src/main/kotlin/io/element/android/features/logout/impl/LogoutPresenter.kt b/features/logout/impl/src/main/kotlin/io/element/android/features/logout/impl/LogoutPresenter.kt index 34f3ca84a6..cafab35913 100644 --- a/features/logout/impl/src/main/kotlin/io/element/android/features/logout/impl/LogoutPresenter.kt +++ b/features/logout/impl/src/main/kotlin/io/element/android/features/logout/impl/LogoutPresenter.kt @@ -104,7 +104,7 @@ class LogoutPresenter @Inject constructor( ignoreSdkError: Boolean, ) = launch { suspend { - matrixClient.logout(ignoreSdkError) + matrixClient.logout(userInitiated = true, ignoreSdkError) }.runCatchingUpdatingState(logoutAction) } } diff --git a/features/logout/impl/src/main/kotlin/io/element/android/features/logout/impl/direct/DefaultDirectLogoutPresenter.kt b/features/logout/impl/src/main/kotlin/io/element/android/features/logout/impl/direct/DefaultDirectLogoutPresenter.kt index b0924e61e0..1177f42186 100644 --- a/features/logout/impl/src/main/kotlin/io/element/android/features/logout/impl/direct/DefaultDirectLogoutPresenter.kt +++ b/features/logout/impl/src/main/kotlin/io/element/android/features/logout/impl/direct/DefaultDirectLogoutPresenter.kt @@ -86,7 +86,7 @@ class DefaultDirectLogoutPresenter @Inject constructor( ignoreSdkError: Boolean, ) = launch { suspend { - matrixClient.logout(ignoreSdkError) + matrixClient.logout(userInitiated = true, ignoreSdkError) }.runCatchingUpdatingState(logoutAction) } } diff --git a/features/logout/impl/src/test/kotlin/io/element/android/features/logout/impl/LogoutPresenterTest.kt b/features/logout/impl/src/test/kotlin/io/element/android/features/logout/impl/LogoutPresenterTest.kt index 70b346ba8d..0303ab45e4 100644 --- a/features/logout/impl/src/test/kotlin/io/element/android/features/logout/impl/LogoutPresenterTest.kt +++ b/features/logout/impl/src/test/kotlin/io/element/android/features/logout/impl/LogoutPresenterTest.kt @@ -144,7 +144,7 @@ class LogoutPresenterTest { @Test fun `present - logout with error then cancel`() = runTest { val matrixClient = FakeMatrixClient().apply { - logoutLambda = { _ -> + logoutLambda = { _, _ -> throw A_THROWABLE } } @@ -172,7 +172,7 @@ class LogoutPresenterTest { @Test fun `present - logout with error then force`() = runTest { val matrixClient = FakeMatrixClient().apply { - logoutLambda = { ignoreSdkError -> + logoutLambda = { ignoreSdkError, _ -> if (!ignoreSdkError) { throw A_THROWABLE } else { diff --git a/features/logout/impl/src/test/kotlin/io/element/android/features/logout/impl/direct/DefaultDirectLogoutPresenterTest.kt b/features/logout/impl/src/test/kotlin/io/element/android/features/logout/impl/direct/DefaultDirectLogoutPresenterTest.kt index 14d340570c..aed5fbc1de 100644 --- a/features/logout/impl/src/test/kotlin/io/element/android/features/logout/impl/direct/DefaultDirectLogoutPresenterTest.kt +++ b/features/logout/impl/src/test/kotlin/io/element/android/features/logout/impl/direct/DefaultDirectLogoutPresenterTest.kt @@ -125,7 +125,7 @@ class DefaultDirectLogoutPresenterTest { @Test fun `present - logout with error then cancel`() = runTest { val matrixClient = FakeMatrixClient().apply { - logoutLambda = { _ -> + logoutLambda = { _, _ -> throw A_THROWABLE } } @@ -153,7 +153,7 @@ class DefaultDirectLogoutPresenterTest { @Test fun `present - logout with error then force`() = runTest { val matrixClient = FakeMatrixClient().apply { - logoutLambda = { ignoreSdkError -> + logoutLambda = { ignoreSdkError, _ -> if (!ignoreSdkError) { throw A_THROWABLE } else { diff --git a/libraries/matrix/api/src/main/kotlin/io/element/android/libraries/matrix/api/MatrixClient.kt b/libraries/matrix/api/src/main/kotlin/io/element/android/libraries/matrix/api/MatrixClient.kt index 3209d49e03..042902cc71 100644 --- a/libraries/matrix/api/src/main/kotlin/io/element/android/libraries/matrix/api/MatrixClient.kt +++ b/libraries/matrix/api/src/main/kotlin/io/element/android/libraries/matrix/api/MatrixClient.kt @@ -88,9 +88,10 @@ interface MatrixClient : Closeable { * Logout the user. * Returns an optional URL. When the URL is there, it should be presented to the user after logout for * Relying Party (RP) initiated logout on their account page. + * @param userInitiated if false, the logout came from the HS, no request will be made and the session entry will be kept in the store. * @param ignoreSdkError if true, the SDK will ignore any error and delete the session data anyway. */ - suspend fun logout(ignoreSdkError: Boolean): String? + suspend fun logout(userInitiated: Boolean, ignoreSdkError: Boolean): String? /** * Retrieve the user profile, will also eventually emit a new value to [userProfile]. diff --git a/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustClientSessionDelegate.kt b/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustClientSessionDelegate.kt new file mode 100644 index 0000000000..2b856ec984 --- /dev/null +++ b/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustClientSessionDelegate.kt @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2024 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.element.android.libraries.matrix.impl + +import io.element.android.libraries.core.coroutine.CoroutineDispatchers +import io.element.android.libraries.matrix.impl.mapper.toSessionData +import io.element.android.libraries.matrix.impl.paths.getSessionPaths +import io.element.android.libraries.matrix.impl.util.anonymizedTokens +import io.element.android.libraries.sessionstorage.api.SessionStore +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.launch +import org.matrix.rustcomponents.sdk.ClientDelegate +import org.matrix.rustcomponents.sdk.ClientSessionDelegate +import org.matrix.rustcomponents.sdk.Session +import timber.log.Timber +import java.util.concurrent.atomic.AtomicBoolean + +/** + * This class is responsible for handling the session data for the Rust SDK. + * + * It implements both [ClientSessionDelegate] and [ClientDelegate] to react to session data updates and auth errors. + * + * IMPORTANT: you must set the [client] property as soon as possible so [didReceiveAuthError] can work properly. + */ +@OptIn(ExperimentalCoroutinesApi::class) +class RustClientSessionDelegate( + private val sessionStore: SessionStore, + private val appCoroutineScope: CoroutineScope, + coroutineDispatchers: CoroutineDispatchers, +) : ClientSessionDelegate, ClientDelegate { + private val clientLog = Timber.tag("$this") + + // Used to ensure several calls to `didReceiveAuthError` don't trigger multiple logouts + private val isLoggingOut = AtomicBoolean(false) + + // To make sure only one coroutine affecting the token persistence can run at a time + private val updateTokensDispatcher = coroutineDispatchers.io.limitedParallelism(1) + + // This Client needs to be set up as soon as possible so `didReceiveAuthError` can work properly. + private var client: RustMatrixClient? = null + + /** + * Sets the [ClientDelegate] for the [RustMatrixClient], and keeps a reference to the client so it can be used later. + */ + fun bindClient(client: RustMatrixClient) { + this.client = client + client.setDelegate(this) + } + + override fun saveSessionInKeychain(session: Session) { + appCoroutineScope.launch(updateTokensDispatcher) { + val existingData = sessionStore.getSession(session.userId) ?: return@launch + val (anonymizedAccessToken, anonymizedRefreshToken) = session.anonymizedTokens() + clientLog.d( + "Saving new session data with token: access token '$anonymizedAccessToken' and refresh token '$anonymizedRefreshToken'. " + + "Was token valid: ${existingData.isTokenValid}" + ) + val newData = session.toSessionData( + isTokenValid = true, + loginType = existingData.loginType, + passphrase = existingData.passphrase, + sessionPaths = existingData.getSessionPaths(), + ) + sessionStore.updateData(newData) + clientLog.d("Saved new session data with access token: '$anonymizedAccessToken'.") + }.invokeOnCompletion { + if (it != null) { + clientLog.e(it, "Failed to save new session data.") + } + } + } + + override fun didReceiveAuthError(isSoftLogout: Boolean) { + clientLog.w("didReceiveAuthError(isSoftLogout=$isSoftLogout)") + if (isLoggingOut.getAndSet(true).not()) { + clientLog.v("didReceiveAuthError -> do the cleanup") + // TODO handle isSoftLogout parameter. + appCoroutineScope.launch(updateTokensDispatcher) { + val currentClient = client + if (currentClient == null) { + clientLog.w("didReceiveAuthError -> no client, exiting") + isLoggingOut.set(false) + return@launch + } + val existingData = sessionStore.getSession(currentClient.sessionId.value) + val (anonymizedAccessToken, anonymizedRefreshToken) = existingData.anonymizedTokens() + clientLog.d( + "Removing session data with access token '$anonymizedAccessToken' " + + "and refresh token '$anonymizedRefreshToken'." + ) + if (existingData != null) { + // Set isTokenValid to false + val newData = existingData.copy(isTokenValid = false) + sessionStore.updateData(newData) + clientLog.d("Invalidated session data with access token: '$anonymizedAccessToken'.") + } else { + clientLog.d("No session data found.") + } + client?.logout(userInitiated = false, ignoreSdkError = true) + }.invokeOnCompletion { + if (it != null) { + clientLog.e(it, "Failed to remove session data.") + } + } + } else { + clientLog.v("didReceiveAuthError -> already cleaning up") + } + } + + override fun didRefreshTokens() { + // This is done in `saveSessionInKeychain(Session)` instead. + } + + override fun retrieveSessionFromKeychain(userId: String): Session { + // This should never be called, as it's only used for multi-process setups + error("retrieveSessionFromKeychain should never be called for Android") + } +} diff --git a/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustMatrixClient.kt b/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustMatrixClient.kt index 089240c94c..0ff6006f36 100644 --- a/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustMatrixClient.kt +++ b/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustMatrixClient.kt @@ -51,12 +51,10 @@ import io.element.android.libraries.matrix.api.user.MatrixUser import io.element.android.libraries.matrix.api.verification.SessionVerificationService import io.element.android.libraries.matrix.impl.core.toProgressWatcher import io.element.android.libraries.matrix.impl.encryption.RustEncryptionService -import io.element.android.libraries.matrix.impl.mapper.toSessionData import io.element.android.libraries.matrix.impl.media.RustMediaLoader import io.element.android.libraries.matrix.impl.notification.RustNotificationService import io.element.android.libraries.matrix.impl.notificationsettings.RustNotificationSettingsService import io.element.android.libraries.matrix.impl.oidc.toRustAction -import io.element.android.libraries.matrix.impl.paths.getSessionPaths import io.element.android.libraries.matrix.impl.pushers.RustPushersService import io.element.android.libraries.matrix.impl.room.RoomContentForwarder import io.element.android.libraries.matrix.impl.room.RoomSyncSubscriber @@ -69,7 +67,6 @@ import io.element.android.libraries.matrix.impl.sync.RustSyncService import io.element.android.libraries.matrix.impl.usersearch.UserProfileMapper import io.element.android.libraries.matrix.impl.usersearch.UserSearchResultMapper import io.element.android.libraries.matrix.impl.util.SessionPathsProvider -import io.element.android.libraries.matrix.impl.util.anonymizedTokens import io.element.android.libraries.matrix.impl.util.cancelAndDestroy import io.element.android.libraries.matrix.impl.util.mxCallbackFlow import io.element.android.libraries.matrix.impl.verification.RustSessionVerificationService @@ -100,7 +97,6 @@ import kotlinx.coroutines.withContext import kotlinx.coroutines.withTimeout import org.matrix.rustcomponents.sdk.BackupState import org.matrix.rustcomponents.sdk.Client -import org.matrix.rustcomponents.sdk.ClientDelegate import org.matrix.rustcomponents.sdk.ClientException import org.matrix.rustcomponents.sdk.IgnoredUsersListener import org.matrix.rustcomponents.sdk.NotificationProcessSetup @@ -111,7 +107,6 @@ import org.matrix.rustcomponents.sdk.use import timber.log.Timber import java.io.File import java.util.Optional -import java.util.concurrent.atomic.AtomicBoolean import kotlin.time.Duration import kotlin.time.Duration.Companion.INFINITE import kotlin.time.Duration.Companion.seconds @@ -130,6 +125,7 @@ class RustMatrixClient( private val baseDirectory: File, baseCacheDirectory: File, private val clock: SystemClock, + sessionDelegate: RustClientSessionDelegate, ) : MatrixClient { override val sessionId: UserId = UserId(client.userId()) override val deviceId: String = client.deviceId() @@ -138,8 +134,6 @@ class RustMatrixClient( private val innerRoomListService = syncService.roomListService() private val sessionDispatcher = dispatchers.io.limitedParallelism(64) - // To make sure only one coroutine affecting the token persistence can run at a time - private val tokenRefreshDispatcher = sessionDispatcher.limitedParallelism(1) private val rustSyncService = RustSyncService(syncService, sessionCoroutineScope) private val pushersService = RustPushersService( client = client, @@ -164,72 +158,6 @@ class RustMatrixClient( private val sessionPathsProvider = SessionPathsProvider(sessionStore) - private val isLoggingOut = AtomicBoolean(false) - - private val clientDelegate = object : ClientDelegate { - private val clientLog get() = Timber.tag(this@RustMatrixClient.toString()) - - override fun didReceiveAuthError(isSoftLogout: Boolean) { - clientLog.w("didReceiveAuthError(isSoftLogout=$isSoftLogout)") - if (isLoggingOut.getAndSet(true).not()) { - clientLog.v("didReceiveAuthError -> do the cleanup") - // TODO handle isSoftLogout parameter. - appCoroutineScope.launch(tokenRefreshDispatcher) { - val existingData = sessionStore.getSession(client.userId()) - val (anonymizedAccessToken, anonymizedRefreshToken) = existingData.anonymizedTokens() - clientLog.d( - "Removing session data with access token '$anonymizedAccessToken' " + - "and refresh token '$anonymizedRefreshToken'." - ) - if (existingData != null) { - // Set isTokenValid to false - val newData = client.session().toSessionData( - isTokenValid = false, - loginType = existingData.loginType, - passphrase = existingData.passphrase, - sessionPaths = existingData.getSessionPaths(), - ) - sessionStore.updateData(newData) - clientLog.d("Removed session data with access token: '$anonymizedAccessToken'.") - } else { - clientLog.d("No session data found.") - } - doLogout(doRequest = false, removeSession = false, ignoreSdkError = false) - }.invokeOnCompletion { - if (it != null) { - clientLog.e(it, "Failed to remove session data.") - } - } - } else { - clientLog.v("didReceiveAuthError -> already cleaning up") - } - } - - override fun didRefreshTokens() { - clientLog.w("didRefreshTokens()") - appCoroutineScope.launch(tokenRefreshDispatcher) { - val existingData = sessionStore.getSession(client.userId()) ?: return@launch - val (anonymizedAccessToken, anonymizedRefreshToken) = client.session().anonymizedTokens() - clientLog.d( - "Saving new session data with token: access token '$anonymizedAccessToken' and refresh token '$anonymizedRefreshToken'. " + - "Was token valid: ${existingData.isTokenValid}" - ) - val newData = client.session().toSessionData( - isTokenValid = true, - loginType = existingData.loginType, - passphrase = existingData.passphrase, - sessionPaths = existingData.getSessionPaths(), - ) - sessionStore.updateData(newData) - clientLog.d("Saved new session data with access token: '$anonymizedAccessToken'.") - }.invokeOnCompletion { - if (it != null) { - clientLog.e(it, "Failed to save new session data.") - } - } - } - } - private val roomSyncSubscriber: RoomSyncSubscriber = RoomSyncSubscriber(innerRoomListService, dispatchers) override val roomListService: RoomListService = RustRoomListService( @@ -271,7 +199,7 @@ class RustMatrixClient( private val roomMembershipObserver = RoomMembershipObserver() - private val clientDelegateTaskHandle: TaskHandle? = client.setDelegate(clientDelegate) + private val clientDelegateTaskHandle: TaskHandle? = client.setDelegate(sessionDelegate) private val _userProfile: MutableStateFlow = MutableStateFlow( MatrixUser( @@ -295,6 +223,9 @@ class RustMatrixClient( .stateIn(sessionCoroutineScope, started = SharingStarted.Eagerly, initialValue = persistentListOf()) init { + // Make sure the session delegate has a reference to the client to be able to logout on auth error + sessionDelegate.bindClient(this) + sessionCoroutineScope.launch { // Force a refresh of the profile getUserProfile() @@ -536,21 +467,11 @@ class RustMatrixClient( deleteSessionDirectory(deleteCryptoDb = false) } - override suspend fun logout(ignoreSdkError: Boolean): String? = doLogout( - doRequest = true, - removeSession = true, - ignoreSdkError = ignoreSdkError, - ) - - private suspend fun doLogout( - doRequest: Boolean, - removeSession: Boolean, - ignoreSdkError: Boolean, - ): String? { + override suspend fun logout(userInitiated: Boolean, ignoreSdkError: Boolean): String? { var result: String? = null syncService.stop() withContext(sessionDispatcher) { - if (doRequest) { + if (userInitiated) { try { result = client.logout() } catch (failure: Throwable) { @@ -564,7 +485,7 @@ class RustMatrixClient( } close() deleteSessionDirectory(deleteCryptoDb = true) - if (removeSession) { + if (userInitiated) { sessionStore.removeSession(sessionId.value) } } @@ -615,6 +536,10 @@ class RustMatrixClient( }) }.buffer(Channel.UNLIMITED) + internal fun setDelegate(delegate: RustClientSessionDelegate) { + client.setDelegate(delegate) + } + private suspend fun File.getCacheSize( includeCryptoDb: Boolean = false, ): Long = withContext(sessionDispatcher) { diff --git a/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustMatrixClientFactory.kt b/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustMatrixClientFactory.kt index c62b4a0da5..7a73c04313 100644 --- a/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustMatrixClientFactory.kt +++ b/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustMatrixClientFactory.kt @@ -56,6 +56,8 @@ class RustMatrixClientFactory @Inject constructor( private val appPreferencesStore: AppPreferencesStore, ) { suspend fun create(sessionData: SessionData): RustMatrixClient = withContext(coroutineDispatchers.io) { + val sessionDelegate = RustClientSessionDelegate(sessionStore, appCoroutineScope, coroutineDispatchers) + val client = getBaseClientBuilder( sessionPaths = sessionData.getSessionPaths(), passphrase = sessionData.passphrase, @@ -67,6 +69,7 @@ class RustMatrixClientFactory @Inject constructor( ) .homeserverUrl(sessionData.homeserverUrl) .username(sessionData.userId) + .setSessionDelegate(sessionDelegate) .use { it.build() } client.restoreSession(sessionData.toSession()) @@ -86,6 +89,7 @@ class RustMatrixClientFactory @Inject constructor( baseDirectory = baseDirectory, baseCacheDirectory = cacheDirectory, clock = clock, + sessionDelegate = sessionDelegate, ).also { Timber.tag(it.toString()).d("Creating Client with access token '$anonymizedAccessToken' and refresh token '$anonymizedRefreshToken'") } diff --git a/libraries/matrix/test/src/main/kotlin/io/element/android/libraries/matrix/test/FakeMatrixClient.kt b/libraries/matrix/test/src/main/kotlin/io/element/android/libraries/matrix/test/FakeMatrixClient.kt index 1fb8878488..132d737ff0 100644 --- a/libraries/matrix/test/src/main/kotlin/io/element/android/libraries/matrix/test/FakeMatrixClient.kt +++ b/libraries/matrix/test/src/main/kotlin/io/element/android/libraries/matrix/test/FakeMatrixClient.kt @@ -122,7 +122,7 @@ class FakeMatrixClient( var getRoomInfoFlowLambda = { _: RoomId -> flowOf>(Optional.empty()) } - var logoutLambda: (Boolean) -> String? = { + var logoutLambda: (Boolean, Boolean) -> String? = { _, _ -> null } @@ -170,8 +170,8 @@ class FakeMatrixClient( clearCacheLambda() } - override suspend fun logout(ignoreSdkError: Boolean): String? = simulateLongTask { - return logoutLambda(ignoreSdkError) + override suspend fun logout(userInitiated: Boolean, ignoreSdkError: Boolean): String? = simulateLongTask { + return logoutLambda(ignoreSdkError, userInitiated) } override fun close() = Unit