diff --git a/extensions/kotlin/deployment/pom.xml b/extensions/kotlin/deployment/pom.xml index dbd398037a393..1f73426ae3768 100644 --- a/extensions/kotlin/deployment/pom.xml +++ b/extensions/kotlin/deployment/pom.xml @@ -33,6 +33,30 @@ quarkus-vertx-kotlin-deployment true + + + io.quarkus + quarkus-junit5-internal + test + + + + io.quarkus + quarkus-arc-test-supplement + test + + + + io.quarkus + quarkus-arc-test-supplement-decorator + test + + + + org.jetbrains.kotlinx + kotlinx-coroutines-test + test + @@ -54,6 +78,33 @@ + + org.jetbrains.kotlin + kotlin-maven-plugin + ${kotlin.version} + + + compile + + compile + + + + test-compile + + test-compile + + + + src/test/kotlin + + + + + + ${maven.compiler.target} + + diff --git a/extensions/kotlin/deployment/src/test/kotlin/io/quarkus/kotlin/arc/RequestContextCoroutineContextTest.kt b/extensions/kotlin/deployment/src/test/kotlin/io/quarkus/kotlin/arc/RequestContextCoroutineContextTest.kt new file mode 100644 index 0000000000000..9f05a9afb6f8a --- /dev/null +++ b/extensions/kotlin/deployment/src/test/kotlin/io/quarkus/kotlin/arc/RequestContextCoroutineContextTest.kt @@ -0,0 +1,517 @@ +package io.quarkus.kotlin.arc + +import io.quarkus.arc.Arc +import io.quarkus.test.QuarkusUnitTest +import jakarta.enterprise.context.RequestScoped +import jakarta.enterprise.context.control.ActivateRequestContext +import jakarta.inject.Inject +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext +import org.jboss.shrinkwrap.api.spec.JavaArchive +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertNotEquals +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension +import kotlin.coroutines.EmptyCoroutineContext + +class RequestContextCoroutineContextTest { + + companion object { + @RegisterExtension + @JvmStatic + val TEST = + QuarkusUnitTest().withApplicationRoot { jar: JavaArchive -> + jar.addClasses(RequestData::class.java) + } + } + + @Inject + lateinit var requestData: RequestData + + lateinit var expectedClassLoader: ClassLoader + + @BeforeEach + @AfterEach + fun setUp() { + this.expectedClassLoader = Thread.currentThread().contextClassLoader + } + + private fun assertThatCallersClassLoaderIsExpected() { + assertEquals( + expectedClassLoader, + Thread.currentThread().contextClassLoader, + "Thread context class loader should be the expected one", + ) + } + + @Test + @ActivateRequestContext + fun `caller with active scope maintained in runBlocking`() = runBlocking { + assertTrue( + Arc.container().requestContext().isActive, + "Request context should be active", + ) + } + + @Test + @ActivateRequestContext + fun `runBlocking with active request`() { + // GIVEN an active request context + assertTrue( + Arc.container().requestContext().isActive, + "Request context should be active", + ) + + // AND a given number and an expected post-async number + val givenNumber = 1234L + val expectedPostAsyncNumber = 5432L + + // AND we set the number value in the request data + requestData.numberValue = givenNumber + + // WHEN we run a block with the request context + + runBlocking(context = EmptyCoroutineContext.withCdiContext()) { + assertTrue( + Arc.container().requestContext().isActive, + "Request context should be active", + ) + + // AND the number value should match the given number + assertEquals(givenNumber, requestData.numberValue) + + assertThatCallersClassLoaderIsExpected() + + // WHEN we set the number value to the expected post-async number + requestData.numberValue = expectedPostAsyncNumber + } + + // THEN the number value should match the expected post-async number after the block + // execution + assertEquals(expectedPostAsyncNumber, requestData.numberValue) + } + + @Test + fun `caller without active scope maintained in runBlocking`() = runBlocking { + assertFalse( + Arc.container().requestContext().isActive, + "Request context should not be active", + ) + } + + @Test + fun `without an active request scope on withContext`() { + // GIVEN no active request context + assertFalse( + Arc.container().requestContext().isActive, + "Request context should not be active", + ) + + // WHEN we run a block + runTest { + withContext(Dispatchers.IO.withCdiContext()) { + // THEN the request context should not be active + assertFalse( + Arc.container().requestContext().isActive, + "Request context should not be active", + ) + + assertThatCallersClassLoaderIsExpected() + } + } + } + + @Test + fun `without an active request scope on async`() { + // GIVEN no active request context + assertFalse( + Arc.container().requestContext().isActive, + "Request context should not be active", + ) + + // WHEN we run a block with async + runTest { + coroutineScope { + async(Dispatchers.IO.withCdiContext()) { + // THEN the request context should not be active + assertFalse( + Arc.container().requestContext().isActive, + "Request context should not be active", + ) + + assertThatCallersClassLoaderIsExpected() + } + .await() + } + } + } + + @Test + @ActivateRequestContext + fun `with an active request scope on withContext`() { + // GIVEN an active request context + assertTrue( + Arc.container().requestContext().isActive, + "Request context should be active", + ) + + // AND a given number and an expected post-async number + val givenNumber = 1234L + val expectedPostAsyncNumber = 5432L + + // AND we set the number value in the request data + requestData.numberValue = givenNumber + + // WHEN we run a block with the request context + runTest { + withContext(Dispatchers.IO.withCdiContext()) { + // THEN the request context should be active + assertTrue( + Arc.container().requestContext().isActive, + "Request context should be active", + ) + + // AND the number value should match the given number + assertEquals(givenNumber, requestData.numberValue) + + assertThatCallersClassLoaderIsExpected() + + // AND the number value should match the given number after a short delay + delay(10) + assertEquals(givenNumber, requestData.numberValue) + + assertThatCallersClassLoaderIsExpected() + + // WHEN we set the number value to the expected post-async number + requestData.numberValue = expectedPostAsyncNumber + + // THEN the number value should match the expected post-async number after a short + // delay + delay(10) + assertEquals(expectedPostAsyncNumber, requestData.numberValue) + + assertThatCallersClassLoaderIsExpected() + } + } + + // THEN the number value should match the expected post-async number after the block + // execution + assertEquals(expectedPostAsyncNumber, requestData.numberValue) + } + + @Test + @ActivateRequestContext + fun `with an active request scope on async`() { + // GIVEN an active request context + assertTrue( + Arc.container().requestContext().isActive, + "Request context should be active", + ) + + // AND a given number and an expected post-async number + val givenNumber = 1234L + val expectedPostAsyncNumber = 5432L + + // AND we set the number value in the request data + requestData.numberValue = givenNumber + + // WHEN we run a block with async + runTest { + coroutineScope { + async(Dispatchers.IO.withCdiContext()) { + // THEN the request context should be active + assertTrue( + Arc.container().requestContext().isActive, + "Request context should be active", + ) + + // AND the number value should match the given number + assertEquals(givenNumber, requestData.numberValue) + + assertThatCallersClassLoaderIsExpected() + + // AND the number value should match the given number after a short delay + delay(10) + assertEquals(givenNumber, requestData.numberValue) + + assertThatCallersClassLoaderIsExpected() + + // WHEN we set the number value to the expected post-async number + requestData.numberValue = expectedPostAsyncNumber + + // THEN the number value should match the expected post-async number after a + // short delay + delay(10) + assertEquals(expectedPostAsyncNumber, requestData.numberValue) + + assertThatCallersClassLoaderIsExpected() + } + .await() + } + } + + // THEN the number value should match the expected post-async number after the block + // execution + assertEquals(expectedPostAsyncNumber, requestData.numberValue) + } + + @Test + @ActivateRequestContext + fun `with an active request scope on inner coroutine scope in async`() { + // GIVEN an active request context + assertTrue( + Arc.container().requestContext().isActive, + "Request context should be active", + ) + + // AND a given number and an expected post-async number + val givenNumber = 1234L + val expectedPostAsyncNumber = 5432L + + // AND we set the number value in the request data + requestData.numberValue = givenNumber + + // WHEN we run a block with async + runTest { + coroutineScope { + async(Dispatchers.IO.withCdiContext()) { + coroutineScope { + // THEN the request context should be active + assertTrue( + Arc.container().requestContext().isActive, + "Request context should be active", + ) + + // AND the number value should match the given number + assertEquals(givenNumber, requestData.numberValue) + + assertThatCallersClassLoaderIsExpected() + + // AND the number value should match the given number after a short + // delay + delay(10) + assertEquals(givenNumber, requestData.numberValue) + + assertThatCallersClassLoaderIsExpected() + + // WHEN we set the number value to the expected post-async number + requestData.numberValue = expectedPostAsyncNumber + + // THEN the number value should match the expected post-async number + // after a short delay + delay(10) + assertEquals( + expectedPostAsyncNumber, + requestData.numberValue, + ) + + assertThatCallersClassLoaderIsExpected() + } + } + .await() + } + } + + // THEN the number value should match the expected post-async number after the block + // execution + assertEquals(expectedPostAsyncNumber, requestData.numberValue) + } + + @Test + @ActivateRequestContext + fun `with a terminated request scope while on async (undefined behavior)`() { + // GIVEN an active request context + assertTrue( + Arc.container().requestContext().isActive, + "Request context should be active", + ) + + // AND a given number + val givenNumber = 1234L + + // AND we set the number value in the request data + requestData.numberValue = givenNumber + + val asyncStarted = CompletableDeferred() + val requestScopeTerminated = CompletableDeferred() + + // WHEN we run a block with async + runTest { + val job = launch { + async(Dispatchers.IO.withCdiContext()) { + // THEN the request context should be active + assertTrue( + Arc.container().requestContext().isActive, + "Request context should be active", + ) + + // AND the number value should match the given number + assertEquals(givenNumber, requestData.numberValue) + + assertThatCallersClassLoaderIsExpected() + + asyncStarted.complete(Unit) + + // WHEN we wait for the request context to be terminated by another thread + requestScopeTerminated.await() + + // THEN the request context should not be active (undefined behavior) + assertFalse( + Arc.container().requestContext().isActive, + "Request context should not be active", + ) + } + .await() + } + + launch { + asyncStarted.await() + + assertTrue( + Arc.container().requestContext().isActive, + "Request context should be active", + ) + assertEquals(givenNumber, requestData.numberValue) + Arc.container().requestContext().terminate() + assertFalse( + Arc.container().requestContext().isActive, + "Request context should be active", + ) + + requestScopeTerminated.complete(Unit) + } + + job.join() + } + } + + @Test + fun `with two active request scope on async on same coroutine`() { + // GIVEN two active request contexts with different giveNumbers + Arc.container().requestContext().activate() + + val firstRequestState = Arc.container().requestContext().stateIfActive + assertNotNull(firstRequestState) + + val firstGivenNumber = 1234L + val expectedFirstGivenNumber = 91234L + requestData.numberValue = firstGivenNumber + + Arc.container().requestContext().activate() + + val secondRequestState = Arc.container().requestContext().stateIfActive + assertNotNull(secondRequestState) + assertNotEquals(firstGivenNumber, requestData.numberValue) + + val secondGivenNumber = 5432L + val expectedSecondGivenNumber = 95432L + requestData.numberValue = secondGivenNumber + + assertNotEquals(firstRequestState, secondRequestState) + + val waitForFirstEnd = CompletableDeferred() + val waitForSecondMiddle = CompletableDeferred() + val waitForSecondEnd = CompletableDeferred() + + // WHEN we run a block with async + runTest { + val jobFirstRequest = launch { + Arc.container().requestContext().activate(firstRequestState) + async(Dispatchers.IO.withCdiContext()) { + // THEN the request context should be active + assertTrue( + Arc.container().requestContext().isActive, + "Request context should be active", + ) + + // AND the number value should match the given number + assertEquals(firstGivenNumber, requestData.numberValue) + + assertThatCallersClassLoaderIsExpected() + + waitForSecondMiddle.await() + + // THEN the request context should be active + assertTrue( + Arc.container().requestContext().isActive, + "Request context should be active", + ) + + // AND the number value should match the given number + assertEquals(firstGivenNumber, requestData.numberValue) + + assertThatCallersClassLoaderIsExpected() + + requestData.numberValue = expectedFirstGivenNumber + + waitForFirstEnd.complete(Unit) + } + .await() + } + + val jobSecondRequest = launch { + Arc.container().requestContext().activate(secondRequestState) + async(Dispatchers.IO.withCdiContext()) { + // THEN the request context should be active + assertTrue( + Arc.container().requestContext().isActive, + "Request context should be active", + ) + + // AND the number value should match the given number + assertEquals(secondGivenNumber, requestData.numberValue) + + assertThatCallersClassLoaderIsExpected() + + waitForSecondMiddle.complete(Unit) + + waitForFirstEnd.await() + + // THEN the request context should be active + assertTrue( + Arc.container().requestContext().isActive, + "Request context should be active", + ) + + // AND the number value should match the given number + assertEquals(secondGivenNumber, requestData.numberValue) + + assertThatCallersClassLoaderIsExpected() + + requestData.numberValue = expectedSecondGivenNumber + + waitForSecondEnd.complete(Unit) + } + .await() + } + + jobSecondRequest.join() + jobFirstRequest.join() + } + + Arc.container().requestContext().activate(secondRequestState) + assertEquals(expectedSecondGivenNumber, requestData.numberValue) + Arc.container().requestContext().terminate() + + Arc.container().requestContext().activate(firstRequestState) + assertEquals(expectedFirstGivenNumber, requestData.numberValue) + Arc.container().requestContext().terminate() + } + + @RequestScoped + class RequestData { + var numberValue = 0L + } +} diff --git a/extensions/kotlin/runtime/pom.xml b/extensions/kotlin/runtime/pom.xml index 7af7abdf3c66c..f95b671e572a8 100644 --- a/extensions/kotlin/runtime/pom.xml +++ b/extensions/kotlin/runtime/pom.xml @@ -14,6 +14,8 @@ Quarkus - Kotlin - Runtime Write your services in Kotlin + ${project.basedir}/src/main/kotlin + ${project.basedir}/src/test/kotlin io.quarkus @@ -30,6 +32,28 @@ + + org.jetbrains.kotlin + kotlin-maven-plugin + ${kotlin.version} + + + compile + + compile + + + + test-compile + + test-compile + + + + + ${maven.compiler.target} + + diff --git a/extensions/kotlin/runtime/src/main/kotlin/io/quarkus/kotlin/arc/RequestContextCoroutineContext.kt b/extensions/kotlin/runtime/src/main/kotlin/io/quarkus/kotlin/arc/RequestContextCoroutineContext.kt new file mode 100644 index 0000000000000..76f77a7d6fbe9 --- /dev/null +++ b/extensions/kotlin/runtime/src/main/kotlin/io/quarkus/kotlin/arc/RequestContextCoroutineContext.kt @@ -0,0 +1,105 @@ +package io.quarkus.kotlin.arc + +import io.quarkus.arc.Arc +import io.quarkus.arc.InjectableContext +import io.quarkus.arc.ManagedContext +import kotlin.coroutines.CoroutineContext +import kotlinx.coroutines.ThreadContextElement + +/** + * This function extends the CoroutineContext to include the Quarkus Request Context if it is + * active. + * + * If the caller finalizes the request context before the coroutine resumes, it results in undefined + * behavior. + * + * Will not start a request context if there is none active at the time of invocation. + */ +fun CoroutineContext.withCdiContext(): CoroutineContext { + val requestContext: ManagedContext? = Arc.container()?.requestContext() + return if (requestContext == null) { + this + } else { + this + RequestContextCoroutineContext(requestContext = requestContext) + } +} + +/** + * A CoroutineContext.Element to propagate the Quarkus Request Context. + * + * This element captures the active request context when a coroutine is launched and ensures it is + * activated whenever the coroutine resumes on a thread. + * + * @param requestContext The Quarkus ManagedContext for the request scope. + */ +class RequestContextCoroutineContext(private val requestContext: ManagedContext) : + ThreadContextElement { + + private val state: InjectableContext.ContextState? = requestContext.stateIfActive + private val classLoader: ClassLoader = Thread.currentThread().contextClassLoader + + fun InjectableContext.ContextState?.isNullOrInvalid(): Boolean { + return this == null || !this.isValid + } + + /** A companion object to act as the Key for this context element. */ + companion object Key : CoroutineContext.Key + + /** The key that identifies this element in a CoroutineContext. */ + override val key: CoroutineContext.Key<*> + get() = Key + + /** + * This function is invoked when the coroutine resumes execution on a thread. It activates the + * captured request context. + * + * @param context The coroutine context. + * @return The state of the request context *before* this element activated its captured state. + * This is used by `restoreThreadContext` to correctly reset the context later. + */ + override fun updateThreadContext(context: CoroutineContext): ContextSnapshot { + // Capture the state of the current thread's context before we change it. + val oldState = requestContext.stateIfActive + + val oldClassLoader = Thread.currentThread().contextClassLoader + + Thread.currentThread().contextClassLoader = classLoader + + // If the coroutine was launched from a thread without an active request context, + // we should deactivate any context that might be active on the current thread. + if (state.isNullOrInvalid()) { + requestContext.deactivate() + } else { + // Activate the request context that we captured when the coroutine was created. + requestContext.activate(state) + } + + return ContextSnapshot(oldState, oldClassLoader) + } + + /** + * This function is invoked when the coroutine suspends or completes. It restores the request + * context of the thread to its original state. + * + * @param context The coroutine context. + * @param oldState The state that was returned by `updateThreadContext`. + */ + override fun restoreThreadContext(context: CoroutineContext, oldState: ContextSnapshot) { + + Thread.currentThread().contextClassLoader = oldState.classLoader + + // We must restore the request context on the thread to whatever it was before + // this coroutine resumed. + val oldContext = oldState.contextState + if (oldContext.isNullOrInvalid()) { + requestContext.deactivate() + } else { + requestContext.activate(oldContext) + } + } + + data class ContextSnapshot( + val contextState: InjectableContext.ContextState? = null, + val classLoader: ClassLoader, + ) +}