diff --git a/.changelog/1761080224.md b/.changelog/1761080224.md new file mode 100644 index 00000000000..e7c28ca564a --- /dev/null +++ b/.changelog/1761080224.md @@ -0,0 +1,13 @@ +--- +applies_to: +- server +authors: +- rcoh +references: ["smithy-rs#4356"] +breaking: true +new_feature: true +bug_fix: false +--- +Parse EventStream signed-frames for servers marked with `@sigv4`. + +This is a breaking change, because events from SigV4 services are wrapped in a SignedEvent frame. diff --git a/AGENTS.md b/AGENTS.md index f31b518ef27..65b34b033cf 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -48,6 +48,26 @@ operation MyOperation { - **`codegen-core/common-test-models/constraints.smithy`** - Constraint validation tests with restJson1 - **`codegen-client-test/model/main.smithy`** - awsJson1_1 protocol tests +### httpQueryParams Bug Investigation + +When investigating the `@httpQueryParams` bug (where query parameters weren't appearing in requests), the issue was in `RequestBindingGenerator.kt` line 173. The bug occurred when: + +1. An operation had ONLY `@httpQueryParams` (no regular `@httpQuery` parameters) +2. The condition `if (dynamicParams.isEmpty() && literalParams.isEmpty() && mapParams.isEmpty())` would skip generating the `uri_query` function + +The fix was to ensure `mapParams.isEmpty()` was included in the condition check. The current implementation correctly generates query parameters for `@httpQueryParams` even when no other query parameters exist. + +**Testing httpQueryParams**: Create operations with only `@httpQueryParams` to ensure they generate proper query strings in requests. + +## rustTemplate Formatting + +**CRITICAL**: Because `#` is the formatting character in `rustTemplate`, Rust attributes must be escaped: + +❌ Wrong: `#[derive(Debug)]` +✅ Correct: `##[derive(Debug)]` + +This applies to ALL Rust attributes: `##[non_exhaustive]`, `##[derive(...)]`, `##[cfg(...)]`, etc. + ## preludeScope: Rust Prelude Types **Always use `preludeScope` for Rust prelude types:** diff --git a/codegen-core/common-test-models/rpcv2Cbor-extras.smithy b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy index 6a964cfca3e..4cd7156f24e 100644 --- a/codegen-core/common-test-models/rpcv2Cbor-extras.smithy +++ b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy @@ -6,8 +6,10 @@ use smithy.framework#ValidationException use smithy.protocols#rpcv2Cbor use smithy.test#httpResponseTests use smithy.test#httpMalformedRequestTests +use aws.auth#sigv4 @rpcv2Cbor +@sigv4(name: "rpcv2-cbor") service RpcV2CborService { operations: [ SimpleStructOperation diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolExt.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolExt.kt index 78b348e6c2e..d9d01639563 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolExt.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolExt.kt @@ -83,9 +83,14 @@ fun Symbol.makeMaybeConstrained(): Symbol = * WARNING: This function does not update any symbol references (e.g., `symbol.addReference()`) on the * returned symbol. You will have to add those yourself if your logic relies on them. **/ -fun Symbol.mapRustType(f: (RustType) -> RustType): Symbol { +fun Symbol.mapRustType( + vararg dependencies: RuntimeType, + f: (RustType) -> RustType, +): Symbol { val newType = f(this.rustType()) - return Symbol.builder().rustType(newType) + val builder = this.toBuilder() + dependencies.forEach { builder.addReference(it.toSymbol()) } + return builder.rustType(newType) .name(newType.name) .build() } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt index bade1055401..4b6a39fc6e4 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt @@ -96,6 +96,12 @@ sealed class HttpBindingSection(name: String) : Section(name) { data class AfterDeserializingIntoADateTimeOfHttpHeaders(val memberShape: MemberShape) : HttpBindingSection("AfterDeserializingIntoADateTimeOfHttpHeaders") + + data class BeforeCreatingEventStreamReceiver( + val operationShape: OperationShape, + val unionShape: UnionShape, + val unmarshallerVariableName: String, + ) : HttpBindingSection("BeforeCreatingEventStreamReceiver") } typealias HttpBindingCustomization = NamedCustomization @@ -272,11 +278,27 @@ class HttpBindingGenerator( rustTemplate( """ let unmarshaller = #{unmarshallerConstructorFn}(); + """, + "unmarshallerConstructorFn" to unmarshallerConstructorFn, + ) + + // Allow customizations to wrap the unmarshaller + for (customization in customizations) { + customization.section( + HttpBindingSection.BeforeCreatingEventStreamReceiver( + operationShape, + targetShape, + "unmarshaller", + ), + )(this) + } + + rustTemplate( + """ let body = std::mem::replace(body, #{SdkBody}::taken()); Ok(#{receiver:W}) """, "SdkBody" to RuntimeType.sdkBody(runtimeConfig), - "unmarshallerConstructorFn" to unmarshallerConstructorFn, "receiver" to writable { if (codegenTarget == CodegenTarget.SERVER) { diff --git a/codegen-server-test/integration-tests/Cargo.lock b/codegen-server-test/integration-tests/Cargo.lock index e8c25db10ea..b752b100e92 100644 --- a/codegen-server-test/integration-tests/Cargo.lock +++ b/codegen-server-test/integration-tests/Cargo.lock @@ -59,7 +59,7 @@ dependencies = [ [[package]] name = "aws-smithy-cbor" -version = "0.61.2" +version = "0.61.3" dependencies = [ "aws-smithy-types", "minicbor", @@ -67,7 +67,7 @@ dependencies = [ [[package]] name = "aws-smithy-eventstream" -version = "0.60.12" +version = "0.60.13" dependencies = [ "aws-smithy-types", "bytes", @@ -96,7 +96,7 @@ dependencies = [ [[package]] name = "aws-smithy-http-client" -version = "1.1.3" +version = "1.1.4" dependencies = [ "aws-smithy-async", "aws-smithy-protocol-test", @@ -120,7 +120,7 @@ dependencies = [ [[package]] name = "aws-smithy-http-server" -version = "0.65.7" +version = "0.65.8" dependencies = [ "aws-smithy-cbor", "aws-smithy-http", @@ -148,7 +148,7 @@ dependencies = [ [[package]] name = "aws-smithy-json" -version = "0.61.6" +version = "0.61.7" dependencies = [ "aws-smithy-types", ] @@ -162,7 +162,7 @@ dependencies = [ [[package]] name = "aws-smithy-protocol-test" -version = "0.63.5" +version = "0.63.6" dependencies = [ "assert-json-diff", "aws-smithy-runtime-api", @@ -202,7 +202,7 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.9.1" +version = "1.9.2" dependencies = [ "aws-smithy-async", "aws-smithy-types", @@ -216,7 +216,7 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.3.3" +version = "1.3.4" dependencies = [ "base64-simd", "bytes", @@ -241,7 +241,7 @@ dependencies = [ [[package]] name = "aws-smithy-xml" -version = "0.60.11" +version = "0.60.12" dependencies = [ "xmlparser", ] diff --git a/codegen-server-test/integration-tests/eventstreams/tests/structured_eventstream_tests.rs b/codegen-server-test/integration-tests/eventstreams/tests/structured_eventstream_tests.rs index 97eb11f67d8..4f52d7141fa 100644 --- a/codegen-server-test/integration-tests/eventstreams/tests/structured_eventstream_tests.rs +++ b/codegen-server-test/integration-tests/eventstreams/tests/structured_eventstream_tests.rs @@ -143,13 +143,15 @@ async fn streaming_operation_handler( state.lock().unwrap().streaming_operation.num_calls += 1; let ev = input.events.recv().await; - if let Ok(Some(event)) = &ev { + if let Ok(Some(signed_event)) = &ev { + // Extract the actual event from the SignedEvent wrapper + let actual_event = &signed_event.message; state .lock() .unwrap() .streaming_operation .events - .push(event.clone()); + .push(actual_event.clone()); } Ok(output::StreamingOperationOutput::builder() @@ -174,13 +176,15 @@ async fn streaming_operation_with_initial_data_handler( let ev = input.events.recv().await; - if let Ok(Some(event)) = &ev { + if let Ok(Some(signed_event)) = &ev { + // Extract the actual event from the SignedEvent wrapper + let actual_event = &signed_event.message; state .lock() .unwrap() .streaming_operation_with_initial_data .events - .push(event.clone()); + .push(actual_event.clone()); } Ok(output::StreamingOperationWithInitialDataOutput::builder() @@ -229,7 +233,7 @@ async fn streaming_operation_with_optional_data_handler( .unwrap() .streaming_operation_with_optional_data .events - .push(event.clone()); + .push(event.message.clone()); } Ok(output::StreamingOperationWithOptionalDataOutput::builder() @@ -348,6 +352,39 @@ fn build_event(event_type: &str) -> Message { Message::new_from_parts(headers, empty_cbor) } +fn build_sigv4_signed_event(event_type: &str) -> Message { + use aws_smithy_eventstream::frame::write_message_to; + use std::time::{SystemTime, UNIX_EPOCH}; + + // Build the inner event message + let inner_event = build_event(event_type); + + // Serialize the inner message to bytes + let mut inner_bytes = Vec::new(); + write_message_to(&inner_event, &mut inner_bytes).unwrap(); + + // Create the SigV4 envelope with signature headers + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + let headers = vec![ + Header::new( + ":chunk-signature", + HeaderValue::ByteArray(Bytes::from( + "example298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + )), + ), + Header::new( + ":date", + HeaderValue::Timestamp(aws_smithy_types::DateTime::from_secs(timestamp as i64)), + ), + ]; + + Message::new_from_parts(headers, Bytes::from(inner_bytes)) +} + fn get_event_type(msg: &Message) -> &str { msg.headers() .iter() @@ -439,6 +476,24 @@ async fn test_streaming_operation_with_initial_data_missing() { ); } +/// Test that the server can handle SigV4 signed event stream messages. +/// The client wraps the actual event in a SigV4 envelope with signature headers. +#[tokio::test] +async fn test_sigv4_signed_event_stream() { + let mut harness = TestHarness::new("StreamingOperation").await; + + // Send a SigV4 signed event - the inner message is wrapped in an envelope + let signed_event = build_sigv4_signed_event("A"); + harness.client.send(signed_event).await.unwrap(); + + let resp = harness.expect_message().await; + assert_eq!(get_event_type(&resp), "A"); + assert_eq!( + harness.server.streaming_operation_events(), + vec![Events::A(Event {})] + ); +} + /// Test that when alwaysSendEventStreamInitialResponse is disabled, no initial-response is sent #[tokio::test] async fn test_server_no_initial_response_when_disabled() { diff --git a/codegen-server/codegen-server-typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerCodegenVisitor.kt b/codegen-server/codegen-server-typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerCodegenVisitor.kt index dea89fbc07e..53763211c73 100644 --- a/codegen-server/codegen-server-typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerCodegenVisitor.kt +++ b/codegen-server/codegen-server-typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerCodegenVisitor.kt @@ -62,7 +62,7 @@ class TsServerCodegenVisitor( ServerProtocolLoader( codegenDecorator.protocols( service.id, - ServerProtocolLoader.DefaultProtocols, + ServerProtocolLoader.defaultProtocols(), ), ) .protocolFor(context.model, service) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustServerCodegenPlugin.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustServerCodegenPlugin.kt index e6ab96bd869..a8d28741ac9 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustServerCodegenPlugin.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustServerCodegenPlugin.kt @@ -19,6 +19,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.StreamingShapeSymbolProvi import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor import software.amazon.smithy.rust.codegen.server.smithy.customizations.CustomValidationExceptionWithReasonDecorator import software.amazon.smithy.rust.codegen.server.smithy.customizations.ServerRequiredCustomizations +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SigV4EventStreamDecorator import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionDecorator import software.amazon.smithy.rust.codegen.server.smithy.customizations.UserProvidedValidationExceptionDecorator import software.amazon.smithy.rust.codegen.server.smithy.customize.CombinedServerCodegenDecorator @@ -54,6 +55,7 @@ class RustServerCodegenPlugin : ServerDecoratableBuildPlugin() { UserProvidedValidationExceptionDecorator(), SmithyValidationExceptionDecorator(), CustomValidationExceptionWithReasonDecorator(), + SigV4EventStreamDecorator(), *decorator, ) logger.info("Loaded plugin to generate pure Rust bindings for the server SDK") diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index 9cfc3ca1063..7fd8d76b32c 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -124,18 +124,7 @@ open class ServerCodegenVisitor( val baseModel = baselineTransform(context.model) val service = settings.getService(baseModel) - val (protocolShape, protocolGeneratorFactory) = - ServerProtocolLoader( - codegenDecorator.protocols( - service.id, - ServerProtocolLoader.DefaultProtocols, - ), - ) - .protocolFor(context.model, service) - this.protocolGeneratorFactory = protocolGeneratorFactory - model = codegenDecorator.transformModel(service, baseModel, settings) - val serverSymbolProviders = ServerSymbolProviders.from( settings, @@ -146,7 +135,19 @@ open class ServerCodegenVisitor( codegenDecorator, RustServerCodegenPlugin::baseSymbolProvider, ) - + val (protocolShape, protocolGeneratorFactory) = + ServerProtocolLoader( + codegenDecorator.protocols( + service.id, + ServerProtocolLoader.defaultProtocols { it -> + codegenDecorator.httpCustomizations( + serverSymbolProviders.symbolProvider, + it, + ) + }, + ), + ) + .protocolFor(context.model, service) codegenContext = ServerCodegenContext( model, @@ -160,6 +161,7 @@ open class ServerCodegenVisitor( serverSymbolProviders.constraintViolationSymbolProvider, serverSymbolProviders.pubCrateConstrainedShapeSymbolProvider, ) + this.protocolGeneratorFactory = protocolGeneratorFactory // We can use a not-null assertion because [CombinedServerCodegenDecorator] returns a not null value. validationExceptionConversionGenerator = codegenDecorator.validationExceptionConversion(codegenContext)!! diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamDecorator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamDecorator.kt new file mode 100644 index 00000000000..0cf966749e4 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamDecorator.kt @@ -0,0 +1,134 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import software.amazon.smithy.aws.traits.auth.SigV4Trait +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.knowledge.ServiceIndex +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingSection +import software.amazon.smithy.rust.codegen.core.smithy.mapRustType +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait +import software.amazon.smithy.rust.codegen.core.util.getTrait +import software.amazon.smithy.rust.codegen.core.util.isEventStream +import software.amazon.smithy.rust.codegen.core.util.isInputEventStream +import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator + +/** + * Decorator that adds SigV4 event stream unsigning support to server code generation. + */ +class SigV4EventStreamDecorator : ServerCodegenDecorator { + override val name: String = "SigV4EventStreamDecorator" + override val order: Byte = 0 + + override fun httpCustomizations( + symbolProvider: RustSymbolProvider, + protocol: ShapeId, + ): List { + return listOf(SigV4EventStreamCustomization(symbolProvider)) + } + + override fun symbolProvider(base: RustSymbolProvider): RustSymbolProvider { + // We need access to the service shape to check for SigV4 trait, but the base interface doesn't provide it. + // For now, we'll wrap all event streams and let the runtime code handle the detection. + return SigV4EventStreamSymbolProvider(base) + } +} + +internal fun RustSymbolProvider.usesSigAuth(): Boolean = + ServiceIndex.of(model).getAuthSchemes(moduleProviderContext.serviceShape!!).containsKey(SigV4Trait.ID) + +// Goes from `T` to `SignedEvent` +fun wrapInSignedEvent( + inner: Symbol, + runtimeConfig: RuntimeConfig, +) = inner.mapRustType { + RustType.Application( + SigV4EventStreamSupportStructures.signedEvent(runtimeConfig).toSymbol().rustType(), + listOf(inner.rustType()), + ) +} + +// Goes from `E` to `SignedEventError` +fun wrapInSignedEventError( + inner: Symbol, + runtimeConfig: RuntimeConfig, +) = inner.mapRustType { + RustType.Application( + SigV4EventStreamSupportStructures.signedEventError(runtimeConfig).toSymbol().rustType(), + listOf(inner.rustType()), + ) +} + +/** + * Symbol provider wrapper that modifies event stream types to support SigV4 signed messages. + */ +class SigV4EventStreamSymbolProvider( + base: RustSymbolProvider, +) : WrappingSymbolProvider(base) { + private val serviceIsSigv4 = base.usesSigAuth() + private val runtimeConfig = base.config.runtimeConfig + + override fun toSymbol(shape: Shape): Symbol { + val baseSymbol = super.toSymbol(shape) + if (!serviceIsSigv4) { + return baseSymbol + } + // We only want to wrap with Event Stream types when dealing with member shapes + if (shape is MemberShape && shape.isEventStream(model)) { + // Determine if the member has a container that is a synthetic input or output + val operationShape = + model.expectShape(shape.container).let { maybeInput -> + val operationId = + maybeInput.getTrait()?.operation + operationId?.let { model.expectShape(it, OperationShape::class.java) } + } + // If we find an operation shape, then we can wrap the type + if (operationShape != null) { + if (operationShape.isInputEventStream(model)) { + return SigV4EventStreamSupportStructures.wrapInEventStreamSigV4(baseSymbol, runtimeConfig) + } + } + } + + return baseSymbol + } +} + +class SigV4EventStreamCustomization(private val symbolProvider: RustSymbolProvider) : HttpBindingCustomization() { + override fun section(section: HttpBindingSection): Writable = + writable { + when (section) { + is HttpBindingSection.BeforeCreatingEventStreamReceiver -> { + // Check if this service uses SigV4 auth + if (symbolProvider.usesSigAuth()) { + val codegenScope = + SigV4EventStreamSupportStructures.codegenScope(symbolProvider.config.runtimeConfig) + rustTemplate( + """ + let ${section.unmarshallerVariableName} = #{SigV4Unmarshaller}::new(${section.unmarshallerVariableName}); + """, + *codegenScope, + ) + } + } + + else -> {} + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamSupportStructures.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamSupportStructures.kt new file mode 100644 index 00000000000..81f67f1dd41 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamSupportStructures.kt @@ -0,0 +1,313 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.mapRustType +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.util.PANIC + +object SigV4EventStreamSupportStructures { + private val supportModule = RustModule.private("sigv4_event_stream") + + fun codegenScope(runtimeConfig: RuntimeConfig) = + arrayOf( + "SignatureInfo" to signatureInfo(), + "ExtractionError" to extractionError(runtimeConfig), + "SignedEventError" to signedEventError(runtimeConfig), + "SignedEvent" to signedEvent(runtimeConfig), + "SigV4Unmarshaller" to sigV4Unmarshaller(runtimeConfig), + "extract_signed_message" to extractSignedMessage(runtimeConfig), + ) + + /** + * Wraps an event stream Receiver type to handle SigV4 signed messages. + * Transforms: Receiver -> Receiver, SignedEventError> + */ + fun wrapInEventStreamSigV4( + symbol: Symbol, + runtimeConfig: RuntimeConfig, + ): Symbol { + val signedEvent = signedEvent(runtimeConfig) + val signedEventError = signedEventError(runtimeConfig) + return symbol.mapRustType(signedEvent, signedEventError) { rustType -> + // Expect Application(Receiver, [T, E]) + if (rustType is RustType.Application && rustType.name == "Receiver" && rustType.args.size == 2) { + val eventType = rustType.args[0] + val errorType = rustType.args[1] + + // Create SignedEvent and SignedEventError + val wrappedEventType = + RustType.Application( + signedEvent.toSymbol().rustType(), + listOf(eventType), + ) + val wrappedErrorType = + RustType.Application( + signedEventError.toSymbol().rustType(), + listOf(errorType), + ) + + // Create new Receiver, SignedEventError> + RustType.Application( + rustType.type, + listOf(wrappedEventType, wrappedErrorType), + ) + } else { + PANIC("Called wrap in EventStreamSigV4 on ${symbol.rustType()} which was not an event stream receiver") + } + } + } + + private fun signatureInfo(): RuntimeType = + RuntimeType.forInlineFun("SignatureInfo", supportModule) { + rustTemplate( + """ + /// Information extracted from a signed event stream message + ##[non_exhaustive] + ##[derive(Debug, Clone)] + pub struct SignatureInfo { + /// The chunk signature bytes from the `:chunk-signature` header + pub chunk_signature: Vec, + /// The timestamp from the `:date` header + pub timestamp: #{SystemTime}, + } + """, + "SystemTime" to RuntimeType.std.resolve("time::SystemTime"), + ) + } + + private fun extractionError(runtimeConfig: RuntimeConfig): RuntimeType = + RuntimeType.forInlineFun("ExtractionError", supportModule) { + rustTemplate( + """ + /// Error type for signed message extraction operations + ##[non_exhaustive] + ##[derive(Debug)] + pub enum ExtractionError { + /// The payload could not be decoded as a valid message + ##[non_exhaustive] + InvalidPayload { + error: #{EventStreamError}, + }, + /// The timestamp header is missing or has an invalid format + ##[non_exhaustive] + InvalidTimestamp, + } + """, + "EventStreamError" to CargoDependency.smithyEventStream(runtimeConfig).toType().resolve("error::Error"), + ) + } + + fun signedEventError(runtimeConfig: RuntimeConfig): RuntimeType = + RuntimeType.forInlineFun("SignedEventError", supportModule) { + rustTemplate( + """ + /// Error wrapper for signed event stream errors + ##[derive(Debug)] + pub enum SignedEventError { + /// Error from the underlying event stream + Event(E), + /// Error extracting signed message + InvalidSignedEvent(#{ExtractionError}), + } + + impl From for SignedEventError { + fn from(err: E) -> Self { + SignedEventError::Event(err) + } + } + """, + "ExtractionError" to extractionError(runtimeConfig), + ) + } + + fun signedEvent(runtimeConfig: RuntimeConfig): RuntimeType = + RuntimeType.forInlineFun("SignedEvent", supportModule) { + rustTemplate( + """ + /// Wrapper for event stream messages that may be signed + ##[derive(Debug)] + pub struct SignedEvent { + /// The actual event message + pub message: T, + /// Signature information if the message was signed + pub signature: #{Option}<#{SignatureInfo}>, + } + """, + "Option" to RuntimeType.std.resolve("option::Option"), + "SignatureInfo" to signatureInfo(), + ) + } + + private fun sigV4Unmarshaller(runtimeConfig: RuntimeConfig): RuntimeType = + RuntimeType.forInlineFun("SigV4Unmarshaller", supportModule) { + rustTemplate( + """ + /// Unmarshaller wrapper that handles SigV4 signed event stream messages + ##[derive(Debug)] + pub struct SigV4Unmarshaller { + inner: T, + } + + impl SigV4Unmarshaller { + pub fn new(inner: T) -> Self { + Self { inner } + } + } + + impl #{UnmarshallMessage} for SigV4Unmarshaller + where + T: #{UnmarshallMessage}, + { + type Output = #{SignedEvent}; + type Error = #{SignedEventError}; + + fn unmarshall(&self, message: &#{Message}) -> #{Result}<#{UnmarshalledMessage}, #{EventStreamError}> { + // First, try to extract the signed message + match #{extract_signed_message}(message) { + Ok(MaybeSignedMessage::Signed { message: inner_message, signature }) => { + // Process the inner message with the base unmarshaller + match self.inner.unmarshall(&inner_message) { + Ok(unmarshalled) => match unmarshalled { + #{UnmarshalledMessage}::Event(event) => { + Ok(#{UnmarshalledMessage}::Event(#{SignedEvent} { + message: event, + signature: Some(signature), + })) + } + #{UnmarshalledMessage}::Error(err) => { + Ok(#{UnmarshalledMessage}::Error(#{SignedEventError}::Event(err))) + } + }, + Err(err) => Err(err), + } + } + Ok(MaybeSignedMessage::Unsigned) => { + // Process unsigned message directly + match self.inner.unmarshall(message) { + Ok(unmarshalled) => match unmarshalled { + #{UnmarshalledMessage}::Event(event) => { + Ok(#{UnmarshalledMessage}::Event(#{SignedEvent} { + message: event, + signature: None, + })) + } + #{UnmarshalledMessage}::Error(err) => { + Ok(#{UnmarshalledMessage}::Error(#{SignedEventError}::Event(err))) + } + }, + Err(err) => Err(err), + } + } + Err(extraction_err) => Ok(#{UnmarshalledMessage}::Error(#{SignedEventError}::InvalidSignedEvent(extraction_err))), + } + } + } + """, + "UnmarshallMessage" to + CargoDependency.smithyEventStream(runtimeConfig).toType() + .resolve("frame::UnmarshallMessage"), + "UnmarshalledMessage" to + CargoDependency.smithyEventStream(runtimeConfig).toType() + .resolve("frame::UnmarshalledMessage"), + "Message" to CargoDependency.smithyTypes(runtimeConfig).toType().resolve("event_stream::Message"), + "EventStreamError" to CargoDependency.smithyEventStream(runtimeConfig).toType().resolve("error::Error"), + "SignedEvent" to signedEvent(runtimeConfig), + "SignedEventError" to signedEventError(runtimeConfig), + "extract_signed_message" to extractSignedMessage(runtimeConfig), + *RuntimeType.preludeScope, + ) + } + + private fun extractSignedMessage(runtimeConfig: RuntimeConfig): RuntimeType = + RuntimeType.forInlineFun("extract_signed_message", supportModule) { + rustTemplate( + """ + /// Result of extracting a potentially signed message + ##[derive(Debug)] + pub enum MaybeSignedMessage { + /// Message was signed and has been extracted + Signed { + /// The inner message that was signed + message: #{Message}, + /// Signature information from the outer message + signature: #{SignatureInfo}, + }, + /// Message was not signed (no `:chunk-signature` header present) + Unsigned, + } + + /// Extracts the inner message from a potentially signed event stream message. + pub fn extract_signed_message(message: &#{Message}) -> #{Result} { + // Check if message has chunk signature + let mut chunk_signature = None; + let mut timestamp = None; + + for header in message.headers() { + match header.name().as_str() { + ":chunk-signature" => { + if let #{HeaderValue}::ByteArray(bytes) = header.value() { + chunk_signature = Some(bytes.as_ref().to_vec()); + } + } + ":date" => { + if let #{HeaderValue}::Timestamp(ts) = header.value() { + timestamp = Some( + #{SystemTime}::try_from(*ts) + .map_err(|_err| #{ExtractionError}::InvalidTimestamp)?, + ); + } else { + return Err(#{ExtractionError}::InvalidTimestamp); + } + } + _ => {} + } + } + + let Some(chunk_signature) = chunk_signature else { + return Ok(MaybeSignedMessage::Unsigned); + }; + + let Some(timestamp) = timestamp else { + return Err(#{ExtractionError}::InvalidTimestamp); + }; + + // Extract inner message + let cursor = #{Cursor}::new(message.payload()); + let inner_message = #{read_message_from}(cursor) + .map_err(|err| #{ExtractionError}::InvalidPayload { error: err })?; + + Ok(MaybeSignedMessage::Signed { + message: inner_message, + signature: #{SignatureInfo} { + chunk_signature, + timestamp, + }, + }) + } + """, + "Message" to CargoDependency.smithyTypes(runtimeConfig).toType().resolve("event_stream::Message"), + "HeaderValue" to + CargoDependency.smithyTypes(runtimeConfig).toType() + .resolve("event_stream::HeaderValue"), + "SystemTime" to RuntimeType.std.resolve("time::SystemTime"), + "Cursor" to RuntimeType.std.resolve("io::Cursor"), + "read_message_from" to + CargoDependency.smithyEventStream(runtimeConfig).toType() + .resolve("frame::read_message_from"), + "SignatureInfo" to signatureInfo(), + "ExtractionError" to extractionError(runtimeConfig), + *RuntimeType.preludeScope, + ) + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt index 6e23bbd6787..d82d3ad65fb 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt @@ -10,8 +10,10 @@ import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.customize.CombinedCoreCodegenDecorator import software.amazon.smithy.rust.codegen.core.smithy.customize.CoreCodegenDecorator +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings @@ -32,6 +34,11 @@ interface ServerCodegenDecorator : CoreCodegenDecorator = emptyList() + fun validationExceptionConversion(codegenContext: ServerCodegenContext): ValidationExceptionConversionGenerator? = null @@ -95,6 +102,11 @@ class CombinedServerCodegenDecorator(decorators: List) : decorator.protocols(serviceId, protocolMap) } + override fun httpCustomizations( + symbolProvider: RustSymbolProvider, + protocol: ShapeId, + ): List = orderedDecorators.flatMap { it.httpCustomizations(symbolProvider, protocol) } + override fun validationExceptionConversion( codegenContext: ServerCodegenContext, ): ValidationExceptionConversionGenerator = diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt index 960cbd735d7..98434672d52 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt @@ -74,6 +74,7 @@ class ServerResponseBeforeIteratingOverMapBoundWithHttpPrefixHeadersUnwrapConstr is HttpBindingSection.BeforeRenderingHeaderValue, is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders, is HttpBindingSection.AfterDeserializingIntoADateTimeOfHttpHeaders, + is HttpBindingSection.BeforeCreatingEventStreamReceiver, -> emptySection } } @@ -88,7 +89,8 @@ class ServerResponseBeforeRenderingHeadersHttpBindingCustomization(val codegenCo when (section) { is HttpBindingSection.BeforeRenderingHeaderValue -> writable { - val isIntegral = section.context.shape is ByteShape || section.context.shape is ShortShape || section.context.shape is IntegerShape || section.context.shape is LongShape + val isIntegral = + section.context.shape is ByteShape || section.context.shape is ShortShape || section.context.shape is IntegerShape || section.context.shape is LongShape val isCollection = section.context.shape is CollectionShape val workingWithPublicWrapper = @@ -107,6 +109,7 @@ class ServerResponseBeforeRenderingHeadersHttpBindingCustomization(val codegenCo is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders, is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders, is HttpBindingSection.AfterDeserializingIntoADateTimeOfHttpHeaders, + is HttpBindingSection.BeforeCreatingEventStreamReceiver, -> emptySection } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index a0532b79fe6..cd9fe80d502 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -179,6 +179,14 @@ class ServerHttpBoundProtocolTraitImplGenerator( private val httpBindingResolver = protocol.httpBindingResolver private val protocolFunctions = ProtocolFunctions(codegenContext) + fun withHttpBindingCustomizations( + customizations: List, + ): ServerHttpBoundProtocolTraitImplGenerator { + return ServerHttpBoundProtocolTraitImplGenerator( + codegenContext, protocol, this.customizations, additionalHttpBindingCustomizations + customizations, + ) + } + private val codegenScope = arrayOf( "AsyncTrait" to ServerCargoDependency.AsyncTrait.toType(), diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt index a121697eb7f..ea875699843 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt @@ -9,11 +9,13 @@ import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait +import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.protocol.traits.Rpcv2CborTrait import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolLoader import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap @@ -29,7 +31,11 @@ class StreamPayloadSerializerCustomization : ServerHttpBoundProtocolCustomizatio if (section.params.shape.isOutputEventStream(section.params.codegenContext.model)) { // Event stream payload, of type `aws_smithy_http::event_stream::MessageStreamAdapter`, already // implements the `Stream` trait, so no need to wrap it in the new-type. - section.params.payloadGenerator.generatePayload(this, section.params.shapeName, section.params.shape) + section.params.payloadGenerator.generatePayload( + this, + section.params.shapeName, + section.params.shape, + ) } else { // Otherwise, the stream payload is `aws_smithy_types::byte_stream::ByteStream`. We wrap it in the // new-type to enable the `Stream` trait. @@ -54,39 +60,44 @@ class StreamPayloadSerializerCustomization : ServerHttpBoundProtocolCustomizatio class ServerProtocolLoader(supportedProtocols: ProtocolMap) : ProtocolLoader(supportedProtocols) { companion object { - val DefaultProtocols = - mapOf( - RestJson1Trait.ID to - ServerRestJsonFactory( - additionalServerHttpBoundProtocolCustomizations = - listOf( - StreamPayloadSerializerCustomization(), - ), - ), - RestXmlTrait.ID to - ServerRestXmlFactory( - additionalServerHttpBoundProtocolCustomizations = - listOf( - StreamPayloadSerializerCustomization(), - ), - ), - AwsJson1_0Trait.ID to - ServerAwsJsonFactory( - AwsJsonVersion.Json10, - additionalServerHttpBoundProtocolCustomizations = listOf(StreamPayloadSerializerCustomization()), - ), - AwsJson1_1Trait.ID to - ServerAwsJsonFactory( - AwsJsonVersion.Json11, - additionalServerHttpBoundProtocolCustomizations = listOf(StreamPayloadSerializerCustomization()), - ), - Rpcv2CborTrait.ID to - ServerRpcV2CborFactory( - additionalServerHttpBoundProtocolCustomizations = - listOf( - StreamPayloadSerializerCustomization(), - ), - ), - ) + fun defaultProtocols( + httpBindingCustomizations: (ShapeId) -> List = { _ -> listOf() }, + ) = mapOf( + RestJson1Trait.ID to + ServerRestJsonFactory( + additionalServerHttpBoundProtocolCustomizations = + listOf( + StreamPayloadSerializerCustomization(), + ), + additionalHttpBindingCustomizations = httpBindingCustomizations(RestJson1Trait.ID), + ), + RestXmlTrait.ID to + ServerRestXmlFactory( + additionalServerHttpBoundProtocolCustomizations = + listOf( + StreamPayloadSerializerCustomization(), + ), + ), + AwsJson1_0Trait.ID to + ServerAwsJsonFactory( + AwsJsonVersion.Json10, + additionalServerHttpBoundProtocolCustomizations = listOf(StreamPayloadSerializerCustomization()), + additionalHttpBindingCustomizations = httpBindingCustomizations(AwsJson1_0Trait.ID), + ), + AwsJson1_1Trait.ID to + ServerAwsJsonFactory( + AwsJsonVersion.Json11, + additionalServerHttpBoundProtocolCustomizations = listOf(StreamPayloadSerializerCustomization()), + additionalHttpBindingCustomizations = httpBindingCustomizations(AwsJson1_1Trait.ID), + ), + Rpcv2CborTrait.ID to + ServerRpcV2CborFactory( + additionalServerHttpBoundProtocolCustomizations = + listOf( + StreamPayloadSerializerCustomization(), + ), + additionalHttpBindingCustomizations = httpBindingCustomizations(Rpcv2CborTrait.ID), + ), + ) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRpcV2CborFactory.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRpcV2CborFactory.kt index d4fa989dbbe..472a0c47991 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRpcV2CborFactory.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRpcV2CborFactory.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGeneratorFactory @@ -14,6 +15,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser class ServerRpcV2CborFactory( private val additionalServerHttpBoundProtocolCustomizations: List = emptyList(), + private val additionalHttpBindingCustomizations: List = listOf(), ) : ProtocolGeneratorFactory { override fun protocol(codegenContext: ServerCodegenContext): Protocol = ServerRpcV2CborProtocol(codegenContext) @@ -22,6 +24,7 @@ class ServerRpcV2CborFactory( codegenContext, ServerRpcV2CborProtocol(codegenContext), additionalServerHttpBoundProtocolCustomizations, + additionalHttpBindingCustomizations = additionalHttpBindingCustomizations, ) override fun support(): ProtocolSupport { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt index a03dfbdd4e3..840d82720db 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt @@ -135,7 +135,7 @@ fun serverTestCodegenContext( fun loadServerProtocol(model: Model): ServerProtocol { val codegenContext = serverTestCodegenContext(model) val (_, protocolGeneratorFactory) = - ServerProtocolLoader(ServerProtocolLoader.DefaultProtocols).protocolFor(model, codegenContext.serviceShape) + ServerProtocolLoader(ServerProtocolLoader.defaultProtocols()).protocolFor(model, codegenContext.serviceShape) return protocolGeneratorFactory.buildProtocolGenerator(codegenContext).protocol } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamSupportStructuresTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamSupportStructuresTest.kt new file mode 100644 index 00000000000..2f09c18ffcc --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamSupportStructuresTest.kt @@ -0,0 +1,73 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import org.junit.jupiter.api.Test +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest + +class SigV4EventStreamSupportStructuresTest { + private val runtimeConfig = TestRuntimeConfig + + @Test + fun `support structures compile`() { + val project = TestWorkspace.testProject() + project.withModule(RustModule.private("sigv4_event_stream")) { + val codegenScope = SigV4EventStreamSupportStructures.codegenScope(runtimeConfig) + + // Generate the support structures - RuntimeType.forInlineFun automatically generates the code + // when the RuntimeType is used, so we just need to reference them + rustTemplate( + """ + use std::time::SystemTime; + + // Reference the types to trigger their generation + fn _test_types() { + let _info: #{SignatureInfo}; + let _error: #{ExtractionError}; + let _signed_error: #{SignedEventError}; + let _signed_event: #{SignedEvent}; + let _unmarshaller: #{SigV4Unmarshaller}; + } + """, + *codegenScope, + ) + + unitTest("test_signature_info_creation") { + rustTemplate( + """ + let info = #{SignatureInfo} { + chunk_signature: vec![1, 2, 3], + timestamp: SystemTime::now(), + }; + assert_eq!(info.chunk_signature, vec![1, 2, 3]); + """, + *codegenScope, + ) + } + + unitTest("test_signed_event_creation") { + rustTemplate( + """ + let event = #{SignedEvent} { + message: "test".to_string(), + signature: None, + }; + assert_eq!(event.message, "test"); + assert!(event.signature.is_none()); + """, + *codegenScope, + ) + } + } + + project.compileAndTest() + } +} diff --git a/examples/pokemon-service-tls/tests/common/mod.rs b/examples/pokemon-service-tls/tests/common/mod.rs index 8954365a205..37a67c7c2d1 100644 --- a/examples/pokemon-service-tls/tests/common/mod.rs +++ b/examples/pokemon-service-tls/tests/common/mod.rs @@ -3,10 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -use std::{fs::File, io::BufReader, process::Command, time::Duration}; - -use assert_cmd::prelude::*; +use assert_cmd::cargo_bin; use aws_smithy_runtime::client::http::hyper_014::HyperClientBuilder; +use std::{fs::File, io::BufReader, process::Command, time::Duration}; use tokio::time::sleep; use pokemon_service_client::{Client, Config}; @@ -14,8 +13,7 @@ use pokemon_service_common::ChildDrop; use pokemon_service_tls::{DEFAULT_DOMAIN, DEFAULT_PORT, DEFAULT_TEST_CERT}; pub async fn run_server() -> ChildDrop { - let crate_name = std::env::var("CARGO_PKG_NAME").unwrap(); - let child = Command::cargo_bin(crate_name).unwrap().spawn().unwrap(); + let child = Command::new(cargo_bin!()).spawn().unwrap(); sleep(Duration::from_millis(500)).await; diff --git a/examples/pokemon-service/tests/common/mod.rs b/examples/pokemon-service/tests/common/mod.rs index 2d58f4a975a..f44080b19b1 100644 --- a/examples/pokemon-service/tests/common/mod.rs +++ b/examples/pokemon-service/tests/common/mod.rs @@ -3,9 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +use assert_cmd::cargo_bin; use std::{process::Command, time::Duration}; - -use assert_cmd::prelude::*; use tokio::time::sleep; use pokemon_service::{DEFAULT_ADDRESS, DEFAULT_PORT}; @@ -13,8 +12,7 @@ use pokemon_service_client::{Client, Config}; use pokemon_service_common::ChildDrop; pub async fn run_server() -> ChildDrop { - let crate_name = std::env::var("CARGO_PKG_NAME").unwrap(); - let child = Command::cargo_bin(crate_name).unwrap().spawn().unwrap(); + let child = Command::new(cargo_bin!()).spawn().unwrap(); sleep(Duration::from_millis(500)).await;