diff --git a/libraries/push/impl/src/main/kotlin/io/element/android/libraries/push/impl/push/DefaultPushHandler.kt b/libraries/push/impl/src/main/kotlin/io/element/android/libraries/push/impl/push/DefaultPushHandler.kt index 44cf6edefc..d5c3f04348 100644 --- a/libraries/push/impl/src/main/kotlin/io/element/android/libraries/push/impl/push/DefaultPushHandler.kt +++ b/libraries/push/impl/src/main/kotlin/io/element/android/libraries/push/impl/push/DefaultPushHandler.kt @@ -11,6 +11,7 @@ package io.element.android.libraries.push.impl.push import dev.zacsweers.metro.AppScope import dev.zacsweers.metro.ContributesBinding import dev.zacsweers.metro.SingleIn +import io.element.android.libraries.core.coroutine.CoroutineDispatchers import io.element.android.libraries.core.log.logger.LoggerTag import io.element.android.libraries.core.meta.BuildMeta import io.element.android.libraries.push.impl.db.PushRequest @@ -35,6 +36,7 @@ import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.flow.first import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext import timber.log.Timber private val loggerTag = LoggerTag("PushHandler", LoggerTag.PushLoggerTag) @@ -53,6 +55,7 @@ class DefaultPushHandler( private val workManagerScheduler: WorkManagerScheduler, private val syncPendingNotificationsRequestFactory: SyncPendingNotificationsRequestBuilder.Factory, resultProcessor: NotificationResultProcessor, + private val dispatchers: CoroutineDispatchers, ) : PushHandler { init { resultProcessor.start() @@ -64,7 +67,7 @@ class DefaultPushHandler( * @param pushData the data received in the push. * @param providerInfo the provider info. */ - override suspend fun handle(pushData: PushData, providerInfo: String): Boolean { + override suspend fun handle(pushData: PushData, providerInfo: String): Boolean = withContext(dispatchers.computation) { // Start measuring how long it takes to display a notification from when the push is received Timber.d("Calculating push-to-notification for event ${pushData.eventId}") val parent = analyticsService.startLongRunningTransaction(AnalyticsLongRunningTransaction.PushToNotification(pushData.eventId.value)) @@ -81,7 +84,7 @@ class DefaultPushHandler( } // Diagnostic Push - return if (pushData.eventId == DefaultTestPush.TEST_EVENT_ID) { + if (pushData.eventId == DefaultTestPush.TEST_EVENT_ID) { pushHistoryService.onDiagnosticPush(providerInfo) diagnosticPushHandler.handlePush() false @@ -90,7 +93,7 @@ class DefaultPushHandler( } } - override suspend fun handleInvalid(providerInfo: String, data: String) { + override suspend fun handleInvalid(providerInfo: String, data: String) = withContext(dispatchers.computation) { incrementPushDataStore.incrementPushCounter() pushHistoryService.onInvalidPushReceived(providerInfo, data) } diff --git a/libraries/push/impl/src/test/kotlin/io/element/android/libraries/push/impl/push/DefaultPushHandlerTest.kt b/libraries/push/impl/src/test/kotlin/io/element/android/libraries/push/impl/push/DefaultPushHandlerTest.kt index a16568d400..f0dee4446c 100644 --- a/libraries/push/impl/src/test/kotlin/io/element/android/libraries/push/impl/push/DefaultPushHandlerTest.kt +++ b/libraries/push/impl/src/test/kotlin/io/element/android/libraries/push/impl/push/DefaultPushHandlerTest.kt @@ -11,6 +11,7 @@ package io.element.android.libraries.push.impl.push import app.cash.turbine.test +import io.element.android.libraries.core.coroutine.CoroutineDispatchers import io.element.android.libraries.core.meta.BuildMeta import io.element.android.libraries.matrix.api.core.EventId import io.element.android.libraries.matrix.api.core.RoomId @@ -40,7 +41,9 @@ import io.element.android.services.toolbox.test.systemclock.FakeSystemClock import io.element.android.tests.testutils.lambda.lambdaError import io.element.android.tests.testutils.lambda.lambdaRecorder import io.element.android.tests.testutils.lambda.value +import io.element.android.tests.testutils.testCoroutineDispatchers import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.test.TestScope import kotlinx.coroutines.test.advanceTimeBy import kotlinx.coroutines.test.runCurrent import kotlinx.coroutines.test.runTest @@ -212,7 +215,7 @@ class DefaultPushHandlerTest { .isCalledOnce() } - private fun createDefaultPushHandler( + private fun TestScope.createDefaultPushHandler( incrementPushCounterResult: () -> Unit = { lambdaError() }, userPushStore: FakeUserPushStore = FakeUserPushStore(), pushClientSecret: PushClientSecret = FakePushClientSecret(), @@ -227,6 +230,7 @@ class DefaultPushHandlerTest { start = {}, stop = {}, ), + dispatchers: CoroutineDispatchers = testCoroutineDispatchers(), ): DefaultPushHandler { return DefaultPushHandler( incrementPushDataStore = object : IncrementPushDataStore { @@ -246,7 +250,8 @@ class DefaultPushHandlerTest { resultProcessor = resultProcessor, syncPendingNotificationsRequestFactory = SyncPendingNotificationsRequestBuilder.Factory { FakeSyncPendingNotificationsRequestBuilder() - } + }, + dispatchers = dispatchers, ) } }