diff --git a/features/login/impl/src/main/kotlin/io/element/android/features/login/impl/screens/createaccount/CreateAccountPresenter.kt b/features/login/impl/src/main/kotlin/io/element/android/features/login/impl/screens/createaccount/CreateAccountPresenter.kt index f7a23df7d1..bf7af111f4 100644 --- a/features/login/impl/src/main/kotlin/io/element/android/features/login/impl/screens/createaccount/CreateAccountPresenter.kt +++ b/features/login/impl/src/main/kotlin/io/element/android/features/login/impl/screens/createaccount/CreateAccountPresenter.kt @@ -19,24 +19,18 @@ import dev.zacsweers.metro.AssistedFactory import dev.zacsweers.metro.AssistedInject import io.element.android.libraries.architecture.AsyncAction import io.element.android.libraries.architecture.Presenter -import io.element.android.libraries.core.data.tryOrNull import io.element.android.libraries.core.extensions.flatMap import io.element.android.libraries.core.extensions.runCatchingExceptions import io.element.android.libraries.core.meta.BuildMeta -import io.element.android.libraries.matrix.api.MatrixClientProvider import io.element.android.libraries.matrix.api.auth.MatrixAuthenticationService import io.element.android.libraries.matrix.api.core.SessionId import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.flow.first import kotlinx.coroutines.launch -import kotlinx.coroutines.withTimeout -import kotlin.time.Duration.Companion.seconds @AssistedInject class CreateAccountPresenter( @Assisted private val url: String, private val authenticationService: MatrixAuthenticationService, - private val clientProvider: MatrixClientProvider, private val messageParser: MessageParser, private val buildMeta: BuildMeta, ) : Presenter { @@ -80,12 +74,6 @@ class CreateAccountPresenter( }.flatMap { externalSession -> authenticationService.importCreatedSession(externalSession) }.onSuccess { sessionId -> - tryOrNull { - // Wait until the session is verified - val client = clientProvider.getOrRestore(sessionId).getOrThrow() - val sessionVerificationService = client.sessionVerificationService - withTimeout(10.seconds) { sessionVerificationService.sessionVerifiedStatus.first { it.isVerified() } } - } loggedInState.value = AsyncAction.Success(sessionId) }.onFailure { failure -> loggedInState.value = AsyncAction.Failure(failure) diff --git a/features/login/impl/src/test/kotlin/io/element/android/features/login/impl/screens/createaccount/CreateAccountPresenterTest.kt b/features/login/impl/src/test/kotlin/io/element/android/features/login/impl/screens/createaccount/CreateAccountPresenterTest.kt index a9d3c02368..c70a5c9a9a 100644 --- a/features/login/impl/src/test/kotlin/io/element/android/features/login/impl/screens/createaccount/CreateAccountPresenterTest.kt +++ b/features/login/impl/src/test/kotlin/io/element/android/features/login/impl/screens/createaccount/CreateAccountPresenterTest.kt @@ -16,8 +16,6 @@ import io.element.android.libraries.matrix.api.auth.external.ExternalSession import io.element.android.libraries.matrix.api.verification.SessionVerifiedStatus import io.element.android.libraries.matrix.test.AN_EXCEPTION import io.element.android.libraries.matrix.test.A_SESSION_ID -import io.element.android.libraries.matrix.test.FakeMatrixClient -import io.element.android.libraries.matrix.test.FakeMatrixClientProvider import io.element.android.libraries.matrix.test.auth.FakeMatrixAuthenticationService import io.element.android.libraries.matrix.test.core.aBuildMeta import io.element.android.libraries.matrix.test.verification.FakeSessionVerificationService @@ -80,14 +78,11 @@ class CreateAccountPresenterTest { fun `present - receiving a message able to be parsed change the state to success`() = runTest { val lambda = lambdaRecorder { _ -> anExternalSession() } val sessionVerificationService = FakeSessionVerificationService() - val client = FakeMatrixClient(sessionVerificationService = sessionVerificationService) - val clientProvider = FakeMatrixClientProvider(getClient = { Result.success(client) }) val presenter = createPresenter( authenticationService = FakeMatrixAuthenticationService( importCreatedSessionLambda = { Result.success(A_SESSION_ID) } ), messageParser = FakeMessageParser(lambda), - clientProvider = clientProvider, ) presenter.test { val initialState = awaitItem() @@ -120,12 +115,10 @@ class CreateAccountPresenterTest { authenticationService: MatrixAuthenticationService = FakeMatrixAuthenticationService(), messageParser: MessageParser = FakeMessageParser(), buildMeta: BuildMeta = aBuildMeta(), - clientProvider: FakeMatrixClientProvider = FakeMatrixClientProvider(), ) = CreateAccountPresenter( url = url, authenticationService = authenticationService, messageParser = messageParser, buildMeta = buildMeta, - clientProvider = clientProvider, ) } diff --git a/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustMatrixClientFactory.kt b/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustMatrixClientFactory.kt index 88207049d2..5932acec20 100644 --- a/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustMatrixClientFactory.kt +++ b/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustMatrixClientFactory.kt @@ -193,7 +193,7 @@ sealed interface ClientBuilderSlidingSync { data object Native : ClientBuilderSlidingSync } -private fun SessionData.toSession() = Session( +fun SessionData.toSession() = Session( accessToken = accessToken, refreshToken = refreshToken, userId = userId, diff --git a/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/auth/RustMatrixAuthenticationService.kt b/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/auth/RustMatrixAuthenticationService.kt index 97719bed36..7cd9fbedf5 100644 --- a/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/auth/RustMatrixAuthenticationService.kt +++ b/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/auth/RustMatrixAuthenticationService.kt @@ -25,6 +25,7 @@ import io.element.android.libraries.matrix.api.auth.external.ExternalSession import io.element.android.libraries.matrix.api.auth.qrlogin.MatrixQrCodeLoginData import io.element.android.libraries.matrix.api.auth.qrlogin.QrCodeLoginStep import io.element.android.libraries.matrix.api.core.SessionId +import io.element.android.libraries.matrix.api.verification.SessionVerifiedStatus import io.element.android.libraries.matrix.impl.ClientBuilderSlidingSync import io.element.android.libraries.matrix.impl.RustMatrixClientFactory import io.element.android.libraries.matrix.impl.auth.qrlogin.QrErrorMapper @@ -35,10 +36,13 @@ import io.element.android.libraries.matrix.impl.keys.PassphraseGenerator import io.element.android.libraries.matrix.impl.mapper.toSessionData import io.element.android.libraries.matrix.impl.paths.SessionPaths import io.element.android.libraries.matrix.impl.paths.SessionPathsFactory +import io.element.android.libraries.matrix.impl.toSession import io.element.android.libraries.sessionstorage.api.LoginType import io.element.android.libraries.sessionstorage.api.SessionStore import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.flow.first import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeoutOrNull import org.matrix.rustcomponents.sdk.Client import org.matrix.rustcomponents.sdk.ClientBuilder import org.matrix.rustcomponents.sdk.HumanQrLoginException @@ -48,6 +52,7 @@ import org.matrix.rustcomponents.sdk.QrLoginProgress import org.matrix.rustcomponents.sdk.QrLoginProgressListener import timber.log.Timber import uniffi.matrix_sdk.OAuthAuthorizationData +import kotlin.time.Duration.Companion.seconds @ContributesBinding(AppScope::class) @SingleIn(AppScope::class) @@ -160,7 +165,7 @@ class RustMatrixAuthenticationService( override suspend fun importCreatedSession(externalSession: ExternalSession): Result = withContext(coroutineDispatchers.io) { runCatchingExceptions { - currentClient ?: error("You need to call `setHomeserver()` first") + val client = currentClient ?: error("You need to call `setHomeserver()` first") val currentSessionPaths = sessionPaths ?: error("You need to call `setHomeserver()` first") val sessionData = externalSession.toSessionData( isTokenValid = true, @@ -168,8 +173,21 @@ class RustMatrixAuthenticationService( passphrase = pendingPassphrase, sessionPaths = currentSessionPaths, ) - clear() + + // We restore the client using the just retrieved session data + client.restoreSession(sessionData.toSession()) + val matrixClient = rustMatrixClientFactory.create(client) + + // We wait for the verification state to be known + matrixClient.waitForKnownVerificationState() + + // And once it's ready we share it and save the actual session data + newMatrixClientObservers.forEach { it.invoke(matrixClient) } sessionStore.addSession(sessionData) + + // Clean up the strong reference held here since it's no longer necessary + currentClient = null + SessionId(sessionData.userId) } } @@ -238,6 +256,8 @@ class RustMatrixAuthenticationService( sessionPaths = currentSessionPaths, ) val matrixClient = rustMatrixClientFactory.create(client) + matrixClient.waitForKnownVerificationState() + newMatrixClientObservers.forEach { it.invoke(matrixClient) } sessionStore.addSession(sessionData) @@ -356,4 +376,12 @@ class RustMatrixAuthenticationService( currentClient?.close() currentClient = null } + + private suspend fun MatrixClient.waitForKnownVerificationState() { + withTimeoutOrNull(10.seconds) { + Timber.d("Waiting for a known verification status...") + val status = sessionVerificationService.sessionVerifiedStatus.first { it != SessionVerifiedStatus.Unknown } + Timber.d("Finished waiting for a known verification status: $status") + } ?: Timber.w("Timed out waiting for a known verification status") + } } diff --git a/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/verification/RustSessionVerificationService.kt b/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/verification/RustSessionVerificationService.kt index 3014618b9d..f613ce8dff 100644 --- a/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/verification/RustSessionVerificationService.kt +++ b/libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/verification/RustSessionVerificationService.kt @@ -44,6 +44,7 @@ import org.matrix.rustcomponents.sdk.VerificationState import org.matrix.rustcomponents.sdk.VerificationStateListener import org.matrix.rustcomponents.sdk.use import timber.log.Timber +import java.util.concurrent.atomic.AtomicBoolean import kotlin.time.Duration.Companion.seconds import org.matrix.rustcomponents.sdk.SessionVerificationData as RustSessionVerificationData import org.matrix.rustcomponents.sdk.SessionVerificationRequestDetails as RustSessionVerificationRequestDetails @@ -66,9 +67,16 @@ class RustSessionVerificationService( private val recoveryState = MutableStateFlow(RecoveryState.UNKNOWN) + private val isInitialized = AtomicBoolean(false) + // Listen for changes in verification status and update accordingly private val verificationStateListenerTaskHandle = encryptionService.verificationStateListener(object : VerificationStateListener { override fun onUpdate(status: VerificationState) { + if (!isInitialized.get()) { + Timber.d("Discarding new verifications state: $status. E2EE is not initialised yet") + return + } + Timber.d("New verification state: $status") _sessionVerifiedStatus.value = status.map() } @@ -77,6 +85,11 @@ class RustSessionVerificationService( // In case we enter the recovery key instead we check changes in the recovery state, since the listener above won't be triggered private val recoveryStateListenerTaskHandle = encryptionService.recoveryStateListener(object : RecoveryStateListener { override fun onUpdate(status: RecoveryState) { + if (!isInitialized.get()) { + Timber.d("Discarding new recovery state: $status. E2EE is not initialised yet") + return + } + Timber.d("New recovery state: $status") // We could check the `RecoveryState`, but it's easier to just use the verification state directly recoveryState.value = status @@ -87,7 +100,7 @@ class RustSessionVerificationService( * The internal service that checks verification can only run after the initial sync. * This [StateFlow] will notify consumers when the service is ready to be used. */ - private val isReady = isSyncServiceReady.stateIn(sessionCoroutineScope, SharingStarted.Eagerly, false) + private val canVerify = isSyncServiceReady.stateIn(sessionCoroutineScope, SharingStarted.Eagerly, false) override val needsSessionVerification = sessionVerifiedStatus.map { verificationStatus -> verificationStatus == SessionVerifiedStatus.NotVerified @@ -99,14 +112,11 @@ class RustSessionVerificationService( private var listener: SessionVerificationServiceListener? = null + private val initializationMutex = Mutex() + init { // Instantiate the verification controller when possible, this is needed to get incoming verification requests - sessionCoroutineScope.launch { - tryOrNull { - encryptionService.waitForE2eeInitializationTasks() - initVerificationControllerIfNeeded() - } - } + sessionCoroutineScope.launch { ensureEncryptionIsInitialized() } } override fun setListener(listener: SessionVerificationServiceListener?) { @@ -114,13 +124,13 @@ class RustSessionVerificationService( } override suspend fun requestCurrentSessionVerification() = tryOrFail { - initVerificationControllerIfNeeded() + ensureEncryptionIsInitialized() verificationController.requestDeviceVerification() currentVerificationRequest = VerificationRequest.Outgoing.CurrentSession } override suspend fun requestUserVerification(userId: UserId) = tryOrFail { - initVerificationControllerIfNeeded() + ensureEncryptionIsInitialized() verificationController.requestUserVerification(userId.value) currentVerificationRequest = VerificationRequest.Outgoing.User(userId) } @@ -140,7 +150,7 @@ class RustSessionVerificationService( } override suspend fun acknowledgeVerificationRequest(verificationRequest: VerificationRequest.Incoming) = tryOrFail { - initVerificationControllerIfNeeded() + ensureEncryptionIsInitialized() verificationController.acknowledgeVerificationRequest( senderId = verificationRequest.details.senderProfile.userId.value, flowId = verificationRequest.details.flowId.value, @@ -225,7 +235,7 @@ class RustSessionVerificationService( override suspend fun reset(cancelAnyPendingVerificationAttempt: Boolean) { currentVerificationRequest = null - if (isReady.value && cancelAnyPendingVerificationAttempt) { + if (canVerify.value && cancelAnyPendingVerificationAttempt) { // Cancel any pending verification attempt tryOrNull { verificationController.cancelVerification() } } @@ -241,23 +251,28 @@ class RustSessionVerificationService( } } - private var initControllerMutex = Mutex() - - private suspend fun initVerificationControllerIfNeeded() = initControllerMutex.withLock { - if (!this::verificationController.isInitialized) { - tryOrFail { - verificationController = client.getSessionVerificationController() - verificationController.setDelegate(this) - } - } - } - private fun updateVerificationStatus() { runCatchingExceptions { _sessionVerifiedStatus.value = encryptionService.verificationState().map() Timber.d("New verification status: ${_sessionVerifiedStatus.value}") } } + + private suspend fun ensureEncryptionIsInitialized() = initializationMutex.withLock { + // We're keeping the separate checks instead of unconditionally calling the suspend methods + // so we can skip crossing the FFI layer when it's not needed + tryOrFail { + if (!isInitialized.get()) { + encryptionService.waitForE2eeInitializationTasks() + isInitialized.set(true) + } + + if (!this::verificationController.isInitialized) { + verificationController = client.getSessionVerificationController() + verificationController.setDelegate(this) + } + } + } } private fun VerificationState.map() = when (this) {