|
| 1 | +package io.quarkus.arc.kotlin |
| 2 | + |
| 3 | +import io.quarkus.arc.Arc |
| 4 | +import io.quarkus.arc.InjectableContext |
| 5 | +import io.quarkus.arc.ManagedContext |
| 6 | +import kotlin.coroutines.CoroutineContext |
| 7 | +import kotlin.coroutines.EmptyCoroutineContext |
| 8 | +import kotlinx.coroutines.CoroutineScope |
| 9 | +import kotlinx.coroutines.CoroutineStart |
| 10 | +import kotlinx.coroutines.Deferred |
| 11 | +import kotlinx.coroutines.ThreadContextElement |
| 12 | +import kotlinx.coroutines.async |
| 13 | +import kotlinx.coroutines.withContext |
| 14 | + |
| 15 | +/** |
| 16 | + * A suspending function that executes a block of code within the Quarkus Request Context. |
| 17 | + * |
| 18 | + * This function captures the current request context and ensures it is activated when the coroutine |
| 19 | + * resumes on a thread. |
| 20 | + * |
| 21 | + * If the request context is finalized before the block completes, it results in undefined behavior. |
| 22 | + * |
| 23 | + * Will not start a request context if there is none active at the time of invocation. |
| 24 | + * |
| 25 | + * @param context The CoroutineContext to use for the coroutine. |
| 26 | + * @param block The block of code to execute within the request context. |
| 27 | + * @return The result of the block execution. |
| 28 | + */ |
| 29 | +suspend fun <T> withPropagatedContext( |
| 30 | + context: CoroutineContext, |
| 31 | + block: suspend CoroutineScope.() -> T, |
| 32 | +): T { |
| 33 | + return withContext(context = context.appendRequestContextToCoroutineContext(), block = block) |
| 34 | +} |
| 35 | + |
| 36 | +/** |
| 37 | + * An async function that executes a block of code within the Quarkus Request Context. |
| 38 | + * |
| 39 | + * This function captures the current request context and ensures it is activated when the coroutine |
| 40 | + * resumes on a thread. |
| 41 | + * |
| 42 | + * If the caller finalizes the request context before the block is executed, results in undefined |
| 43 | + * behavior. |
| 44 | + * |
| 45 | + * Will not start a request context if there is none active at the time of invocation. |
| 46 | + * |
| 47 | + * @param context The CoroutineContext to use for the coroutine. |
| 48 | + * @param block The block of code to execute within the request context. |
| 49 | + */ |
| 50 | +fun <T> CoroutineScope.asyncWithPropagatedContext( |
| 51 | + context: CoroutineContext = EmptyCoroutineContext, |
| 52 | + start: CoroutineStart = CoroutineStart.DEFAULT, |
| 53 | + block: suspend CoroutineScope.() -> T, |
| 54 | +): Deferred<T> { |
| 55 | + return async( |
| 56 | + context = context.appendRequestContextToCoroutineContext(), |
| 57 | + start = start, |
| 58 | + block = block, |
| 59 | + ) |
| 60 | +} |
| 61 | + |
| 62 | +fun CoroutineContext.appendRequestContextToCoroutineContext(): CoroutineContext { |
| 63 | + val requestContext: ManagedContext? = Arc.container()?.requestContext() |
| 64 | + return if (requestContext == null) { |
| 65 | + this |
| 66 | + } else { |
| 67 | + this + RequestContextCoroutineContext(requestContext = requestContext) |
| 68 | + } |
| 69 | +} |
| 70 | + |
| 71 | +/** |
| 72 | + * A CoroutineContext.Element to propagate the Quarkus Request Context. |
| 73 | + * |
| 74 | + * This element captures the active request context when a coroutine is launched and ensures it is |
| 75 | + * activated whenever the coroutine resumes on a thread. |
| 76 | + * |
| 77 | + * @param requestContext The Quarkus ManagedContext for the request scope. |
| 78 | + */ |
| 79 | +class RequestContextCoroutineContext(private val requestContext: ManagedContext) : |
| 80 | + ThreadContextElement<RequestContextCoroutineContext.ContextSnapshot> { |
| 81 | + |
| 82 | + private val state: InjectableContext.ContextState? = requestContext.stateIfActive |
| 83 | + private val classLoader: ClassLoader = Thread.currentThread().contextClassLoader |
| 84 | + |
| 85 | + fun InjectableContext.ContextState?.isNullOrInvalid(): Boolean { |
| 86 | + return this == null || !this.isValid |
| 87 | + } |
| 88 | + |
| 89 | + /** A companion object to act as the Key for this context element. */ |
| 90 | + companion object Key : CoroutineContext.Key<RequestContextCoroutineContext> |
| 91 | + |
| 92 | + /** The key that identifies this element in a CoroutineContext. */ |
| 93 | + override val key: CoroutineContext.Key<*> |
| 94 | + get() = Key |
| 95 | + |
| 96 | + /** |
| 97 | + * This function is invoked when the coroutine resumes execution on a thread. It activates the |
| 98 | + * captured request context. |
| 99 | + * |
| 100 | + * @param context The coroutine context. |
| 101 | + * @return The state of the request context *before* this element activated its captured state. |
| 102 | + * This is used by `restoreThreadContext` to correctly reset the context later. |
| 103 | + */ |
| 104 | + override fun updateThreadContext(context: CoroutineContext): ContextSnapshot { |
| 105 | + // Capture the state of the current thread's context before we change it. |
| 106 | + val oldState = requestContext.stateIfActive |
| 107 | + |
| 108 | + val oldClassLoader = Thread.currentThread().contextClassLoader |
| 109 | + |
| 110 | + Thread.currentThread().contextClassLoader = classLoader |
| 111 | + |
| 112 | + // If the coroutine was launched from a thread without an active request context, |
| 113 | + // we should deactivate any context that might be active on the current thread. |
| 114 | + if (state.isNullOrInvalid()) { |
| 115 | + requestContext.deactivate() |
| 116 | + } else { |
| 117 | + // Activate the request context that we captured when the coroutine was created. |
| 118 | + requestContext.activate(state) |
| 119 | + } |
| 120 | + |
| 121 | + return ContextSnapshot(oldState, oldClassLoader) |
| 122 | + } |
| 123 | + |
| 124 | + /** |
| 125 | + * This function is invoked when the coroutine suspends or completes. It restores the request |
| 126 | + * context of the thread to its original state. |
| 127 | + * |
| 128 | + * @param context The coroutine context. |
| 129 | + * @param oldState The state that was returned by `updateThreadContext`. |
| 130 | + */ |
| 131 | + override fun restoreThreadContext(context: CoroutineContext, oldState: ContextSnapshot) { |
| 132 | + |
| 133 | + Thread.currentThread().contextClassLoader = oldState.classLoader |
| 134 | + |
| 135 | + // We must restore the request context on the thread to whatever it was before |
| 136 | + // this coroutine resumed. |
| 137 | + val oldContext = oldState.contextState |
| 138 | + if (oldContext.isNullOrInvalid()) { |
| 139 | + requestContext.deactivate() |
| 140 | + } else { |
| 141 | + requestContext.activate(oldContext) |
| 142 | + } |
| 143 | + } |
| 144 | + |
| 145 | + data class ContextSnapshot( |
| 146 | + val contextState: InjectableContext.ContextState? = null, |
| 147 | + val classLoader: ClassLoader, |
| 148 | + ) |
| 149 | +} |
0 commit comments