diff --git a/tests/multivalentTests/src/com/android/launcher3/icons/FastBitmapDrawableTest.java b/tests/multivalentTests/src/com/android/launcher3/icons/FastBitmapDrawableTest.java index fbbfb2af48..7e9b68d231 100644 --- a/tests/multivalentTests/src/com/android/launcher3/icons/FastBitmapDrawableTest.java +++ b/tests/multivalentTests/src/com/android/launcher3/icons/FastBitmapDrawableTest.java @@ -37,11 +37,15 @@ import android.view.animation.DecelerateInterpolator; import android.view.animation.PathInterpolator; import androidx.test.annotation.UiThreadTest; +import androidx.test.ext.junit.runners.AndroidJUnit4; import androidx.test.filters.SmallTest; -import androidx.test.runner.AndroidJUnit4; + +import com.android.launcher3.util.rule.RobolectricUiThreadRule; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.TestRule; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -51,11 +55,14 @@ import org.mockito.Spy; * Tests for FastBitmapDrawable. */ @SmallTest -@UiThreadTest @RunWith(AndroidJUnit4.class) +@UiThreadTest public class FastBitmapDrawableTest { private static final float EPSILON = 0.00001f; + @Rule + public final TestRule roboUiThreadRule = new RobolectricUiThreadRule(); + @Spy FastBitmapDrawable mFastBitmapDrawable = spy(new FastBitmapDrawable(Bitmap.createBitmap(100, 100, Bitmap.Config.ARGB_8888))); diff --git a/tests/multivalentTests/src/com/android/launcher3/logging/StartupLatencyLoggerTest.kt b/tests/multivalentTests/src/com/android/launcher3/logging/StartupLatencyLoggerTest.kt index 130dfad2ac..12f6c8cbed 100644 --- a/tests/multivalentTests/src/com/android/launcher3/logging/StartupLatencyLoggerTest.kt +++ b/tests/multivalentTests/src/com/android/launcher3/logging/StartupLatencyLoggerTest.kt @@ -4,8 +4,10 @@ import androidx.core.util.isEmpty import androidx.test.annotation.UiThreadTest import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.filters.SmallTest +import com.android.launcher3.util.rule.RobolectricUiThreadRule import com.google.common.truth.Truth.assertThat import org.junit.Before +import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith @@ -14,6 +16,8 @@ import org.junit.runner.RunWith @RunWith(AndroidJUnit4::class) class StartupLatencyLoggerTest { + @get:Rule val roboUiThreadRule = RobolectricUiThreadRule() + private val underTest = ColdRebootStartupLatencyLogger() @Before diff --git a/tests/multivalentTests/src/com/android/launcher3/util/rule/RobolectricUiThreadRule.kt b/tests/multivalentTests/src/com/android/launcher3/util/rule/RobolectricUiThreadRule.kt new file mode 100644 index 0000000000..18cd1e45a7 --- /dev/null +++ b/tests/multivalentTests/src/com/android/launcher3/util/rule/RobolectricUiThreadRule.kt @@ -0,0 +1,90 @@ +/* + * Copyright (C) 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.launcher3.util.rule + +import androidx.test.annotation.UiThreadTest +import androidx.test.platform.app.InstrumentationRegistry +import java.util.Locale +import java.util.concurrent.atomic.AtomicReference +import org.junit.rules.TestRule +import org.junit.runner.Description +import org.junit.runners.model.Statement + +/** + * A test rule to add support for @UiThreadTest annotations when running in robolectric until is it + * natively supported by the robolectric runner: + * https://github.com/robolectric/robolectric/issues/9026 + */ +class RobolectricUiThreadRule : TestRule { + + override fun apply(base: Statement, description: Description): Statement = + if (!shouldRunOnUiThread(description)) base else UiThreadStatement(base) + + private fun shouldRunOnUiThread(description: Description): Boolean { + if (!isRunningInRobolectric()) { + // If not running in robolectric, let the default runner handle this + return false + } + var clazz = description.testClass + try { + if ( + clazz + .getDeclaredMethod(description.methodName) + .getAnnotation(UiThreadTest::class.java) != null + ) { + return true + } + } catch (_: Exception) { + // Ignore + } + + while (!clazz.isAnnotationPresent(UiThreadTest::class.java)) { + clazz = clazz.superclass ?: return false + } + return true + } + + private fun isRunningInRobolectric(): Boolean { + if ( + System.getProperty("java.runtime.name") + .lowercase(Locale.getDefault()) + .contains("android") + ) + return false + return try { + // Check if robolectric runner exists + Class.forName("org.robolectric.RobolectricTestRunner") != null + } catch (e: ClassNotFoundException) { + false + } + } + + private class UiThreadStatement(val base: Statement) : Statement() { + + override fun evaluate() { + val exceptionRef = AtomicReference() + InstrumentationRegistry.getInstrumentation().runOnMainSync { + try { + base.evaluate() + } catch (throwable: Throwable) { + exceptionRef.set(throwable) + } + } + exceptionRef.get()?.let { throw it } + } + } +}