From d97defbd14906fff39ef77d1ffea82f0202cbd35 Mon Sep 17 00:00:00 2001 From: Matteo Bigoi <1781140+crisidev@users.noreply.github.com> Date: Thu, 13 Apr 2023 14:55:33 +0100 Subject: [PATCH] [Python] Support more testing model (#2541) * Remove parameter from `Protocol`s `structuredDataParser`, `structuredDataSerializer` No implementation of the `Protocol` interface makes use of the `OperationShape` parameter in the `structuredDataParser` and `structuredDataSerializer` methods. * Remove the TypeConversionGenerator class in favor of using customizations for JsonParserGenerator and ServerHttpBoundProtocolGenerator. Signed-off-by: Bigo <1781140+crisidev@users.noreply.github.com> * Make the additionaParserCustomizations default to empty list * Fix merge conflict * Fix missing ; * Use better defaults when checking for customizations * Use better defaults when checking for customizations * Add HttpBindingCustomization and relax the datetime symbol check * Support recursive shapes and add a lot more models to the tests Signed-off-by: Bigo <1781140+crisidev@users.noreply.github.com> * Support naming obstacle course * Add support for constrained blobs conversions * Support constraint traits * Try to generate the full diff Signed-off-by: Bigo <1781140+crisidev@users.noreply.github.com> * A better way of checking if we need to go into the Timestamp branch * Remove wheels folder --------- Signed-off-by: Bigo <1781140+crisidev@users.noreply.github.com> Co-authored-by: david-perez --- .gitignore | 3 + .../generators/TypeConversionGenerator.kt | 54 -------- .../generators/http/HttpBindingGenerator.kt | 11 +- .../protocols/parse/JsonParserGenerator.kt | 31 ++++- codegen-server-test/python/build.gradle.kts | 45 ++++++- .../smithy/PythonServerCodegenVisitor.kt | 27 +++- .../ConstrainedPythonBlobGenerator.kt | 100 ++++++++++++++ .../generators/PythonApplicationGenerator.kt | 20 +-- .../generators/PythonServerModuleGenerator.kt | 5 +- .../PythonServerOperationHandlerGenerator.kt | 9 +- .../PythonServerStructureGenerator.kt | 22 +++ .../generators/PythonServerUnionGenerator.kt | 4 +- .../protocols/PythonServerProtocolLoader.kt | 125 ++++++++++++++++++ .../http/ServerRequestBindingGenerator.kt | 4 +- .../http/ServerResponseBindingGenerator.kt | 2 + .../generators/protocol/ServerProtocol.kt | 10 +- .../server/smithy/protocols/ServerAwsJson.kt | 19 ++- .../ServerHttpBoundProtocolGenerator.kt | 51 +++++-- .../server/smithy/protocols/ServerRestJson.kt | 20 ++- tools/ci-scripts/codegen-diff/diff_lib.py | 39 +++--- 20 files changed, 467 insertions(+), 134 deletions(-) delete mode 100644 codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TypeConversionGenerator.kt create mode 100644 codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/ConstrainedPythonBlobGenerator.kt create mode 100644 codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt diff --git a/.gitignore b/.gitignore index 710a23e16..268344ecb 100644 --- a/.gitignore +++ b/.gitignore @@ -56,3 +56,6 @@ target/ # tools .tool-versions + +# python +__pycache__ diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TypeConversionGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TypeConversionGenerator.kt deleted file mode 100644 index b79200f95..000000000 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TypeConversionGenerator.kt +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.core.smithy.generators - -import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.Model -import software.amazon.smithy.model.shapes.BlobShape -import software.amazon.smithy.model.shapes.Shape -import software.amazon.smithy.model.shapes.TimestampShape -import software.amazon.smithy.rust.codegen.core.rustlang.Writable -import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.rustType - -/* - * Utility class used to force casting a non primitive type into one overriden by a new symbol provider, - * by explicitly calling `from()` or into(). - * - * For example we use this in the server Python implementation, where we override types like [Blob] and [DateTime] - * with wrappers compatible with Python, without touching the original implementation coming from `aws-smithy-types`. - */ -class TypeConversionGenerator(private val model: Model, private val symbolProvider: RustSymbolProvider, private val runtimeConfig: RuntimeConfig) { - private fun findOldSymbol(shape: Shape): Symbol { - return when (shape) { - is BlobShape -> RuntimeType.blob(runtimeConfig).toSymbol() - is TimestampShape -> RuntimeType.dateTime(runtimeConfig).toSymbol() - else -> symbolProvider.toSymbol(shape) - } - } - - fun convertViaFrom(shape: Shape): Writable = - writable { - val oldSymbol = findOldSymbol(shape) - val newSymbol = symbolProvider.toSymbol(shape) - if (oldSymbol.rustType() != newSymbol.rustType()) { - rust(".map($newSymbol::from)") - } - } - - fun convertViaInto(shape: Shape): Writable = - writable { - val oldSymbol = findOldSymbol(shape) - val newSymbol = symbolProvider.toSymbol(shape) - if (oldSymbol.rustType() != newSymbol.rustType()) { - rust(".into()") - } - } -} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt index ab920bc53..881d0f77a 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt @@ -88,6 +88,9 @@ sealed class HttpBindingSection(name: String) : Section(name) { data class AfterDeserializingIntoAHashMapOfHttpPrefixHeaders(val memberShape: MemberShape) : HttpBindingSection("AfterDeserializingIntoAHashMapOfHttpPrefixHeaders") + + data class AfterDeserializingIntoADateTimeOfHttpHeaders(val memberShape: MemberShape) : + HttpBindingSection("AfterDeserializingIntoADateTimeOfHttpHeaders") } typealias HttpBindingCustomization = NamedCustomization @@ -353,7 +356,7 @@ class HttpBindingGenerator( rustType to targetShape } val parsedValue = safeName() - if (coreType == dateTime) { + if (coreShape.isTimestampShape()) { val timestampFormat = index.determineTimestampFormat( memberShape, @@ -362,10 +365,14 @@ class HttpBindingGenerator( ) val timestampFormatType = RuntimeType.parseTimestampFormat(codegenTarget, runtimeConfig, timestampFormat) rust( - "let $parsedValue: Vec<${coreType.render()}> = #T::many_dates(headers, #T)?;", + "let $parsedValue: Vec<${coreType.render()}> = #T::many_dates(headers, #T)?", headerUtil, timestampFormatType, ) + for (customization in customizations) { + customization.section(HttpBindingSection.AfterDeserializingIntoADateTimeOfHttpHeaders(memberShape))(this) + } + rust(";") } else if (coreShape.isPrimitive()) { rust( "let $parsedValue = #T::read_many_primitive::<${coreType.render()}>(headers)?;", diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt index dd7255e47..1566d84bf 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt @@ -39,7 +39,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.canUseDefault import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.Section -import software.amazon.smithy.rust.codegen.core.smithy.generators.TypeConversionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName @@ -61,6 +60,12 @@ import software.amazon.smithy.utils.StringUtils */ sealed class JsonParserSection(name: String) : Section(name) { data class BeforeBoxingDeserializedMember(val shape: MemberShape) : JsonParserSection("BeforeBoxingDeserializedMember") + + data class AfterTimestampDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterTimestampDeserializedMember") + + data class AfterBlobDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterBlobDeserializedMember") + + data class AfterDocumentDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterDocumentDeserializedMember") } /** @@ -94,7 +99,6 @@ class JsonParserGenerator( private val runtimeConfig = codegenContext.runtimeConfig private val codegenTarget = codegenContext.target private val smithyJson = CargoDependency.smithyJson(runtimeConfig).toType() - private val typeConversionGenerator = TypeConversionGenerator(model, symbolProvider, runtimeConfig) private val protocolFunctions = ProtocolFunctions(codegenContext) private val codegenScope = arrayOf( "Error" to smithyJson.resolve("deserialize::error::DeserializeError"), @@ -276,13 +280,13 @@ class JsonParserGenerator( is StringShape -> deserializeString(target) is BooleanShape -> rustTemplate("#{expect_bool_or_null}(tokens.next())?", *codegenScope) is NumberShape -> deserializeNumber(target) - is BlobShape -> deserializeBlob() + is BlobShape -> deserializeBlob(memberShape) is TimestampShape -> deserializeTimestamp(target, memberShape) is CollectionShape -> deserializeCollection(target) is MapShape -> deserializeMap(target) is StructureShape -> deserializeStruct(target) is UnionShape -> deserializeUnion(target) - is DocumentShape -> rustTemplate("Some(#{expect_document}(tokens)?)", *codegenScope) + is DocumentShape -> deserializeDocument(memberShape) else -> PANIC("unexpected shape: $target") } val symbol = symbolProvider.toSymbol(memberShape) @@ -294,11 +298,21 @@ class JsonParserGenerator( } } - private fun RustWriter.deserializeBlob() { + private fun RustWriter.deserializeDocument(member: MemberShape) { + rustTemplate("Some(#{expect_document}(tokens)?)", *codegenScope) + for (customization in customizations) { + customization.section(JsonParserSection.AfterDocumentDeserializedMember(member))(this) + } + } + + private fun RustWriter.deserializeBlob(member: MemberShape) { rustTemplate( "#{expect_blob_or_null}(tokens.next())?", *codegenScope, ) + for (customization in customizations) { + customization.section(JsonParserSection.AfterBlobDeserializedMember(member))(this) + } } private fun RustWriter.deserializeStringInner(target: StringShape, escapedStrName: String) { @@ -349,9 +363,12 @@ class JsonParserGenerator( ) val timestampFormatType = RuntimeType.parseTimestampFormat(codegenTarget, runtimeConfig, timestampFormat) rustTemplate( - "#{expect_timestamp_or_null}(tokens.next(), #{T})?#{ConvertFrom:W}", - "T" to timestampFormatType, "ConvertFrom" to typeConversionGenerator.convertViaFrom(shape), *codegenScope, + "#{expect_timestamp_or_null}(tokens.next(), #{T})?", + "T" to timestampFormatType, *codegenScope, ) + for (customization in customizations) { + customization.section(JsonParserSection.AfterTimestampDeserializedMember(member))(this) + } } private fun RustWriter.deserializeCollection(shape: CollectionShape) { diff --git a/codegen-server-test/python/build.gradle.kts b/codegen-server-test/python/build.gradle.kts index 35144f42f..87ff7cc5b 100644 --- a/codegen-server-test/python/build.gradle.kts +++ b/codegen-server-test/python/build.gradle.kts @@ -54,14 +54,49 @@ val allCodegenTests = "../../codegen-core/common-test-models".let { commonModels // TODO(https://github.com/awslabs/smithy-rs/issues/1401) `@uniqueItems` is used. extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """, ), - // TODO(https://github.com/awslabs/smithy-rs/issues/2476) + CodegenTest( + "aws.protocoltests.json#JsonProtocol", + "json_rpc11", + extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """, + ), + CodegenTest("aws.protocoltests.json10#JsonRpc10", "json_rpc10"), + CodegenTest("aws.protocoltests.restjson#RestJson", "rest_json"), + CodegenTest( + "aws.protocoltests.restjson#RestJsonExtras", + "rest_json_extras", + imports = listOf("$commonModels/rest-json-extras.smithy"), + ), + // TODO(https://github.com/awslabs/smithy-rs/issues/2551) // CodegenTest( - // "aws.protocoltests.json#JsonProtocol", - // "json_rpc11", + // "aws.protocoltests.restjson.validation#RestJsonValidation", + // "rest_json_validation", + // // `@range` trait is used on floating point shapes, which we deliberately don't want to support. + // // See https://github.com/awslabs/smithy-rs/issues/1401. // extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """, // ), - // TODO(https://github.com/awslabs/smithy-rs/issues/2479) - // CodegenTest("aws.protocoltests.json10#JsonRpc10", "json_rpc10"), + CodegenTest( + "com.amazonaws.constraints#ConstraintsService", + "constraints", + imports = listOf("$commonModels/constraints.smithy"), + ), + CodegenTest( + "com.amazonaws.constraints#ConstraintsService", + "constraints_without_public_constrained_types", + imports = listOf("$commonModels/constraints.smithy"), + extraConfig = """, "codegen": { "publicConstrainedTypes": false } """, + ), + CodegenTest( + "com.amazonaws.constraints#UniqueItemsService", + "unique_items", + imports = listOf("$commonModels/unique-items.smithy"), + ), + CodegenTest( + "naming_obs_structs#NamingObstacleCourseStructs", + "naming_test_structs", + imports = listOf("$commonModels/naming-obstacle-course-structs.smithy"), + ), + CodegenTest("casing#ACRONYMInside_Service", "naming_test_casing", imports = listOf("$commonModels/naming-obstacle-course-casing.smithy")), + CodegenTest("crate#Config", "naming_test_ops", imports = listOf("$commonModels/naming-obstacle-course-ops.smithy")), ) } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt index 788b6e695..fa1f8eeb6 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.server.python.smithy import software.amazon.smithy.build.PluginContext import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex +import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape @@ -22,6 +23,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplGenerator import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.isEventStream +import software.amazon.smithy.rust.codegen.server.python.smithy.generators.ConstrainedPythonBlobGenerator import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonApplicationGenerator import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerEnumGenerator import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerEventStreamErrorGenerator @@ -30,6 +32,7 @@ import software.amazon.smithy.rust.codegen.server.python.smithy.generators.Pytho import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerOperationHandlerGenerator import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerStructureGenerator import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerUnionGenerator +import software.amazon.smithy.rust.codegen.server.python.smithy.protocols.PythonServerProtocolLoader import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenVisitor import software.amazon.smithy.rust.codegen.server.smithy.ServerModuleDocProvider @@ -42,8 +45,9 @@ import software.amazon.smithy.rust.codegen.server.smithy.createInlineModuleCreat import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedUnionGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol -import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader +import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput +import software.amazon.smithy.rust.codegen.server.smithy.withModuleOrWithStructureBuilderModule /** * Entrypoint for Python server-side code generation. This class will walk the in-memory model and @@ -68,10 +72,10 @@ class PythonServerCodegenVisitor( val baseModel = baselineTransform(context.model) val service = settings.getService(baseModel) val (protocol, generator) = - ServerProtocolLoader( + PythonServerProtocolLoader( codegenDecorator.protocols( service.id, - ServerProtocolLoader.DefaultProtocols, + PythonServerProtocolLoader.defaultProtocols(settings.runtimeConfig), ), ) .protocolFor(context.model, service) @@ -258,4 +262,21 @@ class PythonServerCodegenVisitor( } } } + + override fun blobShape(shape: BlobShape) { + logger.info("[python-server-codegen] Generating a service $shape") + super.blobShape(shape) + + if (shape.isDirectlyConstrained(codegenContext.symbolProvider)) { + rustCrate.withModuleOrWithStructureBuilderModule(ServerRustModule.Model, shape, codegenContext) { + ConstrainedPythonBlobGenerator( + codegenContext, + rustCrate.createInlineModuleCreator(), + this, + shape, + validationExceptionConversionGenerator, + ).render() + } + } + } } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/ConstrainedPythonBlobGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/ConstrainedPythonBlobGenerator.kt new file mode 100644 index 000000000..a9c202152 --- /dev/null +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/ConstrainedPythonBlobGenerator.kt @@ -0,0 +1,100 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.python.smithy.generators + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.traits.LengthTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.join +import software.amazon.smithy.rust.codegen.core.rustlang.render +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.util.getTrait +import software.amazon.smithy.rust.codegen.core.util.orNull +import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerRuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.InlineModuleCreator +import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.generators.BlobLength +import software.amazon.smithy.rust.codegen.server.smithy.generators.TraitInfo +import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator + +class ConstrainedPythonBlobGenerator( + val codegenContext: ServerCodegenContext, + private val inlineModuleCreator: InlineModuleCreator, + val writer: RustWriter, + val shape: BlobShape, + private val validationExceptionConversionGenerator: ValidationExceptionConversionGenerator, +) { + val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider + val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val constraintViolationSymbolProvider = + with(codegenContext.constraintViolationSymbolProvider) { + if (publicConstrainedTypes) { + this + } else { + PubCrateConstraintViolationSymbolProvider(this) + } + } + val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape) + private val blobConstraintsInfo: List = listOf(LengthTrait::class.java) + .mapNotNull { shape.getTrait(it).orNull() } + .map { BlobLength(it) } + private val constraintsInfo: List = blobConstraintsInfo.map { it.toTraitInfo() } + + fun render() { + val symbol = constrainedShapeSymbolProvider.toSymbol(shape) + val blobType = PythonServerRuntimeType.blob(codegenContext.runtimeConfig).toSymbol().rustType() + renderFrom(symbol, blobType) + renderTryFrom(symbol, blobType) + } + + fun renderFrom(symbol: Symbol, blobType: RustType) { + val name = symbol.name + val inner = blobType.render() + writer.rustTemplate( + """ + impl #{From}<$inner> for #{MaybeConstrained} { + fn from(value: $inner) -> Self { + Self::Unconstrained(value.into()) + } + } + + impl #{From}<$name> for $inner { + fn from(value: $name) -> Self { + value.into_inner().into() + } + } + """, + "MaybeConstrained" to symbol.makeMaybeConstrained(), + "From" to RuntimeType.From, + ) + } + + fun renderTryFrom(symbol: Symbol, blobType: RustType) { + val name = symbol.name + val inner = blobType.render() + writer.rustTemplate( + """ + impl #{TryFrom}<$inner> for $name { + type Error = #{ConstraintViolation}; + + fn try_from(value: $inner) -> Result { + value.try_into() + } + } + """, + "TryFrom" to RuntimeType.TryFrom, + "ConstraintViolation" to constraintViolation, + "TryFromChecks" to constraintsInfo.map { it.tryFromCheck }.join("\n"), + ) + } +} diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index cc4804bb2..deec6f6b2 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.server.python.smithy.generators import software.amazon.smithy.model.knowledge.TopDownIndex import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.traits.DocumentationTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate @@ -203,14 +204,13 @@ class PythonApplicationGenerator( *codegenScope, ) for (operation in operations) { - val operationName = symbolProvider.toSymbol(operation).name - val name = operationName.toSnakeCase() + val fnName = RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(operation).name.toSnakeCase()) rustTemplate( """ - let ${name}_locals = #{pyo3_asyncio}::TaskLocals::new(event_loop); - let handler = self.handlers.get("$name").expect("Python handler for operation `$name` not found").clone(); - let builder = builder.$name(move |input, state| { - #{pyo3_asyncio}::tokio::scope(${name}_locals.clone(), crate::python_operation_adaptor::$name(input, state, handler.clone())) + let ${fnName}_locals = #{pyo3_asyncio}::TaskLocals::new(event_loop); + let handler = self.handlers.get("$fnName").expect("Python handler for operation `$fnName` not found").clone(); + let builder = builder.$fnName(move |input, state| { + #{pyo3_asyncio}::tokio::scope(${fnName}_locals.clone(), crate::python_operation_adaptor::$fnName(input, state, handler.clone())) }); """, *codegenScope, @@ -342,7 +342,7 @@ class PythonApplicationGenerator( ) operations.map { operation -> val operationName = symbolProvider.toSymbol(operation).name - val name = operationName.toSnakeCase() + val fnName = RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(operation).name.toSnakeCase()) val input = PythonType.Opaque("${operationName}Input", "crate::input") val output = PythonType.Opaque("${operationName}Output", "crate::output") @@ -363,15 +363,15 @@ class PythonApplicationGenerator( rustTemplate( """ - /// Method to register `$name` Python implementation inside the handlers map. + /// Method to register `$fnName` Python implementation inside the handlers map. /// It can be used as a function decorator in Python. /// /// :param func ${handler.renderAsDocstring()}: /// :rtype ${PythonType.None.renderAsDocstring()}: ##[pyo3(text_signature = "(${'$'}self, func)")] - pub fn $name(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> { + pub fn $fnName(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> { use #{SmithyPython}::PyApp; - self.register_operation(py, "$name", func) + self.register_operation(py, "$fnName", func) } """, *codegenScope, diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt index 8a622c51e..5baf6e83e 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt @@ -15,6 +15,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RustCrate +import software.amazon.smithy.rust.codegen.core.util.toPascalCase import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerRustModule @@ -71,13 +72,13 @@ class PythonServerModuleGenerator( when (shape) { is UnionShape -> rustTemplate( """ - $moduleType.add_class::()?; + $moduleType.add_class::()?; """, *codegenScope, ) else -> rustTemplate( """ - $moduleType.add_class::()?; + $moduleType.add_class::()?; """, *codegenScope, ) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt index f10780609..e8a704ecd 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt @@ -6,11 +6,13 @@ package software.amazon.smithy.rust.codegen.server.python.smithy.generators import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.util.toPascalCase import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency @@ -50,10 +52,11 @@ class PythonServerOperationHandlerGenerator( private fun renderPythonOperationHandlerImpl(writer: RustWriter) { val operationName = symbolProvider.toSymbol(operation).name - val input = "crate::input::${operationName}Input" - val output = "crate::output::${operationName}Output" + val input = "crate::input::${operationName.toPascalCase()}Input" + val output = "crate::output::${operationName.toPascalCase()}Output" + // TODO(https://github.com/awslabs/smithy-rs/issues/2552) - Use to pascalCase for error shapes. val error = "crate::error::${operationName}Error" - val fnName = operationName.toSnakeCase() + val fnName = RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(operation).name.toSnakeCase()) writer.rustTemplate( """ diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerStructureGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerStructureGenerator.kt index 28b48d31e..1502f5422 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerStructureGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerStructureGenerator.kt @@ -60,6 +60,9 @@ class PythonServerStructureGenerator( writer.rustTemplate("#{ConstructorSignature:W}", "ConstructorSignature" to renderConstructorSignature()) super.renderStructure() renderPyO3Methods() + if (!shape.hasTrait()) { + renderPyBoxTraits() + } } override fun renderStructureMember( @@ -101,6 +104,25 @@ class PythonServerStructureGenerator( ) } + private fun renderPyBoxTraits() { + writer.rustTemplate( + """ + impl<'source> #{pyo3}::FromPyObject<'source> for std::boxed::Box<$name> { + fn extract(ob: &'source #{pyo3}::PyAny) -> #{pyo3}::PyResult { + ob.extract::<$name>().map(Box::new) + } + } + + impl #{pyo3}::IntoPy<#{pyo3}::PyObject> for std::boxed::Box<$name> { + fn into_py(self, py: #{pyo3}::Python<'_>) -> #{pyo3}::PyObject { + (*self).into_py(py) + } + } + """, + "pyo3" to pyO3, + ) + } + private fun renderStructSignatureMembers(): Writable = writable { forEachMember(members) { _, memberName, memberSymbol -> diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerUnionGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerUnionGenerator.kt index df083751a..6b7a7bee8 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerUnionGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerUnionGenerator.kt @@ -121,7 +121,7 @@ class PythonServerUnionGenerator( ) writer.rust("/// :rtype ${unionSymbol.name}:") writer.rustBlock("pub fn $funcNamePart() -> Self") { - rust("Self(${unionSymbol.name}::$variantName") + rust("Self(${unionSymbol.name}::$variantName)") } } else { val memberSymbol = symbolProvider.toSymbol(member) @@ -157,7 +157,7 @@ class PythonServerUnionGenerator( writer.rustBlockTemplate("pub fn as_$funcNamePart(&self) -> #{pyo3}::PyResult<()>", "pyo3" to pyo3) { rustTemplate( """ - self.0.as_$funcNamePart().map_err(#{pyo3}::exceptions::PyValueError::new_err( + self.0.as_$funcNamePart().map_err(|_| #{pyo3}::exceptions::PyValueError::new_err( "${unionSymbol.name} variant is not None" )) """, diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt new file mode 100644 index 000000000..f84d35e7c --- /dev/null +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt @@ -0,0 +1,125 @@ + +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.python.smithy.protocols + +import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait +import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait +import software.amazon.smithy.aws.traits.protocols.RestJson1Trait +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingSection +import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion +import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolLoader +import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserSection +import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerRuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator +import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJsonFactory +import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBoundProtocolCustomization +import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBoundProtocolSection +import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJsonFactory + +/** + * Customization class used to force casting a non primitive type into one overriden by a new symbol provider, + * by explicitly calling `from()` on it. + * + * For example we use this in the server Python implementation, where we override types like [Blob], [DateTime] and [Document] + * with wrappers compatible with Python, without touching the original implementation coming from `aws-smithy-types`. + */ +class PythonServerAfterDeserializedMemberJsonParserCustomization(private val runtimeConfig: RuntimeConfig) : + JsonParserCustomization() { + override fun section(section: JsonParserSection): Writable = when (section) { + is JsonParserSection.AfterTimestampDeserializedMember -> writable { + rust(".map(#T::from)", PythonServerRuntimeType.dateTime(runtimeConfig).toSymbol()) + } + is JsonParserSection.AfterBlobDeserializedMember -> writable { + rust(".map(#T::from)", PythonServerRuntimeType.blob(runtimeConfig).toSymbol()) + } + is JsonParserSection.AfterDocumentDeserializedMember -> writable { + rust(".map(#T::from)", PythonServerRuntimeType.document(runtimeConfig).toSymbol()) + } + else -> emptySection + } +} + +/** + * Customization class used to force casting a non primitive type into one overriden by a new symbol provider, + * by explicitly calling `into()` on it. + */ +class PythonServerAfterDeserializedMemberServerHttpBoundCustomization() : + ServerHttpBoundProtocolCustomization() { + override fun section(section: ServerHttpBoundProtocolSection): Writable = when (section) { + is ServerHttpBoundProtocolSection.AfterTimestampDeserializedMember -> writable { + rust(".into()") + } + else -> emptySection + } +} + +/** + * Customization class used to force casting a `Vec` into one a Python `Vec` + */ +class PythonServerAfterDeserializedMemberHttpBindingCustomization(private val runtimeConfig: RuntimeConfig) : + HttpBindingCustomization() { + override fun section(section: HttpBindingSection): Writable = when (section) { + is HttpBindingSection.AfterDeserializingIntoADateTimeOfHttpHeaders -> writable { + rust(".into_iter().map(#T::from).collect()", PythonServerRuntimeType.dateTime(runtimeConfig).toSymbol()) + } + else -> emptySection + } +} + +class PythonServerProtocolLoader( + private val supportedProtocols: ProtocolMap, +) : ProtocolLoader(supportedProtocols) { + + companion object { + fun defaultProtocols(runtimeConfig: RuntimeConfig) = + mapOf( + RestJson1Trait.ID to ServerRestJsonFactory( + additionalParserCustomizations = listOf( + PythonServerAfterDeserializedMemberJsonParserCustomization(runtimeConfig), + ), + additionalServerHttpBoundProtocolCustomizations = listOf( + PythonServerAfterDeserializedMemberServerHttpBoundCustomization(), + ), + additionalHttpBindingCustomizations = listOf( + PythonServerAfterDeserializedMemberHttpBindingCustomization(runtimeConfig), + ), + ), + AwsJson1_0Trait.ID to ServerAwsJsonFactory( + AwsJsonVersion.Json10, + additionalParserCustomizations = listOf( + PythonServerAfterDeserializedMemberJsonParserCustomization(runtimeConfig), + ), + additionalServerHttpBoundProtocolCustomizations = listOf( + PythonServerAfterDeserializedMemberServerHttpBoundCustomization(), + ), + additionalHttpBindingCustomizations = listOf( + PythonServerAfterDeserializedMemberHttpBindingCustomization(runtimeConfig), + ), + ), + AwsJson1_1Trait.ID to ServerAwsJsonFactory( + AwsJsonVersion.Json11, + additionalParserCustomizations = listOf( + PythonServerAfterDeserializedMemberJsonParserCustomization(runtimeConfig), + ), + additionalServerHttpBoundProtocolCustomizations = listOf( + PythonServerAfterDeserializedMemberServerHttpBoundCustomization(), + ), + additionalHttpBindingCustomizations = listOf( + PythonServerAfterDeserializedMemberHttpBindingCustomization(runtimeConfig), + ), + ), + ) + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt index 41201b869..e1e6c747f 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt @@ -28,6 +28,7 @@ class ServerRequestBindingGenerator( protocol: Protocol, codegenContext: ServerCodegenContext, operationShape: OperationShape, + additionalHttpBindingCustomizations: List = listOf(), ) { private val httpBindingGenerator = HttpBindingGenerator( @@ -39,7 +40,7 @@ class ServerRequestBindingGenerator( ServerRequestAfterDeserializingIntoAHashMapOfHttpPrefixHeadersWrapInUnconstrainedMapHttpBindingCustomization( codegenContext, ), - ), + ) + additionalHttpBindingCustomizations, ) fun generateDeserializeHeaderFn(binding: HttpBindingDescriptor): RuntimeType = @@ -81,5 +82,6 @@ class ServerRequestAfterDeserializingIntoAHashMapOfHttpPrefixHeadersWrapInUncons ) } } + else -> emptySection } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt index cc4783005..01448d27a 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt @@ -71,6 +71,7 @@ class ServerResponseBeforeIteratingOverMapBoundWithHttpPrefixHeadersUnwrapConstr is HttpBindingSection.BeforeRenderingHeaderValue, is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders, + is HttpBindingSection.AfterDeserializingIntoADateTimeOfHttpHeaders, -> emptySection } } @@ -100,6 +101,7 @@ class ServerResponseBeforeRenderingHeadersHttpBindingCustomization(val codegenCo is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders, is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders, + is HttpBindingSection.AfterDeserializingIntoADateTimeOfHttpHeaders, -> emptySection } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index 0b466f95c..3ee16c233 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -106,6 +106,7 @@ fun jsonParserGenerator( codegenContext: ServerCodegenContext, httpBindingResolver: HttpBindingResolver, jsonName: (MemberShape) -> String, + additionalParserCustomizations: List = listOf(), ): JsonParserGenerator = JsonParserGenerator( codegenContext, @@ -114,12 +115,13 @@ fun jsonParserGenerator( returnSymbolToParseFn(codegenContext), listOf( ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(codegenContext), - ), + ) + additionalParserCustomizations, ) class ServerAwsJsonProtocol( private val serverCodegenContext: ServerCodegenContext, awsJsonVersion: AwsJsonVersion, + private val additionalParserCustomizations: List = listOf(), ) : AwsJson(serverCodegenContext, awsJsonVersion), ServerProtocol { private val runtimeConfig = codegenContext.runtimeConfig @@ -130,7 +132,7 @@ class ServerAwsJsonProtocol( } override fun structuredDataParser(): StructuredDataParserGenerator = - jsonParserGenerator(serverCodegenContext, httpBindingResolver, ::awsJsonFieldName) + jsonParserGenerator(serverCodegenContext, httpBindingResolver, ::awsJsonFieldName, additionalParserCustomizations) override fun structuredDataSerializer(): StructuredDataSerializerGenerator = ServerAwsJsonSerializerGenerator(serverCodegenContext, httpBindingResolver, awsJsonVersion) @@ -183,13 +185,14 @@ private fun restRouterType(runtimeConfig: RuntimeConfig) = class ServerRestJsonProtocol( private val serverCodegenContext: ServerCodegenContext, + private val additionalParserCustomizations: List = listOf(), ) : RestJson(serverCodegenContext), ServerProtocol { val runtimeConfig = codegenContext.runtimeConfig override val protocolModulePath: String = "rest_json_1" override fun structuredDataParser(): StructuredDataParserGenerator = - jsonParserGenerator(serverCodegenContext, httpBindingResolver, ::restJsonFieldName) + jsonParserGenerator(serverCodegenContext, httpBindingResolver, ::restJsonFieldName, additionalParserCustomizations) override fun structuredDataSerializer(): StructuredDataSerializerGenerator = ServerRestJsonSerializerGenerator(serverCodegenContext, httpBindingResolver) @@ -254,5 +257,6 @@ class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonPa rust(".map(|x| x.into())") } } + else -> emptySection } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt index 2a3467cd6..920813e34 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt @@ -10,11 +10,13 @@ import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.escape import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.core.smithy.protocols.awsJsonFieldName +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerSection @@ -30,13 +32,22 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser * AwsJson 1.0 and 1.1 server-side protocol factory. This factory creates the [ServerHttpBoundProtocolGenerator] * with AwsJson specific configurations. */ -class ServerAwsJsonFactory(private val version: AwsJsonVersion) : - ProtocolGeneratorFactory { +class ServerAwsJsonFactory( + private val version: AwsJsonVersion, + private val additionalParserCustomizations: List = listOf(), + private val additionalServerHttpBoundProtocolCustomizations: List = listOf(), + private val additionalHttpBindingCustomizations: List = listOf(), +) : ProtocolGeneratorFactory { override fun protocol(codegenContext: ServerCodegenContext): ServerProtocol = - ServerAwsJsonProtocol(codegenContext, version) + ServerAwsJsonProtocol(codegenContext, version, additionalParserCustomizations) override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator = - ServerHttpBoundProtocolGenerator(codegenContext, protocol(codegenContext)) + ServerHttpBoundProtocolGenerator( + codegenContext, + protocol(codegenContext), + additionalServerHttpBoundProtocolCustomizations, + additionalHttpBindingCustomizations, + ) override fun support(): ProtocolSupport { return ProtocolSupport( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 9a7e45f11..889967412 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -16,6 +16,7 @@ import software.amazon.smithy.model.pattern.UriPattern import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape @@ -42,8 +43,10 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization -import software.amazon.smithy.rust.codegen.core.smithy.generators.TypeConversionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.customize.Section +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName import software.amazon.smithy.rust.codegen.core.smithy.isOptional @@ -77,6 +80,18 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol import java.util.logging.Logger +/** + * Class describing a ServerHttpBoundProtocol section that can be used in a customization. + */ +sealed class ServerHttpBoundProtocolSection(name: String) : Section(name) { + data class AfterTimestampDeserializedMember(val shape: MemberShape) : ServerHttpBoundProtocolSection("AfterTimestampDeserializedMember") +} + +/** + * Customization for the ServerHttpBoundProtocol generator. + */ +typealias ServerHttpBoundProtocolCustomization = NamedCustomization + /** * Implement operations' input parsing and output serialization. Protocols can plug their own implementations * and overrides by creating a protocol factory inheriting from this class and feeding it to the [ServerProtocolLoader]. @@ -85,10 +100,12 @@ import java.util.logging.Logger class ServerHttpBoundProtocolGenerator( codegenContext: ServerCodegenContext, protocol: ServerProtocol, + customizations: List = listOf(), + additionalHttpBindingCustomizations: List = listOf(), ) : ServerProtocolGenerator( codegenContext, protocol, - ServerHttpBoundProtocolTraitImplGenerator(codegenContext, protocol), + ServerHttpBoundProtocolTraitImplGenerator(codegenContext, protocol, customizations, additionalHttpBindingCustomizations), ) { // Define suffixes for operation input / output / error wrappers companion object { @@ -104,6 +121,8 @@ class ServerHttpBoundProtocolGenerator( class ServerHttpBoundProtocolTraitImplGenerator( private val codegenContext: ServerCodegenContext, private val protocol: ServerProtocol, + private val customizations: List, + private val additionalHttpBindingCustomizations: List, ) { private val logger = Logger.getLogger(javaClass.name) private val symbolProvider = codegenContext.symbolProvider @@ -111,7 +130,6 @@ class ServerHttpBoundProtocolTraitImplGenerator( private val model = codegenContext.model private val runtimeConfig = codegenContext.runtimeConfig private val httpBindingResolver = protocol.httpBindingResolver - private val typeConversionGenerator = TypeConversionGenerator(model, symbolProvider, runtimeConfig) private val protocolFunctions = ProtocolFunctions(codegenContext) private val codegenScope = arrayOf( @@ -568,9 +586,9 @@ class ServerHttpBoundProtocolTraitImplGenerator( private fun serverRenderHttpResponseCode(defaultCode: Int) = writable { check(defaultCode in 100..999) { """ - Smithy library lied to us. According to https://smithy.io/2.0/spec/http-bindings.html#http-trait, - "The provided value SHOULD be between 100 and 599, and it MUST be between 100 and 999". - """.replace("\n", "").trimIndent() + Smithy library lied to us. According to https://smithy.io/2.0/spec/http-bindings.html#http-trait, + "The provided value SHOULD be between 100 and 599, and it MUST be between 100 and 999". + """.replace("\n", "").trimIndent() } rustTemplate( """ @@ -611,7 +629,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( inputShape: StructureShape, bindings: List, ) { - val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape) + val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations) val structuredDataParser = protocol.structuredDataParser() Attribute.AllowUnusedMut.render(this) rust( @@ -952,12 +970,15 @@ class ServerHttpBoundProtocolTraitImplGenerator( val timestampFormatType = RuntimeType.parseTimestampFormat(CodegenTarget.SERVER, runtimeConfig, timestampFormat) rustTemplate( """ - let v = #{DateTime}::from_str(&v, #{format})?#{ConvertInto:W}; + let v = #{DateTime}::from_str(&v, #{format})? """.trimIndent(), *codegenScope, "format" to timestampFormatType, - "ConvertInto" to typeConversionGenerator.convertViaInto(memberShape), ) + for (customization in customizations) { + customization.section(ServerHttpBoundProtocolSection.AfterTimestampDeserializedMember(it.member))(this) + } + rust(";") } else -> { // Number or boolean. rust( @@ -1047,7 +1068,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( } private fun serverRenderHeaderParser(writer: RustWriter, binding: HttpBindingDescriptor, operationShape: OperationShape) { - val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape) + val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations) val deserializer = httpBindingGenerator.generateDeserializeHeaderFn(binding) writer.rustTemplate( """ @@ -1109,22 +1130,24 @@ class ServerHttpBoundProtocolTraitImplGenerator( rustTemplate( """ let value = #{PercentEncoding}::percent_decode_str(value).decode_utf8()?; - let value = #{DateTime}::from_str(value.as_ref(), #{format})?#{ConvertInto:W}; + let value = #{DateTime}::from_str(value.as_ref(), #{format})? """, *codegenScope, "format" to timestampFormatType, - "ConvertInto" to typeConversionGenerator.convertViaInto(target), ) } else { rustTemplate( """ - let value = #{DateTime}::from_str(value, #{format})?#{ConvertInto:W}; + let value = #{DateTime}::from_str(value, #{format})? """, *codegenScope, "format" to timestampFormatType, - "ConvertInto" to typeConversionGenerator.convertViaInto(target), ) } + for (customization in customizations) { + customization.section(ServerHttpBoundProtocolSection.AfterTimestampDeserializedMember(binding.member))(this) + } + rust(";") } else -> { check(target is NumberShape || target is BooleanShape) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt index 744ef93bc..ddf1ca08c 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt @@ -5,10 +5,12 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGeneratorFactory +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.restJsonFieldName import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator @@ -21,11 +23,23 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser * RestJson1 server-side protocol factory. This factory creates the [ServerHttpProtocolGenerator] * with RestJson1 specific configurations. */ -class ServerRestJsonFactory : ProtocolGeneratorFactory { - override fun protocol(codegenContext: ServerCodegenContext): Protocol = ServerRestJsonProtocol(codegenContext) +class ServerRestJsonFactory( + private val additionalParserCustomizations: List = listOf(), + private val additionalServerHttpBoundProtocolCustomizations: List = listOf(), + private val additionalHttpBindingCustomizations: List = listOf(), +) : ProtocolGeneratorFactory { + override fun protocol(codegenContext: ServerCodegenContext): Protocol = ServerRestJsonProtocol(codegenContext, additionalParserCustomizations) override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator = - ServerHttpBoundProtocolGenerator(codegenContext, ServerRestJsonProtocol(codegenContext)) + ServerHttpBoundProtocolGenerator( + codegenContext, + ServerRestJsonProtocol( + codegenContext, + additionalParserCustomizations, + ), + additionalServerHttpBoundProtocolCustomizations, + additionalHttpBindingCustomizations, + ) override fun support(): ProtocolSupport { return ProtocolSupport( diff --git a/tools/ci-scripts/codegen-diff/diff_lib.py b/tools/ci-scripts/codegen-diff/diff_lib.py index 60fa3d6ed..09f7edbf8 100644 --- a/tools/ci-scripts/codegen-diff/diff_lib.py +++ b/tools/ci-scripts/codegen-diff/diff_lib.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import os -import sys -import subprocess import shlex +import subprocess +import sys HEAD_BRANCH_NAME = "__tmp-localonly-head" BASE_BRANCH_NAME = "__tmp-localonly-base" @@ -15,10 +15,10 @@ COMMIT_AUTHOR_EMAIL = "generated-code-action@github.com" CDN_URL = "https://d2luzm2xt3nokh.cloudfront.net" -PYTHON_EXAMPLES_PATH = "rust-runtime/aws-smithy-http-server-python/examples" - target_codegen_client = 'codegen-client-test' target_codegen_server = 'codegen-server-test' +target_codegen_server_python = 'codegen-server-test:python' +target_codegen_server_typescript = 'codegen-server-test:typescript' target_aws_sdk = 'aws:sdk' @@ -38,19 +38,19 @@ def checkout_commit_and_generate(revision_sha, branch_name, targets=None): def generate_and_commit_generated_code(revision_sha, targets=None): - targets = targets or [target_codegen_client, target_codegen_server, target_aws_sdk] + targets = targets or [ + target_codegen_client, + target_codegen_server, + target_aws_sdk, + target_codegen_server_python, + target_codegen_server_typescript + ] # Clean the build artifacts before continuing + assemble_tasks = ' '.join([f'{t}:assemble' for t in targets]) + clean_tasks = ' '.join([f'{t}:clean' for t in targets]) get_cmd_output("rm -rf aws/sdk/build") - if target_codegen_server in targets: - get_cmd_output("make distclean", shell=True, cwd=PYTHON_EXAMPLES_PATH) - get_cmd_output("./gradlew codegen-core:clean codegen-client:clean codegen-server:clean aws:sdk-codegen:clean") - - # Generate code - tasks = ' '.join([f'{t}:assemble' for t in targets]) - get_cmd_output(f"./gradlew --rerun-tasks {tasks}") - if target_codegen_server in targets: - get_cmd_output("make build", shell=True, check=False, cwd=PYTHON_EXAMPLES_PATH) - get_cmd_output(f"./gradlew --rerun-tasks codegen-server-test:typescript:assemble") + get_cmd_output(f"./gradlew --rerun-tasks {clean_tasks}") + get_cmd_output(f"./gradlew --rerun-tasks {assemble_tasks}") # Move generated code into codegen-diff/ directory get_cmd_output(f"rm -rf {OUTPUT_PATH}") @@ -61,12 +61,8 @@ def generate_and_commit_generated_code(revision_sha, targets=None): if target in targets: get_cmd_output(f"mv {target}/build/smithyprojections/{target} {OUTPUT_PATH}/") if target == target_codegen_server: - get_cmd_output( - f"mv {PYTHON_EXAMPLES_PATH}/pokemon-service-server-sdk/ {OUTPUT_PATH}/codegen-server-test-python/", - check=False) - get_cmd_output( - f"mv codegen-server-test/typescript/build/smithyprojections/codegen-server-test-typescript {OUTPUT_PATH}/", - check=False) + get_cmd_output(f"mv {target}/python/build/smithyprojections/{target}-python {OUTPUT_PATH}/") + get_cmd_output(f"mv {target}/typescript/build/smithyprojections/{target}-typescript {OUTPUT_PATH}/") # Clean up the SDK directory get_cmd_output(f"rm -f {OUTPUT_PATH}/aws-sdk/versions.toml") @@ -79,6 +75,7 @@ def generate_and_commit_generated_code(revision_sha, targets=None): # Clean up the server-test folder get_cmd_output(f"rm -rf {OUTPUT_PATH}/codegen-server-test/source") + get_cmd_output(f"rm -rf {OUTPUT_PATH}/codegen-server-test-python/source") get_cmd_output(f"rm -rf {OUTPUT_PATH}/codegen-server-test-typescript/source") run(f"find {OUTPUT_PATH}/codegen-server-test | " f"grep -E 'smithy-build-info.json|sources/manifest|model.json' | "