Merge pull request #6682 from element-hq/feature/bma/customMasScheme

Add a way to tweak MAS url.
This commit is contained in:
Benoit Marty 2026-05-07 10:51:32 +02:00 committed by GitHub
commit 2f45ca8835
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 656 additions and 17 deletions

@ -1 +1 @@
Subproject commit cdde60c158ecd0987a3ba6fd79a4617551aff463
Subproject commit fb7e9287d9d446012925139842d9aaa8e99a74dc

View file

@ -16,6 +16,7 @@ import kotlinx.coroutines.flow.Flow
interface EnterpriseService {
val isEnterpriseBuild: Boolean
suspend fun isEnterpriseUser(sessionId: SessionId): Boolean
suspend fun tweakMasUrl(url: String, homeserver: String): String
fun defaultHomeserverList(): List<String>
suspend fun isAllowedToConnectToHomeserver(homeserverUrl: String): Boolean

View file

@ -10,6 +10,7 @@ package io.element.android.features.enterprise.api
interface SessionEnterpriseService {
suspend fun isElementCallAvailable(): Boolean
suspend fun tweakMasUrl(url: String): String
suspend fun init()
}

View file

@ -23,7 +23,7 @@ class DefaultEnterpriseService : EnterpriseService {
override val isEnterpriseBuild = false
override suspend fun isEnterpriseUser(sessionId: SessionId) = false
override suspend fun tweakMasUrl(url: String, homeserver: String) = url
override fun defaultHomeserverList(): List<String> = emptyList()
override suspend fun isAllowedToConnectToHomeserver(homeserverUrl: String) = true

View file

@ -15,5 +15,6 @@ import io.element.android.libraries.di.SessionScope
@ContributesBinding(SessionScope::class)
class DefaultSessionEnterpriseService : SessionEnterpriseService {
override suspend fun init() = Unit
override suspend fun tweakMasUrl(url: String): String = url
override suspend fun isElementCallAvailable(): Boolean = true
}

View file

@ -30,6 +30,7 @@ class FakeEnterpriseService(
private val firebasePushGatewayResult: () -> String? = { lambdaError() },
private val unifiedPushDefaultPushGatewayResult: () -> String? = { lambdaError() },
private val getNoisyNotificationChannelIdResult: (SessionId?) -> String? = { lambdaError() },
private val tweakMasUrlResult: (String, String) -> String = { _, _ -> lambdaError() },
) : EnterpriseService {
private val brandColorState = MutableStateFlow(initialBrandColor)
private val semanticColorsState = MutableStateFlow(initialSemanticColors)
@ -38,6 +39,10 @@ class FakeEnterpriseService(
isEnterpriseUserResult(sessionId)
}
override suspend fun tweakMasUrl(url: String, homeserver: String): String = simulateLongTask {
tweakMasUrlResult(url, homeserver)
}
override fun defaultHomeserverList(): List<String> {
return defaultHomeserverListResult()
}

View file

@ -14,10 +14,15 @@ import io.element.android.tests.testutils.simulateLongTask
class FakeSessionEnterpriseService(
private val isElementCallAvailableResult: () -> Boolean = { lambdaError() },
private val tweakMasUrlResult: (String) -> String = { lambdaError() },
) : SessionEnterpriseService {
override suspend fun init() {
}
override suspend fun tweakMasUrl(url: String): String = simulateLongTask {
tweakMasUrlResult(url)
}
override suspend fun isElementCallAvailable(): Boolean = simulateLongTask {
isElementCallAvailableResult()
}

View file

@ -26,6 +26,7 @@ import dev.zacsweers.metro.Assisted
import dev.zacsweers.metro.AssistedInject
import io.element.android.annotations.ContributesNode
import io.element.android.compound.theme.ElementTheme
import io.element.android.features.enterprise.api.SessionEnterpriseService
import io.element.android.features.linknewdevice.api.LinkNewDeviceEntryPoint
import io.element.android.features.linknewdevice.impl.screens.confirmation.CodeConfirmationNode
import io.element.android.features.linknewdevice.impl.screens.desktop.DesktopNoticeNode
@ -65,6 +66,7 @@ class LinkNewDeviceFlowNode(
private val sessionCoroutineScope: CoroutineScope,
private val linkNewMobileHandler: LinkNewMobileHandler,
private val linkNewDesktopHandler: LinkNewDesktopHandler,
private val sessionEnterpriseService: SessionEnterpriseService,
) : BaseFlowNode<LinkNewDeviceFlowNode.NavTarget>(
backstack = BackStack(
initialElement = NavTarget.Root,
@ -298,8 +300,12 @@ class LinkNewDeviceFlowNode(
}
}
private fun navigateToBrowser(url: String) {
activity?.openUrlInChromeCustomTab(null, darkTheme, url)
private suspend fun navigateToBrowser(url: String) {
activity?.openUrlInChromeCustomTab(
session = null,
darkTheme = darkTheme,
url = sessionEnterpriseService.tweakMasUrl(url),
)
}
@Composable

View file

@ -11,6 +11,7 @@ import androidx.arch.core.executor.testing.InstantTaskExecutorRule
import com.bumble.appyx.core.modality.BuildContext
import com.bumble.appyx.testing.junit4.util.MainDispatcherRule
import com.google.common.truth.Truth.assertThat
import io.element.android.features.enterprise.test.FakeSessionEnterpriseService
import io.element.android.features.linknewdevice.api.LinkNewDeviceEntryPoint
import io.element.android.libraries.matrix.test.FakeMatrixClient
import io.element.android.tests.testutils.lambda.lambdaError
@ -37,6 +38,7 @@ class DefaultLinkNewDeviceEntryPointTest {
sessionCoroutineScope = backgroundScope,
linkNewMobileHandler = LinkNewMobileHandler(client),
linkNewDesktopHandler = LinkNewDesktopHandler(client),
sessionEnterpriseService = FakeSessionEnterpriseService(),
)
}
val callback: LinkNewDeviceEntryPoint.Callback = object : LinkNewDeviceEntryPoint.Callback {

View file

@ -51,6 +51,7 @@ dependencies {
implementation(projects.appconfig)
implementation(projects.libraries.core)
implementation(projects.libraries.architecture)
implementation(projects.libraries.cachestore.api)
implementation(projects.libraries.matrix.api)
implementation(projects.libraries.designsystem)
implementation(projects.libraries.featureflag.api)
@ -115,6 +116,7 @@ dependencies {
testImplementation(projects.features.logout.test)
testImplementation(projects.libraries.indicator.test)
testImplementation(projects.libraries.pushproviders.test)
testImplementation(projects.libraries.cachestore.test)
testImplementation(projects.libraries.sessionStorage.test)
testImplementation(projects.services.appnavstate.impl)
testImplementation(projects.services.analytics.test)

View file

@ -19,6 +19,7 @@ import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import dev.zacsweers.metro.Inject
import io.element.android.features.enterprise.api.SessionEnterpriseService
import io.element.android.features.logout.api.direct.DirectLogoutState
import io.element.android.features.preferences.impl.utils.ShowDeveloperSettingsProvider
import io.element.android.features.rageshake.api.RageshakeFeatureAvailability
@ -55,6 +56,7 @@ class PreferencesRootPresenter(
private val rageshakeFeatureAvailability: RageshakeFeatureAvailability,
private val featureFlagService: FeatureFlagService,
private val sessionStore: SessionStore,
private val sessionEnterpriseService: SessionEnterpriseService,
) : Presenter<PreferencesRootState> {
@Composable
override fun present(): PreferencesRootState {
@ -158,6 +160,10 @@ class PreferencesRootPresenter(
private fun CoroutineScope.initAccountManagementUrl(
accountManagementUrl: MutableState<String?>,
) = launch {
accountManagementUrl.value = matrixClient.getAccountManagementUrl(null).getOrNull()
accountManagementUrl.value = matrixClient.getAccountManagementUrl(null)
.getOrNull()
?.let {
sessionEnterpriseService.tweakMasUrl(it)
}
}
}

View file

@ -14,6 +14,7 @@ import dev.zacsweers.metro.ContributesBinding
import dev.zacsweers.metro.Provider
import io.element.android.features.invite.api.SeenInvitesStore
import io.element.android.features.preferences.impl.DefaultCacheService
import io.element.android.libraries.cachestore.api.CacheStore
import io.element.android.libraries.core.coroutine.CoroutineDispatchers
import io.element.android.libraries.di.SessionScope
import io.element.android.libraries.di.annotations.ApplicationContext
@ -37,8 +38,11 @@ class DefaultClearCacheUseCase(
private val pushService: PushService,
private val seenInvitesStore: SeenInvitesStore,
private val activeRoomsHolder: ActiveRoomsHolder,
private val cacheStore: CacheStore,
) : ClearCacheUseCase {
override suspend fun invoke() = withContext(coroutineDispatchers.io) {
// Clear cache store
cacheStore.deleteAll()
// Active rooms should be disposed of before clearing the cache
activeRoomsHolder.clear(matrixClient.sessionId)
// Clear Matrix cache

View file

@ -12,6 +12,8 @@ package io.element.android.features.preferences.impl.root
import app.cash.turbine.ReceiveTurbine
import com.google.common.truth.Truth.assertThat
import io.element.android.features.enterprise.api.SessionEnterpriseService
import io.element.android.features.enterprise.test.FakeSessionEnterpriseService
import io.element.android.features.logout.api.direct.aDirectLogoutState
import io.element.android.features.preferences.impl.utils.ShowDeveloperSettingsProvider
import io.element.android.features.rageshake.api.RageshakeFeatureAvailability
@ -65,6 +67,9 @@ class PreferencesRootPresenterTest {
)
createPresenter(
matrixClient = matrixClient,
sessionEnterpriseService = FakeSessionEnterpriseService(
tweakMasUrlResult = { "tweaked $it" },
),
).test {
val initialState = awaitItem()
assertThat(initialState.myUser).isEqualTo(
@ -100,7 +105,7 @@ class PreferencesRootPresenterTest {
val finalState = awaitItem()
accountManagementUrlResult.assertions().isCalledOnce()
.with(value(null))
assertThat(finalState.accountManagementUrl).isEqualTo("null url")
assertThat(finalState.accountManagementUrl).isEqualTo("tweaked null url")
}
}
@ -327,6 +332,7 @@ class PreferencesRootPresenterTest {
indicatorService: IndicatorService = FakeIndicatorService(),
featureFlagService: FeatureFlagService = FakeFeatureFlagService(),
sessionStore: SessionStore = InMemorySessionStore(),
sessionEnterpriseService: SessionEnterpriseService = FakeSessionEnterpriseService(),
) = PreferencesRootPresenter(
matrixClient = matrixClient,
sessionVerificationService = sessionVerificationService,
@ -339,5 +345,6 @@ class PreferencesRootPresenterTest {
rageshakeFeatureAvailability = rageshakeFeatureAvailability,
featureFlagService = featureFlagService,
sessionStore = sessionStore,
sessionEnterpriseService = sessionEnterpriseService,
)
}

View file

@ -19,6 +19,8 @@ import io.element.android.libraries.matrix.test.A_SESSION_ID
import io.element.android.libraries.matrix.test.FakeMatrixClient
import io.element.android.libraries.matrix.test.room.FakeJoinedRoom
import io.element.android.libraries.push.test.FakePushService
import io.element.android.libraries.sessionstorage.test.InMemoryCacheStore
import io.element.android.libraries.sessionstorage.test.aCacheData
import io.element.android.services.appnavstate.impl.DefaultActiveRoomsHolder
import io.element.android.tests.testutils.lambda.lambdaRecorder
import io.element.android.tests.testutils.lambda.value
@ -49,6 +51,9 @@ class DefaultClearCacheUseCaseTest {
)
val seenInvitesStore = InMemorySeenInvitesStore(setOf(A_ROOM_ID))
assertThat(seenInvitesStore.seenRoomIds().first()).isNotEmpty()
val cacheStore = InMemoryCacheStore(
initialData = mapOf("key1" to aCacheData())
)
val sut = DefaultClearCacheUseCase(
context = InstrumentationRegistry.getInstrumentation().context,
matrixClient = matrixClient,
@ -58,9 +63,11 @@ class DefaultClearCacheUseCaseTest {
pushService = pushService,
seenInvitesStore = seenInvitesStore,
activeRoomsHolder = activeRoomsHolder,
cacheStore = cacheStore,
)
defaultCacheService.clearedCacheEventFlow.test {
sut.invoke()
assertThat(cacheStore.dataMap).isEmpty()
clearCacheLambda.assertions().isCalledOnce()
setIgnoreRegistrationErrorLambda.assertions().isCalledOnce()
.with(value(matrixClient.sessionId), value(false))

View file

@ -28,6 +28,7 @@ setupDependencyInjection()
dependencies {
implementation(projects.appconfig)
implementation(projects.features.enterprise.api)
implementation(projects.libraries.core)
implementation(projects.libraries.androidutils)
implementation(projects.libraries.architecture)

View file

@ -25,6 +25,7 @@ import dev.zacsweers.metro.Assisted
import dev.zacsweers.metro.AssistedInject
import io.element.android.annotations.ContributesNode
import io.element.android.compound.theme.ElementTheme
import io.element.android.features.enterprise.api.SessionEnterpriseService
import io.element.android.features.securebackup.impl.reset.password.ResetIdentityPasswordNode
import io.element.android.features.securebackup.impl.reset.root.ResetIdentityRootNode
import io.element.android.libraries.androidutils.browser.openUrlInChromeCustomTab
@ -53,6 +54,7 @@ class ResetIdentityFlowNode(
private val resetIdentityFlowManager: ResetIdentityFlowManager,
@SessionCoroutineScope
private val sessionCoroutineScope: CoroutineScope,
private val sessionEnterpriseService: SessionEnterpriseService,
) : BaseFlowNode<ResetIdentityFlowNode.NavTarget>(
backstack = BackStack(initialElement = NavTarget.Root, savedStateMap = buildContext.savedStateMap),
buildContext = buildContext,
@ -125,7 +127,8 @@ class ResetIdentityFlowNode(
}
is IdentityOAuthResetHandle -> {
Timber.d("Launching reset confirmation in MAS")
activity.openUrlInChromeCustomTab(null, darkTheme, handle.url)
val url = sessionEnterpriseService.tweakMasUrl(handle.url)
activity.openUrlInChromeCustomTab(null, darkTheme, url)
Timber.d("Starting resetOAuth")
resetJob = launch { handle.resetOAuth() }
resetJob?.invokeOnCompletion { Timber.d("resetOAuth ended") }

View file

@ -0,0 +1,13 @@
/*
* Copyright (c) 2026 Element Creations Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
plugins {
id("io.element.android-library")
}
android {
namespace = "io.element.android.libraries.cachestore.api"
}

View file

@ -0,0 +1,13 @@
/*
* Copyright (c) 2026 Element Creations Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.cachestore.api
data class CacheData(
val value: String,
val updatedAt: Long,
)

View file

@ -0,0 +1,15 @@
/*
* Copyright (c) 2026 Element Creations Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.cachestore.api
interface CacheStore {
suspend fun storeData(key: String, data: CacheData)
suspend fun getData(key: String): CacheData?
suspend fun deleteData(key: String)
suspend fun deleteAll()
}

View file

@ -0,0 +1,48 @@
import extension.setupDependencyInjection
import extension.testCommonDependencies
/*
* Copyright (c) 2026 Element Creations Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
plugins {
id("io.element.android-library")
alias(libs.plugins.sqldelight)
}
android {
namespace = "io.element.android.libraries.cachestore.impl"
}
setupDependencyInjection()
dependencies {
implementation(projects.libraries.androidutils)
implementation(projects.libraries.core)
implementation(projects.libraries.encryptedDb)
api(projects.libraries.cachestore.api)
implementation(libs.sqldelight.driver.android)
implementation(libs.sqlcipher)
implementation(libs.sqlite)
implementation(projects.libraries.di)
implementation(libs.sqldelight.coroutines)
testCommonDependencies(libs)
testImplementation(libs.sqldelight.driver.jvm)
}
sqldelight {
databases {
create("CacheDatabase") {
// https://sqldelight.github.io/sqldelight/2.1.0/android_sqlite/migrations/
// To generate a .db file from your latest schema, run this task
// ./gradlew generateDebugCacheDatabaseSchema
// Test migration by running
// ./gradlew verifySqlDelightMigration
schemaOutputDirectory = File("src/main/sqldelight/databases")
verifyMigrations = true
}
}
}

View file

@ -0,0 +1,26 @@
/*
* Copyright (c) 2026 Element Creations Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.cachestore.impl
import io.element.android.libraries.cachestore.api.CacheData
import io.element.android.libraries.cachestore.CacheData as DbCacheData
internal fun CacheData.toDbModel(key: String): DbCacheData {
return DbCacheData(
key = key,
value_ = value,
updatedAt = updatedAt,
)
}
internal fun DbCacheData.toApiModel(): CacheData {
return CacheData(
value = value_,
updatedAt = updatedAt,
)
}

View file

@ -0,0 +1,40 @@
/*
* Copyright (c) 2026 Element Creations Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.cachestore.impl
import dev.zacsweers.metro.AppScope
import dev.zacsweers.metro.ContributesBinding
import dev.zacsweers.metro.SingleIn
import io.element.android.libraries.cachestore.api.CacheData
import io.element.android.libraries.cachestore.api.CacheStore
@SingleIn(AppScope::class)
@ContributesBinding(AppScope::class)
class DatabaseCacheStore(
private val database: CacheDatabase,
) : CacheStore {
override suspend fun getData(key: String): CacheData? {
return database.cacheDataQueries.selectData(key)
.executeAsOneOrNull()
?.toApiModel()
}
override suspend fun storeData(key: String, data: CacheData) {
database.cacheDataQueries.insertData(
data.toDbModel(key)
).await()
}
override suspend fun deleteData(key: String) {
database.cacheDataQueries.deleteData(key).await()
}
override suspend fun deleteAll() {
database.cacheDataQueries.deleteAll().await()
}
}

View file

@ -0,0 +1,43 @@
/*
* Copyright (c) 2026 Element Creations Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.cachestore.impl.di
import android.content.Context
import dev.zacsweers.metro.AppScope
import dev.zacsweers.metro.BindingContainer
import dev.zacsweers.metro.ContributesTo
import dev.zacsweers.metro.Provides
import dev.zacsweers.metro.SingleIn
import io.element.android.libraries.cachestore.impl.CacheDatabase
import io.element.android.libraries.di.annotations.ApplicationContext
import io.element.encrypteddb.SqlCipherDriverFactory
import io.element.encrypteddb.passphrase.RandomSecretPassphraseProvider
@BindingContainer
@ContributesTo(AppScope::class)
object CacheStoreModule {
@Provides
@SingleIn(AppScope::class)
fun provideCacheDatabase(
@ApplicationContext context: Context,
): CacheDatabase {
val name = "cache_database"
val secretFile = context.getDatabasePath("$name.key")
// Make sure the parent directory of the key file exists, otherwise it will crash in older Android versions
val parentDir = secretFile.parentFile
if (parentDir != null && !parentDir.exists()) {
parentDir.mkdirs()
}
val passphraseProvider = RandomSecretPassphraseProvider(context, secretFile)
val driver = SqlCipherDriverFactory(passphraseProvider)
.create(CacheDatabase.Schema, "$name.db", context)
return CacheDatabase(driver)
}
}

View file

@ -0,0 +1,28 @@
--------------------------------------------------------------------
-- Current version of the DB is the highest value of filename
-- in the folder `sqldelight/databases`.
--
-- When upgrading the schema, you have to create a file .sqm in the
-- `sqldelight/databases` folder and run the following task to
-- generate a .db file using the latest schema
-- > ./gradlew generateDebugCacheDatabaseSchema
--------------------------------------------------------------------
CREATE TABLE CacheData (
key TEXT NOT NULL PRIMARY KEY,
value TEXT NOT NULL,
updatedAt INTEGER NOT NULL
);
selectData:
SELECT * FROM CacheData WHERE key = ?;
insertData:
INSERT OR REPLACE INTO CacheData VALUES ?;
deleteData:
DELETE FROM CacheData WHERE key = ?;
deleteAll:
DELETE FROM CacheData;

View file

@ -0,0 +1,86 @@
/*
* Copyright (c) 2026 Element Creations Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.sessionstorage.impl
import app.cash.sqldelight.driver.jdbc.sqlite.JdbcSqliteDriver
import com.google.common.truth.Truth.assertThat
import io.element.android.libraries.cachestore.api.CacheData
import io.element.android.libraries.cachestore.impl.CacheDatabase
import io.element.android.libraries.cachestore.impl.DatabaseCacheStore
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.runTest
import org.junit.Before
import org.junit.Test
import io.element.android.libraries.cachestore.CacheData as DbCacheData
private const val A_KEY = "aKey"
private const val A_DATA_1 = "aData1"
private const val A_DATA_2 = "aData2"
class DatabaseCacheStoreTest {
private lateinit var database: CacheDatabase
private lateinit var databaseCacheStore: DatabaseCacheStore
@OptIn(ExperimentalCoroutinesApi::class)
@Before
fun setup() {
// Initialise in memory SQLite driver
val driver = JdbcSqliteDriver(JdbcSqliteDriver.IN_MEMORY)
CacheDatabase.Schema.create(driver)
database = CacheDatabase(driver)
databaseCacheStore = DatabaseCacheStore(
database = database,
)
}
@Test
fun `storeData persists the CacheData into the DB, deleteData deletes it`() = runTest {
// Assert that no data is stored for the key
assertThat(database.cacheDataQueries.selectData(A_KEY).executeAsOneOrNull()).isNull()
// Store data
databaseCacheStore.storeData(A_KEY, CacheData(A_DATA_1, 1))
assertThat(database.cacheDataQueries.selectData(A_KEY).executeAsOneOrNull()).isEqualTo(
DbCacheData(
key = A_KEY,
value_ = A_DATA_1,
updatedAt = 1,
)
)
// Update data
databaseCacheStore.storeData(A_KEY, CacheData(A_DATA_2, 2))
assertThat(database.cacheDataQueries.selectData(A_KEY).executeAsOneOrNull()).isEqualTo(
DbCacheData(
key = A_KEY,
value_ = A_DATA_2,
updatedAt = 2,
)
)
// Delete data
databaseCacheStore.deleteData(A_KEY)
assertThat(database.cacheDataQueries.selectData(A_KEY).executeAsOneOrNull()).isNull()
}
@Test
fun `deleteAll deletes all the data`() = runTest {
// Assert that no data is stored for the key
assertThat(database.cacheDataQueries.selectData(A_KEY).executeAsOneOrNull()).isNull()
// Store data
databaseCacheStore.storeData(A_KEY, CacheData(A_DATA_1, 1))
assertThat(database.cacheDataQueries.selectData(A_KEY).executeAsOneOrNull()).isEqualTo(
DbCacheData(
key = A_KEY,
value_ = A_DATA_1,
updatedAt = 1,
)
)
// Delete all data
databaseCacheStore.deleteAll()
assertThat(database.cacheDataQueries.selectData(A_KEY).executeAsOneOrNull()).isNull()
}
}

View file

@ -0,0 +1,21 @@
/*
* Copyright (c) 2026 Element Creations Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.sessionstorage.impl
import io.element.android.libraries.cachestore.CacheData
import java.util.Date
internal fun aCacheData(
key: String = "aKey",
value: String = "aValue",
updatedAt: Date = Date(),
) = CacheData(
key = key,
value_ = value,
updatedAt = updatedAt.time,
)

View file

@ -0,0 +1,17 @@
/*
* Copyright (c) 2026 Element Creations Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
plugins {
id("io.element.android-library")
}
android {
namespace = "io.element.android.libraries.cachestore.test"
}
dependencies {
implementation(projects.libraries.cachestore.api)
}

View file

@ -0,0 +1,18 @@
/*
* Copyright (c) 2026 Element Creations Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.sessionstorage.test
import io.element.android.libraries.cachestore.api.CacheData
fun aCacheData(
value: String = "aValue",
updatedAt: Long = 0,
) = CacheData(
value = value,
updatedAt = updatedAt,
)

View file

@ -0,0 +1,33 @@
/*
* Copyright (c) 2026 Element Creations Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.sessionstorage.test
import io.element.android.libraries.cachestore.api.CacheData
import io.element.android.libraries.cachestore.api.CacheStore
class InMemoryCacheStore(
initialData: Map<String, CacheData> = emptyMap(),
) : CacheStore {
val dataMap = initialData.toMutableMap()
override suspend fun storeData(key: String, data: CacheData) {
dataMap[key] = data
}
override suspend fun getData(key: String): CacheData? {
return dataMap[key]
}
override suspend fun deleteData(key: String) {
dataMap.remove(key)
}
override suspend fun deleteAll() {
dataMap.clear()
}
}

View file

@ -31,6 +31,7 @@ dependencies {
implementation(projects.libraries.rustlsTls)
implementation(projects.appconfig)
implementation(projects.features.enterprise.api)
implementation(projects.libraries.androidutils)
implementation(projects.libraries.architecture)
implementation(projects.libraries.di)
@ -49,6 +50,7 @@ dependencies {
implementation(libs.kotlinx.collections.immutable)
testCommonDependencies(libs)
testImplementation(projects.features.enterprise.test)
testImplementation(projects.libraries.featureflag.test)
testImplementation(projects.libraries.matrix.test)
testImplementation(projects.libraries.preferences.test)

View file

@ -11,6 +11,7 @@ package io.element.android.libraries.matrix.impl.auth
import dev.zacsweers.metro.AppScope
import dev.zacsweers.metro.ContributesBinding
import dev.zacsweers.metro.SingleIn
import io.element.android.features.enterprise.api.EnterpriseService
import io.element.android.libraries.core.coroutine.CoroutineDispatchers
import io.element.android.libraries.core.extensions.mapFailure
import io.element.android.libraries.core.extensions.runCatchingExceptions
@ -66,6 +67,7 @@ class RustMatrixAuthenticationService(
private val rustMatrixClientFactory: RustMatrixClientFactory,
private val passphraseGenerator: PassphraseGenerator,
private val oAuthConfigurationProvider: OAuthConfigurationProvider,
private val enterpriseService: EnterpriseService,
) : MatrixAuthenticationService {
// Any existing Element Classic session that we want to try to import secrets from during login.
private var elementClassicSession: ElementClassicSession? = null
@ -269,6 +271,12 @@ class RustMatrixAuthenticationService(
additionalScopes = emptyList(),
)
val url = oAuthAuthorizationData.loginUrl()
.let {
enterpriseService.tweakMasUrl(
url = it,
homeserver = client.server() ?: client.homeserver(),
)
}
pendingOAuthAuthorizationData = oAuthAuthorizationData
OAuthDetails(url)
}.mapFailure { failure ->

View file

@ -9,6 +9,8 @@
package io.element.android.libraries.matrix.impl.auth
import com.google.common.truth.Truth.assertThat
import io.element.android.features.enterprise.api.EnterpriseService
import io.element.android.features.enterprise.test.FakeEnterpriseService
import io.element.android.libraries.matrix.impl.ClientBuilderProvider
import io.element.android.libraries.matrix.impl.FakeClientBuilderProvider
import io.element.android.libraries.matrix.impl.createRustMatrixClientFactory
@ -50,6 +52,7 @@ class RustMatrixAuthenticationServiceTest {
private fun TestScope.createRustMatrixAuthenticationService(
sessionStore: SessionStore = InMemorySessionStore(),
clientBuilderProvider: ClientBuilderProvider = FakeClientBuilderProvider(),
enterpriseService: EnterpriseService = FakeEnterpriseService(),
): RustMatrixAuthenticationService {
val baseDirectory = File("/base")
val cacheDirectory = File("/cache")
@ -68,6 +71,7 @@ class RustMatrixAuthenticationServiceTest {
buildMeta = aBuildMeta(),
oAuthRedirectUrlProvider = FakeOAuthRedirectUrlProvider(),
),
enterpriseService = enterpriseService,
)
}
}

View file

@ -14,4 +14,5 @@ data class ElementWellKnown(
val rageshakeUrl: String?,
val brandColor: String?,
val notificationSound: String?,
val identityProviderAppScheme: String?,
)

View file

@ -33,9 +33,13 @@ dependencies {
implementation(projects.libraries.architecture)
implementation(projects.libraries.matrix.api)
implementation(projects.libraries.network)
implementation(projects.libraries.cachestore.api)
implementation(projects.services.toolbox.api)
testCommonDependencies(libs)
testImplementation(libs.coroutines.core)
testImplementation(projects.libraries.cachestore.test)
testImplementation(projects.libraries.matrix.test)
testImplementation(projects.libraries.wellknown.test)
testImplementation(projects.services.toolbox.test)
}

View file

@ -10,29 +10,70 @@ package io.element.android.libraries.wellknown.impl
import dev.zacsweers.metro.ContributesBinding
import io.element.android.libraries.androidutils.json.JsonProvider
import io.element.android.libraries.cachestore.api.CacheData
import io.element.android.libraries.cachestore.api.CacheStore
import io.element.android.libraries.core.extensions.mapCatchingExceptions
import io.element.android.libraries.di.SessionScope
import io.element.android.libraries.di.annotations.SessionCoroutineScope
import io.element.android.libraries.matrix.api.MatrixClient
import io.element.android.libraries.matrix.api.exception.ClientException
import io.element.android.libraries.wellknown.api.ElementWellKnown
import io.element.android.libraries.wellknown.api.SessionWellknownRetriever
import io.element.android.libraries.wellknown.api.WellknownRetrieverResult
import io.element.android.services.toolbox.api.systemclock.SystemClock
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch
import timber.log.Timber
@ContributesBinding(SessionScope::class)
class DefaultSessionWellknownRetriever(
private val matrixClient: MatrixClient,
private val json: JsonProvider,
private val cacheStore: CacheStore,
private val systemClock: SystemClock,
@SessionCoroutineScope
private val sessionCoroutineScope: CoroutineScope,
) : SessionWellknownRetriever {
private val domain by lazy { matrixClient.userIdServerName() }
override suspend fun getElementWellKnown(): WellknownRetrieverResult<ElementWellKnown> {
val url = "https://$domain/.well-known/element/element.json"
val cacheData = cacheStore.getData(url)
if (cacheData != null) {
Timber.d("Element .well-known data retrieved from cache for $domain")
// If the cache is outdated, trigger a refresh in background but still return the cached value
if (systemClock.epochMillis() > cacheData.updatedAt + CACHE_VALIDITY_MILLIS) {
sessionCoroutineScope.launch {
fetchElementWellKnown(url)
}
}
try {
val parsed = json().decodeFromString<InternalElementWellKnown>(cacheData.value).map()
return WellknownRetrieverResult.Success(parsed)
} catch (e: Exception) {
Timber.e(e, "Failed to parse cached Element .well-known data for $domain, deleting cache")
cacheStore.deleteData(url)
}
}
return fetchElementWellKnown(url)
}
private suspend fun fetchElementWellKnown(url: String): WellknownRetrieverResult<ElementWellKnown> {
return matrixClient
.getUrl(url)
.mapCatchingExceptions {
val data = String(it)
json().decodeFromString<InternalElementWellKnown>(data).map()
val parsed = json().decodeFromString<InternalElementWellKnown>(data).map()
// Also store in cache, if valid
cacheStore.storeData(
key = url,
data = CacheData(
value = data,
updatedAt = systemClock.epochMillis(),
)
)
parsed
}
.toWellknownRetrieverResult()
}
@ -51,4 +92,9 @@ class DefaultSessionWellknownRetriever(
}
}
)
companion object {
// 1 day
private const val CACHE_VALIDITY_MILLIS = 1 * 24 * 60 * 60 * 1000L
}
}

View file

@ -32,4 +32,6 @@ data class InternalElementWellKnown(
val brandColor: String? = null,
@SerialName("notification_sound")
val notificationSound: String? = null,
@SerialName("idp_app_scheme")
val identityProviderAppScheme: String? = null,
)

View file

@ -16,4 +16,5 @@ internal fun InternalElementWellKnown.map() = ElementWellKnown(
rageshakeUrl = rageshakeUrl,
brandColor = brandColor,
notificationSound = notificationSound,
identityProviderAppScheme = identityProviderAppScheme,
)

View file

@ -6,16 +6,30 @@
* Please see LICENSE files in the repository root for full details.
*/
@file:OptIn(ExperimentalCoroutinesApi::class)
package io.element.android.libraries.wellknown.impl
import com.google.common.truth.Truth.assertThat
import io.element.android.features.wellknown.test.anElementWellKnown
import io.element.android.libraries.androidutils.json.DefaultJsonProvider
import io.element.android.libraries.androidutils.json.JsonProvider
import io.element.android.libraries.cachestore.api.CacheData
import io.element.android.libraries.cachestore.api.CacheStore
import io.element.android.libraries.matrix.test.AN_EXCEPTION
import io.element.android.libraries.matrix.test.FakeMatrixClient
import io.element.android.libraries.sessionstorage.test.InMemoryCacheStore
import io.element.android.libraries.wellknown.api.ElementWellKnown
import io.element.android.libraries.wellknown.api.WellknownRetrieverResult
import io.element.android.services.toolbox.api.systemclock.SystemClock
import io.element.android.services.toolbox.test.systemclock.A_FAKE_TIMESTAMP
import io.element.android.services.toolbox.test.systemclock.FakeSystemClock
import io.element.android.tests.testutils.lambda.lambdaError
import io.element.android.tests.testutils.lambda.lambdaRecorder
import io.element.android.tests.testutils.lambda.value
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.TestScope
import kotlinx.coroutines.test.runCurrent
import kotlinx.coroutines.test.runTest
import org.junit.Test
@ -36,6 +50,7 @@ class DefaultSessionWellknownRetrieverTest {
rageshakeUrl = null,
brandColor = null,
notificationSound = null,
identityProviderAppScheme = null,
)
)
)
@ -48,13 +63,7 @@ class DefaultSessionWellknownRetrieverTest {
val sut = createDefaultSessionWellknownRetriever(
getUrlLambda = {
Result.success(
"""{
"registration_helper_url": "a_registration_url",
"enforce_element_pro": true,
"rageshake_url": "a_rageshake_url",
"brand_color": "#FF0000",
"notification_sound": "a_notification_sound.flac"
}""".trimIndent().toByteArray()
WELLKNOWN_CONTENT.toByteArray()
)
}
)
@ -66,6 +75,7 @@ class DefaultSessionWellknownRetrieverTest {
rageshakeUrl = "a_rageshake_url",
brandColor = "#FF0000",
notificationSound = "a_notification_sound.flac",
identityProviderAppScheme = "an_app_scheme",
)
)
)
@ -94,6 +104,7 @@ class DefaultSessionWellknownRetrieverTest {
rageshakeUrl = "a_rageshake_url",
brandColor = null,
notificationSound = null,
identityProviderAppScheme = null,
)
)
)
@ -124,13 +135,118 @@ class DefaultSessionWellknownRetrieverTest {
assertThat(sut.getElementWellKnown()).isInstanceOf(WellknownRetrieverResult.Error::class.java)
}
private fun createDefaultSessionWellknownRetriever(
@Test
fun `get element wellknown hitting cache`() = runTest {
val sut = createDefaultSessionWellknownRetriever(
getUrlLambda = { lambdaError() },
cacheStore = InMemoryCacheStore(
initialData = mapOf(
WELLKNOWN_URL to CacheData(
value = WELLKNOWN_CONTENT,
updatedAt = A_FAKE_TIMESTAMP,
)
)
)
)
assertThat(sut.getElementWellKnown()).isEqualTo(
WellknownRetrieverResult.Success(
ElementWellKnown(
registrationHelperUrl = "a_registration_url",
enforceElementPro = true,
rageshakeUrl = "a_rageshake_url",
brandColor = "#FF0000",
notificationSound = "a_notification_sound.flac",
identityProviderAppScheme = "an_app_scheme",
)
)
)
}
@Test
fun `get element wellknown hitting cache containing invalid json`() = runTest {
val cacheStore = InMemoryCacheStore(
initialData = mapOf(
WELLKNOWN_URL to CacheData(
value = WELLKNOWN_CONTENT,
updatedAt = A_FAKE_TIMESTAMP,
)
)
)
val sut = createDefaultSessionWellknownRetriever(
getUrlLambda = {
Result.success("{}".toByteArray())
},
cacheStore = cacheStore,
jsonProvider = JsonProvider { error("Failed to parse JSON") }
)
assertThat(sut.getElementWellKnown()).isInstanceOf(WellknownRetrieverResult.Error::class.java)
// Ensure that the cache is deleted after the failure to parse it
assertThat(cacheStore.dataMap).isEmpty()
}
@Test
fun `get element wellknown hitting outdated cache`() = runTest {
val sut = createDefaultSessionWellknownRetriever(
getUrlLambda = {
Result.success("{}".toByteArray())
},
cacheStore = InMemoryCacheStore(
initialData = mapOf(
WELLKNOWN_URL to CacheData(
value = WELLKNOWN_CONTENT,
updatedAt = 0L,
)
),
),
// 3 days later, so the cache is outdated
systemClock = FakeSystemClock(3 * 24 * 60 * 60 * 1000L)
)
assertThat(sut.getElementWellKnown()).isEqualTo(
WellknownRetrieverResult.Success(
ElementWellKnown(
registrationHelperUrl = "a_registration_url",
enforceElementPro = true,
rageshakeUrl = "a_rageshake_url",
brandColor = "#FF0000",
notificationSound = "a_notification_sound.flac",
identityProviderAppScheme = "an_app_scheme",
)
)
)
// Next call returns the updated value
runCurrent()
assertThat(sut.getElementWellKnown()).isEqualTo(
WellknownRetrieverResult.Success(
anElementWellKnown()
)
)
}
private fun TestScope.createDefaultSessionWellknownRetriever(
getUrlLambda: (String) -> Result<ByteArray>,
jsonProvider: JsonProvider = DefaultJsonProvider(),
cacheStore: CacheStore = InMemoryCacheStore(),
systemClock: SystemClock = FakeSystemClock(),
) = DefaultSessionWellknownRetriever(
matrixClient = FakeMatrixClient(
userIdServerNameLambda = { "user.domain.org" },
getUrlLambda = getUrlLambda,
),
json = DefaultJsonProvider(),
json = jsonProvider,
cacheStore = cacheStore,
systemClock = systemClock,
sessionCoroutineScope = backgroundScope,
)
companion object {
private const val WELLKNOWN_URL = "https://user.domain.org/.well-known/element/element.json"
private const val WELLKNOWN_CONTENT = """{
"registration_helper_url": "a_registration_url",
"enforce_element_pro": true,
"rageshake_url": "a_rageshake_url",
"brand_color": "#FF0000",
"notification_sound": "a_notification_sound.flac",
"idp_app_scheme": "an_app_scheme"
}"""
}
}

View file

@ -16,10 +16,12 @@ fun anElementWellKnown(
rageshakeUrl: String? = null,
brandColor: String? = null,
notificationSound: String? = null,
identityProviderAppScheme: String? = null,
) = ElementWellKnown(
registrationHelperUrl = registrationHelperUrl,
enforceElementPro = enforceElementPro,
rageshakeUrl = rageshakeUrl,
brandColor = brandColor,
notificationSound = notificationSound,
identityProviderAppScheme = identityProviderAppScheme,
)

View file

@ -104,6 +104,7 @@ fun DependencyHandlerScope.allLibrariesImpl() {
implementation(project(":libraries:architecture"))
implementation(project(":libraries:dateformatter:impl"))
implementation(project(":libraries:di"))
implementation(project(":libraries:cachestore:impl"))
implementation(project(":libraries:session-storage:impl"))
implementation(project(":libraries:mediapickers:impl"))
implementation(project(":libraries:mediaupload:impl"))