mirror of https://github.com/smithy-lang/smithy-rs
[Python] Support more testing model (#2541)
* Remove parameter from `Protocol`s `structuredDataParser`, `structuredDataSerializer` No implementation of the `Protocol` interface makes use of the `OperationShape` parameter in the `structuredDataParser` and `structuredDataSerializer` methods. * Remove the TypeConversionGenerator class in favor of using customizations for JsonParserGenerator and ServerHttpBoundProtocolGenerator. Signed-off-by: Bigo <1781140+crisidev@users.noreply.github.com> * Make the additionaParserCustomizations default to empty list * Fix merge conflict * Fix missing ; * Use better defaults when checking for customizations * Use better defaults when checking for customizations * Add HttpBindingCustomization and relax the datetime symbol check * Support recursive shapes and add a lot more models to the tests Signed-off-by: Bigo <1781140+crisidev@users.noreply.github.com> * Support naming obstacle course * Add support for constrained blobs conversions * Support constraint traits * Try to generate the full diff Signed-off-by: Bigo <1781140+crisidev@users.noreply.github.com> * A better way of checking if we need to go into the Timestamp branch * Remove wheels folder --------- Signed-off-by: Bigo <1781140+crisidev@users.noreply.github.com> Co-authored-by: david-perez <d@vidp.dev>
This commit is contained in:
parent
3aa4cc24a5
commit
d97defbd14
|
@ -56,3 +56,6 @@ target/
|
|||
|
||||
# tools
|
||||
.tool-versions
|
||||
|
||||
# python
|
||||
__pycache__
|
||||
|
|
|
@ -1,54 +0,0 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package software.amazon.smithy.rust.codegen.core.smithy.generators
|
||||
|
||||
import software.amazon.smithy.codegen.core.Symbol
|
||||
import software.amazon.smithy.model.Model
|
||||
import software.amazon.smithy.model.shapes.BlobShape
|
||||
import software.amazon.smithy.model.shapes.Shape
|
||||
import software.amazon.smithy.model.shapes.TimestampShape
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rust
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.writable
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.rustType
|
||||
|
||||
/*
|
||||
* Utility class used to force casting a non primitive type into one overriden by a new symbol provider,
|
||||
* by explicitly calling `from()` or into().
|
||||
*
|
||||
* For example we use this in the server Python implementation, where we override types like [Blob] and [DateTime]
|
||||
* with wrappers compatible with Python, without touching the original implementation coming from `aws-smithy-types`.
|
||||
*/
|
||||
class TypeConversionGenerator(private val model: Model, private val symbolProvider: RustSymbolProvider, private val runtimeConfig: RuntimeConfig) {
|
||||
private fun findOldSymbol(shape: Shape): Symbol {
|
||||
return when (shape) {
|
||||
is BlobShape -> RuntimeType.blob(runtimeConfig).toSymbol()
|
||||
is TimestampShape -> RuntimeType.dateTime(runtimeConfig).toSymbol()
|
||||
else -> symbolProvider.toSymbol(shape)
|
||||
}
|
||||
}
|
||||
|
||||
fun convertViaFrom(shape: Shape): Writable =
|
||||
writable {
|
||||
val oldSymbol = findOldSymbol(shape)
|
||||
val newSymbol = symbolProvider.toSymbol(shape)
|
||||
if (oldSymbol.rustType() != newSymbol.rustType()) {
|
||||
rust(".map($newSymbol::from)")
|
||||
}
|
||||
}
|
||||
|
||||
fun convertViaInto(shape: Shape): Writable =
|
||||
writable {
|
||||
val oldSymbol = findOldSymbol(shape)
|
||||
val newSymbol = symbolProvider.toSymbol(shape)
|
||||
if (oldSymbol.rustType() != newSymbol.rustType()) {
|
||||
rust(".into()")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -88,6 +88,9 @@ sealed class HttpBindingSection(name: String) : Section(name) {
|
|||
|
||||
data class AfterDeserializingIntoAHashMapOfHttpPrefixHeaders(val memberShape: MemberShape) :
|
||||
HttpBindingSection("AfterDeserializingIntoAHashMapOfHttpPrefixHeaders")
|
||||
|
||||
data class AfterDeserializingIntoADateTimeOfHttpHeaders(val memberShape: MemberShape) :
|
||||
HttpBindingSection("AfterDeserializingIntoADateTimeOfHttpHeaders")
|
||||
}
|
||||
|
||||
typealias HttpBindingCustomization = NamedCustomization<HttpBindingSection>
|
||||
|
@ -353,7 +356,7 @@ class HttpBindingGenerator(
|
|||
rustType to targetShape
|
||||
}
|
||||
val parsedValue = safeName()
|
||||
if (coreType == dateTime) {
|
||||
if (coreShape.isTimestampShape()) {
|
||||
val timestampFormat =
|
||||
index.determineTimestampFormat(
|
||||
memberShape,
|
||||
|
@ -362,10 +365,14 @@ class HttpBindingGenerator(
|
|||
)
|
||||
val timestampFormatType = RuntimeType.parseTimestampFormat(codegenTarget, runtimeConfig, timestampFormat)
|
||||
rust(
|
||||
"let $parsedValue: Vec<${coreType.render()}> = #T::many_dates(headers, #T)?;",
|
||||
"let $parsedValue: Vec<${coreType.render()}> = #T::many_dates(headers, #T)?",
|
||||
headerUtil,
|
||||
timestampFormatType,
|
||||
)
|
||||
for (customization in customizations) {
|
||||
customization.section(HttpBindingSection.AfterDeserializingIntoADateTimeOfHttpHeaders(memberShape))(this)
|
||||
}
|
||||
rust(";")
|
||||
} else if (coreShape.isPrimitive()) {
|
||||
rust(
|
||||
"let $parsedValue = #T::read_many_primitive::<${coreType.render()}>(headers)?;",
|
||||
|
|
|
@ -39,7 +39,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
|
|||
import software.amazon.smithy.rust.codegen.core.smithy.canUseDefault
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.TypeConversionGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName
|
||||
|
@ -61,6 +60,12 @@ import software.amazon.smithy.utils.StringUtils
|
|||
*/
|
||||
sealed class JsonParserSection(name: String) : Section(name) {
|
||||
data class BeforeBoxingDeserializedMember(val shape: MemberShape) : JsonParserSection("BeforeBoxingDeserializedMember")
|
||||
|
||||
data class AfterTimestampDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterTimestampDeserializedMember")
|
||||
|
||||
data class AfterBlobDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterBlobDeserializedMember")
|
||||
|
||||
data class AfterDocumentDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterDocumentDeserializedMember")
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -94,7 +99,6 @@ class JsonParserGenerator(
|
|||
private val runtimeConfig = codegenContext.runtimeConfig
|
||||
private val codegenTarget = codegenContext.target
|
||||
private val smithyJson = CargoDependency.smithyJson(runtimeConfig).toType()
|
||||
private val typeConversionGenerator = TypeConversionGenerator(model, symbolProvider, runtimeConfig)
|
||||
private val protocolFunctions = ProtocolFunctions(codegenContext)
|
||||
private val codegenScope = arrayOf(
|
||||
"Error" to smithyJson.resolve("deserialize::error::DeserializeError"),
|
||||
|
@ -276,13 +280,13 @@ class JsonParserGenerator(
|
|||
is StringShape -> deserializeString(target)
|
||||
is BooleanShape -> rustTemplate("#{expect_bool_or_null}(tokens.next())?", *codegenScope)
|
||||
is NumberShape -> deserializeNumber(target)
|
||||
is BlobShape -> deserializeBlob()
|
||||
is BlobShape -> deserializeBlob(memberShape)
|
||||
is TimestampShape -> deserializeTimestamp(target, memberShape)
|
||||
is CollectionShape -> deserializeCollection(target)
|
||||
is MapShape -> deserializeMap(target)
|
||||
is StructureShape -> deserializeStruct(target)
|
||||
is UnionShape -> deserializeUnion(target)
|
||||
is DocumentShape -> rustTemplate("Some(#{expect_document}(tokens)?)", *codegenScope)
|
||||
is DocumentShape -> deserializeDocument(memberShape)
|
||||
else -> PANIC("unexpected shape: $target")
|
||||
}
|
||||
val symbol = symbolProvider.toSymbol(memberShape)
|
||||
|
@ -294,11 +298,21 @@ class JsonParserGenerator(
|
|||
}
|
||||
}
|
||||
|
||||
private fun RustWriter.deserializeBlob() {
|
||||
private fun RustWriter.deserializeDocument(member: MemberShape) {
|
||||
rustTemplate("Some(#{expect_document}(tokens)?)", *codegenScope)
|
||||
for (customization in customizations) {
|
||||
customization.section(JsonParserSection.AfterDocumentDeserializedMember(member))(this)
|
||||
}
|
||||
}
|
||||
|
||||
private fun RustWriter.deserializeBlob(member: MemberShape) {
|
||||
rustTemplate(
|
||||
"#{expect_blob_or_null}(tokens.next())?",
|
||||
*codegenScope,
|
||||
)
|
||||
for (customization in customizations) {
|
||||
customization.section(JsonParserSection.AfterBlobDeserializedMember(member))(this)
|
||||
}
|
||||
}
|
||||
|
||||
private fun RustWriter.deserializeStringInner(target: StringShape, escapedStrName: String) {
|
||||
|
@ -349,9 +363,12 @@ class JsonParserGenerator(
|
|||
)
|
||||
val timestampFormatType = RuntimeType.parseTimestampFormat(codegenTarget, runtimeConfig, timestampFormat)
|
||||
rustTemplate(
|
||||
"#{expect_timestamp_or_null}(tokens.next(), #{T})?#{ConvertFrom:W}",
|
||||
"T" to timestampFormatType, "ConvertFrom" to typeConversionGenerator.convertViaFrom(shape), *codegenScope,
|
||||
"#{expect_timestamp_or_null}(tokens.next(), #{T})?",
|
||||
"T" to timestampFormatType, *codegenScope,
|
||||
)
|
||||
for (customization in customizations) {
|
||||
customization.section(JsonParserSection.AfterTimestampDeserializedMember(member))(this)
|
||||
}
|
||||
}
|
||||
|
||||
private fun RustWriter.deserializeCollection(shape: CollectionShape) {
|
||||
|
|
|
@ -54,14 +54,49 @@ val allCodegenTests = "../../codegen-core/common-test-models".let { commonModels
|
|||
// TODO(https://github.com/awslabs/smithy-rs/issues/1401) `@uniqueItems` is used.
|
||||
extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """,
|
||||
),
|
||||
// TODO(https://github.com/awslabs/smithy-rs/issues/2476)
|
||||
CodegenTest(
|
||||
"aws.protocoltests.json#JsonProtocol",
|
||||
"json_rpc11",
|
||||
extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """,
|
||||
),
|
||||
CodegenTest("aws.protocoltests.json10#JsonRpc10", "json_rpc10"),
|
||||
CodegenTest("aws.protocoltests.restjson#RestJson", "rest_json"),
|
||||
CodegenTest(
|
||||
"aws.protocoltests.restjson#RestJsonExtras",
|
||||
"rest_json_extras",
|
||||
imports = listOf("$commonModels/rest-json-extras.smithy"),
|
||||
),
|
||||
// TODO(https://github.com/awslabs/smithy-rs/issues/2551)
|
||||
// CodegenTest(
|
||||
// "aws.protocoltests.json#JsonProtocol",
|
||||
// "json_rpc11",
|
||||
// "aws.protocoltests.restjson.validation#RestJsonValidation",
|
||||
// "rest_json_validation",
|
||||
// // `@range` trait is used on floating point shapes, which we deliberately don't want to support.
|
||||
// // See https://github.com/awslabs/smithy-rs/issues/1401.
|
||||
// extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """,
|
||||
// ),
|
||||
// TODO(https://github.com/awslabs/smithy-rs/issues/2479)
|
||||
// CodegenTest("aws.protocoltests.json10#JsonRpc10", "json_rpc10"),
|
||||
CodegenTest(
|
||||
"com.amazonaws.constraints#ConstraintsService",
|
||||
"constraints",
|
||||
imports = listOf("$commonModels/constraints.smithy"),
|
||||
),
|
||||
CodegenTest(
|
||||
"com.amazonaws.constraints#ConstraintsService",
|
||||
"constraints_without_public_constrained_types",
|
||||
imports = listOf("$commonModels/constraints.smithy"),
|
||||
extraConfig = """, "codegen": { "publicConstrainedTypes": false } """,
|
||||
),
|
||||
CodegenTest(
|
||||
"com.amazonaws.constraints#UniqueItemsService",
|
||||
"unique_items",
|
||||
imports = listOf("$commonModels/unique-items.smithy"),
|
||||
),
|
||||
CodegenTest(
|
||||
"naming_obs_structs#NamingObstacleCourseStructs",
|
||||
"naming_test_structs",
|
||||
imports = listOf("$commonModels/naming-obstacle-course-structs.smithy"),
|
||||
),
|
||||
CodegenTest("casing#ACRONYMInside_Service", "naming_test_casing", imports = listOf("$commonModels/naming-obstacle-course-casing.smithy")),
|
||||
CodegenTest("crate#Config", "naming_test_ops", imports = listOf("$commonModels/naming-obstacle-course-ops.smithy")),
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.server.python.smithy
|
|||
import software.amazon.smithy.build.PluginContext
|
||||
import software.amazon.smithy.model.Model
|
||||
import software.amazon.smithy.model.knowledge.NullableIndex
|
||||
import software.amazon.smithy.model.shapes.BlobShape
|
||||
import software.amazon.smithy.model.shapes.MemberShape
|
||||
import software.amazon.smithy.model.shapes.OperationShape
|
||||
import software.amazon.smithy.model.shapes.ServiceShape
|
||||
|
@ -22,6 +23,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig
|
|||
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.util.getTrait
|
||||
import software.amazon.smithy.rust.codegen.core.util.isEventStream
|
||||
import software.amazon.smithy.rust.codegen.server.python.smithy.generators.ConstrainedPythonBlobGenerator
|
||||
import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonApplicationGenerator
|
||||
import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerEnumGenerator
|
||||
import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerEventStreamErrorGenerator
|
||||
|
@ -30,6 +32,7 @@ import software.amazon.smithy.rust.codegen.server.python.smithy.generators.Pytho
|
|||
import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerOperationHandlerGenerator
|
||||
import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerStructureGenerator
|
||||
import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerUnionGenerator
|
||||
import software.amazon.smithy.rust.codegen.server.python.smithy.protocols.PythonServerProtocolLoader
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenVisitor
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.ServerModuleDocProvider
|
||||
|
@ -42,8 +45,9 @@ import software.amazon.smithy.rust.codegen.server.smithy.createInlineModuleCreat
|
|||
import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedUnionGenerator
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.withModuleOrWithStructureBuilderModule
|
||||
|
||||
/**
|
||||
* Entrypoint for Python server-side code generation. This class will walk the in-memory model and
|
||||
|
@ -68,10 +72,10 @@ class PythonServerCodegenVisitor(
|
|||
val baseModel = baselineTransform(context.model)
|
||||
val service = settings.getService(baseModel)
|
||||
val (protocol, generator) =
|
||||
ServerProtocolLoader(
|
||||
PythonServerProtocolLoader(
|
||||
codegenDecorator.protocols(
|
||||
service.id,
|
||||
ServerProtocolLoader.DefaultProtocols,
|
||||
PythonServerProtocolLoader.defaultProtocols(settings.runtimeConfig),
|
||||
),
|
||||
)
|
||||
.protocolFor(context.model, service)
|
||||
|
@ -258,4 +262,21 @@ class PythonServerCodegenVisitor(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun blobShape(shape: BlobShape) {
|
||||
logger.info("[python-server-codegen] Generating a service $shape")
|
||||
super.blobShape(shape)
|
||||
|
||||
if (shape.isDirectlyConstrained(codegenContext.symbolProvider)) {
|
||||
rustCrate.withModuleOrWithStructureBuilderModule(ServerRustModule.Model, shape, codegenContext) {
|
||||
ConstrainedPythonBlobGenerator(
|
||||
codegenContext,
|
||||
rustCrate.createInlineModuleCreator(),
|
||||
this,
|
||||
shape,
|
||||
validationExceptionConversionGenerator,
|
||||
).render()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package software.amazon.smithy.rust.codegen.server.python.smithy.generators
|
||||
|
||||
import software.amazon.smithy.codegen.core.Symbol
|
||||
import software.amazon.smithy.model.shapes.BlobShape
|
||||
import software.amazon.smithy.model.traits.LengthTrait
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.join
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.render
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rust
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.rustType
|
||||
import software.amazon.smithy.rust.codegen.core.util.getTrait
|
||||
import software.amazon.smithy.rust.codegen.core.util.orNull
|
||||
import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerRuntimeType
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.InlineModuleCreator
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.generators.BlobLength
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.generators.TraitInfo
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
|
||||
|
||||
class ConstrainedPythonBlobGenerator(
|
||||
val codegenContext: ServerCodegenContext,
|
||||
private val inlineModuleCreator: InlineModuleCreator,
|
||||
val writer: RustWriter,
|
||||
val shape: BlobShape,
|
||||
private val validationExceptionConversionGenerator: ValidationExceptionConversionGenerator,
|
||||
) {
|
||||
val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider
|
||||
val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes
|
||||
private val constraintViolationSymbolProvider =
|
||||
with(codegenContext.constraintViolationSymbolProvider) {
|
||||
if (publicConstrainedTypes) {
|
||||
this
|
||||
} else {
|
||||
PubCrateConstraintViolationSymbolProvider(this)
|
||||
}
|
||||
}
|
||||
val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape)
|
||||
private val blobConstraintsInfo: List<BlobLength> = listOf(LengthTrait::class.java)
|
||||
.mapNotNull { shape.getTrait(it).orNull() }
|
||||
.map { BlobLength(it) }
|
||||
private val constraintsInfo: List<TraitInfo> = blobConstraintsInfo.map { it.toTraitInfo() }
|
||||
|
||||
fun render() {
|
||||
val symbol = constrainedShapeSymbolProvider.toSymbol(shape)
|
||||
val blobType = PythonServerRuntimeType.blob(codegenContext.runtimeConfig).toSymbol().rustType()
|
||||
renderFrom(symbol, blobType)
|
||||
renderTryFrom(symbol, blobType)
|
||||
}
|
||||
|
||||
fun renderFrom(symbol: Symbol, blobType: RustType) {
|
||||
val name = symbol.name
|
||||
val inner = blobType.render()
|
||||
writer.rustTemplate(
|
||||
"""
|
||||
impl #{From}<$inner> for #{MaybeConstrained} {
|
||||
fn from(value: $inner) -> Self {
|
||||
Self::Unconstrained(value.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl #{From}<$name> for $inner {
|
||||
fn from(value: $name) -> Self {
|
||||
value.into_inner().into()
|
||||
}
|
||||
}
|
||||
""",
|
||||
"MaybeConstrained" to symbol.makeMaybeConstrained(),
|
||||
"From" to RuntimeType.From,
|
||||
)
|
||||
}
|
||||
|
||||
fun renderTryFrom(symbol: Symbol, blobType: RustType) {
|
||||
val name = symbol.name
|
||||
val inner = blobType.render()
|
||||
writer.rustTemplate(
|
||||
"""
|
||||
impl #{TryFrom}<$inner> for $name {
|
||||
type Error = #{ConstraintViolation};
|
||||
|
||||
fn try_from(value: $inner) -> Result<Self, Self::Error> {
|
||||
value.try_into()
|
||||
}
|
||||
}
|
||||
""",
|
||||
"TryFrom" to RuntimeType.TryFrom,
|
||||
"ConstraintViolation" to constraintViolation,
|
||||
"TryFromChecks" to constraintsInfo.map { it.tryFromCheck }.join("\n"),
|
||||
)
|
||||
}
|
||||
}
|
|
@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.server.python.smithy.generators
|
|||
import software.amazon.smithy.model.knowledge.TopDownIndex
|
||||
import software.amazon.smithy.model.shapes.OperationShape
|
||||
import software.amazon.smithy.model.traits.DocumentationTrait
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rust
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
|
||||
|
@ -203,14 +204,13 @@ class PythonApplicationGenerator(
|
|||
*codegenScope,
|
||||
)
|
||||
for (operation in operations) {
|
||||
val operationName = symbolProvider.toSymbol(operation).name
|
||||
val name = operationName.toSnakeCase()
|
||||
val fnName = RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(operation).name.toSnakeCase())
|
||||
rustTemplate(
|
||||
"""
|
||||
let ${name}_locals = #{pyo3_asyncio}::TaskLocals::new(event_loop);
|
||||
let handler = self.handlers.get("$name").expect("Python handler for operation `$name` not found").clone();
|
||||
let builder = builder.$name(move |input, state| {
|
||||
#{pyo3_asyncio}::tokio::scope(${name}_locals.clone(), crate::python_operation_adaptor::$name(input, state, handler.clone()))
|
||||
let ${fnName}_locals = #{pyo3_asyncio}::TaskLocals::new(event_loop);
|
||||
let handler = self.handlers.get("$fnName").expect("Python handler for operation `$fnName` not found").clone();
|
||||
let builder = builder.$fnName(move |input, state| {
|
||||
#{pyo3_asyncio}::tokio::scope(${fnName}_locals.clone(), crate::python_operation_adaptor::$fnName(input, state, handler.clone()))
|
||||
});
|
||||
""",
|
||||
*codegenScope,
|
||||
|
@ -342,7 +342,7 @@ class PythonApplicationGenerator(
|
|||
)
|
||||
operations.map { operation ->
|
||||
val operationName = symbolProvider.toSymbol(operation).name
|
||||
val name = operationName.toSnakeCase()
|
||||
val fnName = RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(operation).name.toSnakeCase())
|
||||
|
||||
val input = PythonType.Opaque("${operationName}Input", "crate::input")
|
||||
val output = PythonType.Opaque("${operationName}Output", "crate::output")
|
||||
|
@ -363,15 +363,15 @@ class PythonApplicationGenerator(
|
|||
|
||||
rustTemplate(
|
||||
"""
|
||||
/// Method to register `$name` Python implementation inside the handlers map.
|
||||
/// Method to register `$fnName` Python implementation inside the handlers map.
|
||||
/// It can be used as a function decorator in Python.
|
||||
///
|
||||
/// :param func ${handler.renderAsDocstring()}:
|
||||
/// :rtype ${PythonType.None.renderAsDocstring()}:
|
||||
##[pyo3(text_signature = "(${'$'}self, func)")]
|
||||
pub fn $name(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> {
|
||||
pub fn $fnName(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> {
|
||||
use #{SmithyPython}::PyApp;
|
||||
self.register_operation(py, "$name", func)
|
||||
self.register_operation(py, "$fnName", func)
|
||||
}
|
||||
""",
|
||||
*codegenScope,
|
||||
|
|
|
@ -15,6 +15,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
|
|||
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
|
||||
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
|
||||
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
|
||||
import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency
|
||||
import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerRustModule
|
||||
|
@ -71,13 +72,13 @@ class PythonServerModuleGenerator(
|
|||
when (shape) {
|
||||
is UnionShape -> rustTemplate(
|
||||
"""
|
||||
$moduleType.add_class::<crate::$moduleType::PyUnionMarker${shape.id.name}>()?;
|
||||
$moduleType.add_class::<crate::$moduleType::PyUnionMarker${shape.id.name.toPascalCase()}>()?;
|
||||
""",
|
||||
*codegenScope,
|
||||
)
|
||||
else -> rustTemplate(
|
||||
"""
|
||||
$moduleType.add_class::<crate::$moduleType::${shape.id.name}>()?;
|
||||
$moduleType.add_class::<crate::$moduleType::${shape.id.name.toPascalCase()}>()?;
|
||||
""",
|
||||
*codegenScope,
|
||||
)
|
||||
|
|
|
@ -6,11 +6,13 @@
|
|||
package software.amazon.smithy.rust.codegen.server.python.smithy.generators
|
||||
|
||||
import software.amazon.smithy.model.shapes.OperationShape
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.writable
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
|
||||
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
|
||||
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
|
||||
import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency
|
||||
|
||||
|
@ -50,10 +52,11 @@ class PythonServerOperationHandlerGenerator(
|
|||
|
||||
private fun renderPythonOperationHandlerImpl(writer: RustWriter) {
|
||||
val operationName = symbolProvider.toSymbol(operation).name
|
||||
val input = "crate::input::${operationName}Input"
|
||||
val output = "crate::output::${operationName}Output"
|
||||
val input = "crate::input::${operationName.toPascalCase()}Input"
|
||||
val output = "crate::output::${operationName.toPascalCase()}Output"
|
||||
// TODO(https://github.com/awslabs/smithy-rs/issues/2552) - Use to pascalCase for error shapes.
|
||||
val error = "crate::error::${operationName}Error"
|
||||
val fnName = operationName.toSnakeCase()
|
||||
val fnName = RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(operation).name.toSnakeCase())
|
||||
|
||||
writer.rustTemplate(
|
||||
"""
|
||||
|
|
|
@ -60,6 +60,9 @@ class PythonServerStructureGenerator(
|
|||
writer.rustTemplate("#{ConstructorSignature:W}", "ConstructorSignature" to renderConstructorSignature())
|
||||
super.renderStructure()
|
||||
renderPyO3Methods()
|
||||
if (!shape.hasTrait<ErrorTrait>()) {
|
||||
renderPyBoxTraits()
|
||||
}
|
||||
}
|
||||
|
||||
override fun renderStructureMember(
|
||||
|
@ -101,6 +104,25 @@ class PythonServerStructureGenerator(
|
|||
)
|
||||
}
|
||||
|
||||
private fun renderPyBoxTraits() {
|
||||
writer.rustTemplate(
|
||||
"""
|
||||
impl<'source> #{pyo3}::FromPyObject<'source> for std::boxed::Box<$name> {
|
||||
fn extract(ob: &'source #{pyo3}::PyAny) -> #{pyo3}::PyResult<Self> {
|
||||
ob.extract::<$name>().map(Box::new)
|
||||
}
|
||||
}
|
||||
|
||||
impl #{pyo3}::IntoPy<#{pyo3}::PyObject> for std::boxed::Box<$name> {
|
||||
fn into_py(self, py: #{pyo3}::Python<'_>) -> #{pyo3}::PyObject {
|
||||
(*self).into_py(py)
|
||||
}
|
||||
}
|
||||
""",
|
||||
"pyo3" to pyO3,
|
||||
)
|
||||
}
|
||||
|
||||
private fun renderStructSignatureMembers(): Writable =
|
||||
writable {
|
||||
forEachMember(members) { _, memberName, memberSymbol ->
|
||||
|
|
|
@ -121,7 +121,7 @@ class PythonServerUnionGenerator(
|
|||
)
|
||||
writer.rust("/// :rtype ${unionSymbol.name}:")
|
||||
writer.rustBlock("pub fn $funcNamePart() -> Self") {
|
||||
rust("Self(${unionSymbol.name}::$variantName")
|
||||
rust("Self(${unionSymbol.name}::$variantName)")
|
||||
}
|
||||
} else {
|
||||
val memberSymbol = symbolProvider.toSymbol(member)
|
||||
|
@ -157,7 +157,7 @@ class PythonServerUnionGenerator(
|
|||
writer.rustBlockTemplate("pub fn as_$funcNamePart(&self) -> #{pyo3}::PyResult<()>", "pyo3" to pyo3) {
|
||||
rustTemplate(
|
||||
"""
|
||||
self.0.as_$funcNamePart().map_err(#{pyo3}::exceptions::PyValueError::new_err(
|
||||
self.0.as_$funcNamePart().map_err(|_| #{pyo3}::exceptions::PyValueError::new_err(
|
||||
"${unionSymbol.name} variant is not None"
|
||||
))
|
||||
""",
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
|
||||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package software.amazon.smithy.rust.codegen.server.python.smithy.protocols
|
||||
|
||||
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
|
||||
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
|
||||
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rust
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.writable
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingSection
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolLoader
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserCustomization
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserSection
|
||||
import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerRuntimeType
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJsonFactory
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBoundProtocolCustomization
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBoundProtocolSection
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJsonFactory
|
||||
|
||||
/**
|
||||
* Customization class used to force casting a non primitive type into one overriden by a new symbol provider,
|
||||
* by explicitly calling `from()` on it.
|
||||
*
|
||||
* For example we use this in the server Python implementation, where we override types like [Blob], [DateTime] and [Document]
|
||||
* with wrappers compatible with Python, without touching the original implementation coming from `aws-smithy-types`.
|
||||
*/
|
||||
class PythonServerAfterDeserializedMemberJsonParserCustomization(private val runtimeConfig: RuntimeConfig) :
|
||||
JsonParserCustomization() {
|
||||
override fun section(section: JsonParserSection): Writable = when (section) {
|
||||
is JsonParserSection.AfterTimestampDeserializedMember -> writable {
|
||||
rust(".map(#T::from)", PythonServerRuntimeType.dateTime(runtimeConfig).toSymbol())
|
||||
}
|
||||
is JsonParserSection.AfterBlobDeserializedMember -> writable {
|
||||
rust(".map(#T::from)", PythonServerRuntimeType.blob(runtimeConfig).toSymbol())
|
||||
}
|
||||
is JsonParserSection.AfterDocumentDeserializedMember -> writable {
|
||||
rust(".map(#T::from)", PythonServerRuntimeType.document(runtimeConfig).toSymbol())
|
||||
}
|
||||
else -> emptySection
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Customization class used to force casting a non primitive type into one overriden by a new symbol provider,
|
||||
* by explicitly calling `into()` on it.
|
||||
*/
|
||||
class PythonServerAfterDeserializedMemberServerHttpBoundCustomization() :
|
||||
ServerHttpBoundProtocolCustomization() {
|
||||
override fun section(section: ServerHttpBoundProtocolSection): Writable = when (section) {
|
||||
is ServerHttpBoundProtocolSection.AfterTimestampDeserializedMember -> writable {
|
||||
rust(".into()")
|
||||
}
|
||||
else -> emptySection
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Customization class used to force casting a `Vec<DateTime>` into one a Python `Vec<DateTime>`
|
||||
*/
|
||||
class PythonServerAfterDeserializedMemberHttpBindingCustomization(private val runtimeConfig: RuntimeConfig) :
|
||||
HttpBindingCustomization() {
|
||||
override fun section(section: HttpBindingSection): Writable = when (section) {
|
||||
is HttpBindingSection.AfterDeserializingIntoADateTimeOfHttpHeaders -> writable {
|
||||
rust(".into_iter().map(#T::from).collect()", PythonServerRuntimeType.dateTime(runtimeConfig).toSymbol())
|
||||
}
|
||||
else -> emptySection
|
||||
}
|
||||
}
|
||||
|
||||
class PythonServerProtocolLoader(
|
||||
private val supportedProtocols: ProtocolMap<ServerProtocolGenerator, ServerCodegenContext>,
|
||||
) : ProtocolLoader<ServerProtocolGenerator, ServerCodegenContext>(supportedProtocols) {
|
||||
|
||||
companion object {
|
||||
fun defaultProtocols(runtimeConfig: RuntimeConfig) =
|
||||
mapOf(
|
||||
RestJson1Trait.ID to ServerRestJsonFactory(
|
||||
additionalParserCustomizations = listOf(
|
||||
PythonServerAfterDeserializedMemberJsonParserCustomization(runtimeConfig),
|
||||
),
|
||||
additionalServerHttpBoundProtocolCustomizations = listOf(
|
||||
PythonServerAfterDeserializedMemberServerHttpBoundCustomization(),
|
||||
),
|
||||
additionalHttpBindingCustomizations = listOf(
|
||||
PythonServerAfterDeserializedMemberHttpBindingCustomization(runtimeConfig),
|
||||
),
|
||||
),
|
||||
AwsJson1_0Trait.ID to ServerAwsJsonFactory(
|
||||
AwsJsonVersion.Json10,
|
||||
additionalParserCustomizations = listOf(
|
||||
PythonServerAfterDeserializedMemberJsonParserCustomization(runtimeConfig),
|
||||
),
|
||||
additionalServerHttpBoundProtocolCustomizations = listOf(
|
||||
PythonServerAfterDeserializedMemberServerHttpBoundCustomization(),
|
||||
),
|
||||
additionalHttpBindingCustomizations = listOf(
|
||||
PythonServerAfterDeserializedMemberHttpBindingCustomization(runtimeConfig),
|
||||
),
|
||||
),
|
||||
AwsJson1_1Trait.ID to ServerAwsJsonFactory(
|
||||
AwsJsonVersion.Json11,
|
||||
additionalParserCustomizations = listOf(
|
||||
PythonServerAfterDeserializedMemberJsonParserCustomization(runtimeConfig),
|
||||
),
|
||||
additionalServerHttpBoundProtocolCustomizations = listOf(
|
||||
PythonServerAfterDeserializedMemberServerHttpBoundCustomization(),
|
||||
),
|
||||
additionalHttpBindingCustomizations = listOf(
|
||||
PythonServerAfterDeserializedMemberHttpBindingCustomization(runtimeConfig),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
|
@ -28,6 +28,7 @@ class ServerRequestBindingGenerator(
|
|||
protocol: Protocol,
|
||||
codegenContext: ServerCodegenContext,
|
||||
operationShape: OperationShape,
|
||||
additionalHttpBindingCustomizations: List<HttpBindingCustomization> = listOf(),
|
||||
) {
|
||||
private val httpBindingGenerator =
|
||||
HttpBindingGenerator(
|
||||
|
@ -39,7 +40,7 @@ class ServerRequestBindingGenerator(
|
|||
ServerRequestAfterDeserializingIntoAHashMapOfHttpPrefixHeadersWrapInUnconstrainedMapHttpBindingCustomization(
|
||||
codegenContext,
|
||||
),
|
||||
),
|
||||
) + additionalHttpBindingCustomizations,
|
||||
)
|
||||
|
||||
fun generateDeserializeHeaderFn(binding: HttpBindingDescriptor): RuntimeType =
|
||||
|
@ -81,5 +82,6 @@ class ServerRequestAfterDeserializingIntoAHashMapOfHttpPrefixHeadersWrapInUncons
|
|||
)
|
||||
}
|
||||
}
|
||||
else -> emptySection
|
||||
}
|
||||
}
|
||||
|
|
|
@ -71,6 +71,7 @@ class ServerResponseBeforeIteratingOverMapBoundWithHttpPrefixHeadersUnwrapConstr
|
|||
|
||||
is HttpBindingSection.BeforeRenderingHeaderValue,
|
||||
is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders,
|
||||
is HttpBindingSection.AfterDeserializingIntoADateTimeOfHttpHeaders,
|
||||
-> emptySection
|
||||
}
|
||||
}
|
||||
|
@ -100,6 +101,7 @@ class ServerResponseBeforeRenderingHeadersHttpBindingCustomization(val codegenCo
|
|||
|
||||
is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders,
|
||||
is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders,
|
||||
is HttpBindingSection.AfterDeserializingIntoADateTimeOfHttpHeaders,
|
||||
-> emptySection
|
||||
}
|
||||
}
|
||||
|
|
|
@ -106,6 +106,7 @@ fun jsonParserGenerator(
|
|||
codegenContext: ServerCodegenContext,
|
||||
httpBindingResolver: HttpBindingResolver,
|
||||
jsonName: (MemberShape) -> String,
|
||||
additionalParserCustomizations: List<JsonParserCustomization> = listOf(),
|
||||
): JsonParserGenerator =
|
||||
JsonParserGenerator(
|
||||
codegenContext,
|
||||
|
@ -114,12 +115,13 @@ fun jsonParserGenerator(
|
|||
returnSymbolToParseFn(codegenContext),
|
||||
listOf(
|
||||
ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(codegenContext),
|
||||
),
|
||||
) + additionalParserCustomizations,
|
||||
)
|
||||
|
||||
class ServerAwsJsonProtocol(
|
||||
private val serverCodegenContext: ServerCodegenContext,
|
||||
awsJsonVersion: AwsJsonVersion,
|
||||
private val additionalParserCustomizations: List<JsonParserCustomization> = listOf(),
|
||||
) : AwsJson(serverCodegenContext, awsJsonVersion), ServerProtocol {
|
||||
private val runtimeConfig = codegenContext.runtimeConfig
|
||||
|
||||
|
@ -130,7 +132,7 @@ class ServerAwsJsonProtocol(
|
|||
}
|
||||
|
||||
override fun structuredDataParser(): StructuredDataParserGenerator =
|
||||
jsonParserGenerator(serverCodegenContext, httpBindingResolver, ::awsJsonFieldName)
|
||||
jsonParserGenerator(serverCodegenContext, httpBindingResolver, ::awsJsonFieldName, additionalParserCustomizations)
|
||||
|
||||
override fun structuredDataSerializer(): StructuredDataSerializerGenerator =
|
||||
ServerAwsJsonSerializerGenerator(serverCodegenContext, httpBindingResolver, awsJsonVersion)
|
||||
|
@ -183,13 +185,14 @@ private fun restRouterType(runtimeConfig: RuntimeConfig) =
|
|||
|
||||
class ServerRestJsonProtocol(
|
||||
private val serverCodegenContext: ServerCodegenContext,
|
||||
private val additionalParserCustomizations: List<JsonParserCustomization> = listOf(),
|
||||
) : RestJson(serverCodegenContext), ServerProtocol {
|
||||
val runtimeConfig = codegenContext.runtimeConfig
|
||||
|
||||
override val protocolModulePath: String = "rest_json_1"
|
||||
|
||||
override fun structuredDataParser(): StructuredDataParserGenerator =
|
||||
jsonParserGenerator(serverCodegenContext, httpBindingResolver, ::restJsonFieldName)
|
||||
jsonParserGenerator(serverCodegenContext, httpBindingResolver, ::restJsonFieldName, additionalParserCustomizations)
|
||||
|
||||
override fun structuredDataSerializer(): StructuredDataSerializerGenerator =
|
||||
ServerRestJsonSerializerGenerator(serverCodegenContext, httpBindingResolver)
|
||||
|
@ -254,5 +257,6 @@ class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonPa
|
|||
rust(".map(|x| x.into())")
|
||||
}
|
||||
}
|
||||
else -> emptySection
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,11 +10,13 @@ import software.amazon.smithy.rust.codegen.core.rustlang.Writable
|
|||
import software.amazon.smithy.rust.codegen.core.rustlang.escape
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rust
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.writable
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGeneratorFactory
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.awsJsonFieldName
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserCustomization
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerCustomization
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerSection
|
||||
|
@ -30,13 +32,22 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser
|
|||
* AwsJson 1.0 and 1.1 server-side protocol factory. This factory creates the [ServerHttpBoundProtocolGenerator]
|
||||
* with AwsJson specific configurations.
|
||||
*/
|
||||
class ServerAwsJsonFactory(private val version: AwsJsonVersion) :
|
||||
ProtocolGeneratorFactory<ServerHttpBoundProtocolGenerator, ServerCodegenContext> {
|
||||
class ServerAwsJsonFactory(
|
||||
private val version: AwsJsonVersion,
|
||||
private val additionalParserCustomizations: List<JsonParserCustomization> = listOf(),
|
||||
private val additionalServerHttpBoundProtocolCustomizations: List<ServerHttpBoundProtocolCustomization> = listOf(),
|
||||
private val additionalHttpBindingCustomizations: List<HttpBindingCustomization> = listOf(),
|
||||
) : ProtocolGeneratorFactory<ServerHttpBoundProtocolGenerator, ServerCodegenContext> {
|
||||
override fun protocol(codegenContext: ServerCodegenContext): ServerProtocol =
|
||||
ServerAwsJsonProtocol(codegenContext, version)
|
||||
ServerAwsJsonProtocol(codegenContext, version, additionalParserCustomizations)
|
||||
|
||||
override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator =
|
||||
ServerHttpBoundProtocolGenerator(codegenContext, protocol(codegenContext))
|
||||
ServerHttpBoundProtocolGenerator(
|
||||
codegenContext,
|
||||
protocol(codegenContext),
|
||||
additionalServerHttpBoundProtocolCustomizations,
|
||||
additionalHttpBindingCustomizations,
|
||||
)
|
||||
|
||||
override fun support(): ProtocolSupport {
|
||||
return ProtocolSupport(
|
||||
|
|
|
@ -16,6 +16,7 @@ import software.amazon.smithy.model.pattern.UriPattern
|
|||
import software.amazon.smithy.model.shapes.BooleanShape
|
||||
import software.amazon.smithy.model.shapes.CollectionShape
|
||||
import software.amazon.smithy.model.shapes.MapShape
|
||||
import software.amazon.smithy.model.shapes.MemberShape
|
||||
import software.amazon.smithy.model.shapes.NumberShape
|
||||
import software.amazon.smithy.model.shapes.OperationShape
|
||||
import software.amazon.smithy.model.shapes.Shape
|
||||
|
@ -42,8 +43,10 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
|
|||
import software.amazon.smithy.rust.codegen.core.rustlang.writable
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.TypeConversionGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
|
||||
|
@ -77,6 +80,18 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser
|
|||
import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol
|
||||
import java.util.logging.Logger
|
||||
|
||||
/**
|
||||
* Class describing a ServerHttpBoundProtocol section that can be used in a customization.
|
||||
*/
|
||||
sealed class ServerHttpBoundProtocolSection(name: String) : Section(name) {
|
||||
data class AfterTimestampDeserializedMember(val shape: MemberShape) : ServerHttpBoundProtocolSection("AfterTimestampDeserializedMember")
|
||||
}
|
||||
|
||||
/**
|
||||
* Customization for the ServerHttpBoundProtocol generator.
|
||||
*/
|
||||
typealias ServerHttpBoundProtocolCustomization = NamedCustomization<ServerHttpBoundProtocolSection>
|
||||
|
||||
/**
|
||||
* Implement operations' input parsing and output serialization. Protocols can plug their own implementations
|
||||
* and overrides by creating a protocol factory inheriting from this class and feeding it to the [ServerProtocolLoader].
|
||||
|
@ -85,10 +100,12 @@ import java.util.logging.Logger
|
|||
class ServerHttpBoundProtocolGenerator(
|
||||
codegenContext: ServerCodegenContext,
|
||||
protocol: ServerProtocol,
|
||||
customizations: List<ServerHttpBoundProtocolCustomization> = listOf(),
|
||||
additionalHttpBindingCustomizations: List<HttpBindingCustomization> = listOf(),
|
||||
) : ServerProtocolGenerator(
|
||||
codegenContext,
|
||||
protocol,
|
||||
ServerHttpBoundProtocolTraitImplGenerator(codegenContext, protocol),
|
||||
ServerHttpBoundProtocolTraitImplGenerator(codegenContext, protocol, customizations, additionalHttpBindingCustomizations),
|
||||
) {
|
||||
// Define suffixes for operation input / output / error wrappers
|
||||
companion object {
|
||||
|
@ -104,6 +121,8 @@ class ServerHttpBoundProtocolGenerator(
|
|||
class ServerHttpBoundProtocolTraitImplGenerator(
|
||||
private val codegenContext: ServerCodegenContext,
|
||||
private val protocol: ServerProtocol,
|
||||
private val customizations: List<ServerHttpBoundProtocolCustomization>,
|
||||
private val additionalHttpBindingCustomizations: List<HttpBindingCustomization>,
|
||||
) {
|
||||
private val logger = Logger.getLogger(javaClass.name)
|
||||
private val symbolProvider = codegenContext.symbolProvider
|
||||
|
@ -111,7 +130,6 @@ class ServerHttpBoundProtocolTraitImplGenerator(
|
|||
private val model = codegenContext.model
|
||||
private val runtimeConfig = codegenContext.runtimeConfig
|
||||
private val httpBindingResolver = protocol.httpBindingResolver
|
||||
private val typeConversionGenerator = TypeConversionGenerator(model, symbolProvider, runtimeConfig)
|
||||
private val protocolFunctions = ProtocolFunctions(codegenContext)
|
||||
|
||||
private val codegenScope = arrayOf(
|
||||
|
@ -568,9 +586,9 @@ class ServerHttpBoundProtocolTraitImplGenerator(
|
|||
private fun serverRenderHttpResponseCode(defaultCode: Int) = writable {
|
||||
check(defaultCode in 100..999) {
|
||||
"""
|
||||
Smithy library lied to us. According to https://smithy.io/2.0/spec/http-bindings.html#http-trait,
|
||||
"The provided value SHOULD be between 100 and 599, and it MUST be between 100 and 999".
|
||||
""".replace("\n", "").trimIndent()
|
||||
Smithy library lied to us. According to https://smithy.io/2.0/spec/http-bindings.html#http-trait,
|
||||
"The provided value SHOULD be between 100 and 599, and it MUST be between 100 and 999".
|
||||
""".replace("\n", "").trimIndent()
|
||||
}
|
||||
rustTemplate(
|
||||
"""
|
||||
|
@ -611,7 +629,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
|
|||
inputShape: StructureShape,
|
||||
bindings: List<HttpBindingDescriptor>,
|
||||
) {
|
||||
val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape)
|
||||
val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations)
|
||||
val structuredDataParser = protocol.structuredDataParser()
|
||||
Attribute.AllowUnusedMut.render(this)
|
||||
rust(
|
||||
|
@ -952,12 +970,15 @@ class ServerHttpBoundProtocolTraitImplGenerator(
|
|||
val timestampFormatType = RuntimeType.parseTimestampFormat(CodegenTarget.SERVER, runtimeConfig, timestampFormat)
|
||||
rustTemplate(
|
||||
"""
|
||||
let v = #{DateTime}::from_str(&v, #{format})?#{ConvertInto:W};
|
||||
let v = #{DateTime}::from_str(&v, #{format})?
|
||||
""".trimIndent(),
|
||||
*codegenScope,
|
||||
"format" to timestampFormatType,
|
||||
"ConvertInto" to typeConversionGenerator.convertViaInto(memberShape),
|
||||
)
|
||||
for (customization in customizations) {
|
||||
customization.section(ServerHttpBoundProtocolSection.AfterTimestampDeserializedMember(it.member))(this)
|
||||
}
|
||||
rust(";")
|
||||
}
|
||||
else -> { // Number or boolean.
|
||||
rust(
|
||||
|
@ -1047,7 +1068,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
|
|||
}
|
||||
|
||||
private fun serverRenderHeaderParser(writer: RustWriter, binding: HttpBindingDescriptor, operationShape: OperationShape) {
|
||||
val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape)
|
||||
val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations)
|
||||
val deserializer = httpBindingGenerator.generateDeserializeHeaderFn(binding)
|
||||
writer.rustTemplate(
|
||||
"""
|
||||
|
@ -1109,22 +1130,24 @@ class ServerHttpBoundProtocolTraitImplGenerator(
|
|||
rustTemplate(
|
||||
"""
|
||||
let value = #{PercentEncoding}::percent_decode_str(value).decode_utf8()?;
|
||||
let value = #{DateTime}::from_str(value.as_ref(), #{format})?#{ConvertInto:W};
|
||||
let value = #{DateTime}::from_str(value.as_ref(), #{format})?
|
||||
""",
|
||||
*codegenScope,
|
||||
"format" to timestampFormatType,
|
||||
"ConvertInto" to typeConversionGenerator.convertViaInto(target),
|
||||
)
|
||||
} else {
|
||||
rustTemplate(
|
||||
"""
|
||||
let value = #{DateTime}::from_str(value, #{format})?#{ConvertInto:W};
|
||||
let value = #{DateTime}::from_str(value, #{format})?
|
||||
""",
|
||||
*codegenScope,
|
||||
"format" to timestampFormatType,
|
||||
"ConvertInto" to typeConversionGenerator.convertViaInto(target),
|
||||
)
|
||||
}
|
||||
for (customization in customizations) {
|
||||
customization.section(ServerHttpBoundProtocolSection.AfterTimestampDeserializedMember(binding.member))(this)
|
||||
}
|
||||
rust(";")
|
||||
}
|
||||
else -> {
|
||||
check(target is NumberShape || target is BooleanShape)
|
||||
|
|
|
@ -5,10 +5,12 @@
|
|||
|
||||
package software.amazon.smithy.rust.codegen.server.smithy.protocols
|
||||
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGeneratorFactory
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserCustomization
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.restJsonFieldName
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator
|
||||
|
@ -21,11 +23,23 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser
|
|||
* RestJson1 server-side protocol factory. This factory creates the [ServerHttpProtocolGenerator]
|
||||
* with RestJson1 specific configurations.
|
||||
*/
|
||||
class ServerRestJsonFactory : ProtocolGeneratorFactory<ServerHttpBoundProtocolGenerator, ServerCodegenContext> {
|
||||
override fun protocol(codegenContext: ServerCodegenContext): Protocol = ServerRestJsonProtocol(codegenContext)
|
||||
class ServerRestJsonFactory(
|
||||
private val additionalParserCustomizations: List<JsonParserCustomization> = listOf(),
|
||||
private val additionalServerHttpBoundProtocolCustomizations: List<ServerHttpBoundProtocolCustomization> = listOf(),
|
||||
private val additionalHttpBindingCustomizations: List<HttpBindingCustomization> = listOf(),
|
||||
) : ProtocolGeneratorFactory<ServerHttpBoundProtocolGenerator, ServerCodegenContext> {
|
||||
override fun protocol(codegenContext: ServerCodegenContext): Protocol = ServerRestJsonProtocol(codegenContext, additionalParserCustomizations)
|
||||
|
||||
override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator =
|
||||
ServerHttpBoundProtocolGenerator(codegenContext, ServerRestJsonProtocol(codegenContext))
|
||||
ServerHttpBoundProtocolGenerator(
|
||||
codegenContext,
|
||||
ServerRestJsonProtocol(
|
||||
codegenContext,
|
||||
additionalParserCustomizations,
|
||||
),
|
||||
additionalServerHttpBoundProtocolCustomizations,
|
||||
additionalHttpBindingCustomizations,
|
||||
)
|
||||
|
||||
override fun support(): ProtocolSupport {
|
||||
return ProtocolSupport(
|
||||
|
|
|
@ -2,9 +2,9 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
HEAD_BRANCH_NAME = "__tmp-localonly-head"
|
||||
BASE_BRANCH_NAME = "__tmp-localonly-base"
|
||||
|
@ -15,10 +15,10 @@ COMMIT_AUTHOR_EMAIL = "generated-code-action@github.com"
|
|||
|
||||
CDN_URL = "https://d2luzm2xt3nokh.cloudfront.net"
|
||||
|
||||
PYTHON_EXAMPLES_PATH = "rust-runtime/aws-smithy-http-server-python/examples"
|
||||
|
||||
target_codegen_client = 'codegen-client-test'
|
||||
target_codegen_server = 'codegen-server-test'
|
||||
target_codegen_server_python = 'codegen-server-test:python'
|
||||
target_codegen_server_typescript = 'codegen-server-test:typescript'
|
||||
target_aws_sdk = 'aws:sdk'
|
||||
|
||||
|
||||
|
@ -38,19 +38,19 @@ def checkout_commit_and_generate(revision_sha, branch_name, targets=None):
|
|||
|
||||
|
||||
def generate_and_commit_generated_code(revision_sha, targets=None):
|
||||
targets = targets or [target_codegen_client, target_codegen_server, target_aws_sdk]
|
||||
targets = targets or [
|
||||
target_codegen_client,
|
||||
target_codegen_server,
|
||||
target_aws_sdk,
|
||||
target_codegen_server_python,
|
||||
target_codegen_server_typescript
|
||||
]
|
||||
# Clean the build artifacts before continuing
|
||||
assemble_tasks = ' '.join([f'{t}:assemble' for t in targets])
|
||||
clean_tasks = ' '.join([f'{t}:clean' for t in targets])
|
||||
get_cmd_output("rm -rf aws/sdk/build")
|
||||
if target_codegen_server in targets:
|
||||
get_cmd_output("make distclean", shell=True, cwd=PYTHON_EXAMPLES_PATH)
|
||||
get_cmd_output("./gradlew codegen-core:clean codegen-client:clean codegen-server:clean aws:sdk-codegen:clean")
|
||||
|
||||
# Generate code
|
||||
tasks = ' '.join([f'{t}:assemble' for t in targets])
|
||||
get_cmd_output(f"./gradlew --rerun-tasks {tasks}")
|
||||
if target_codegen_server in targets:
|
||||
get_cmd_output("make build", shell=True, check=False, cwd=PYTHON_EXAMPLES_PATH)
|
||||
get_cmd_output(f"./gradlew --rerun-tasks codegen-server-test:typescript:assemble")
|
||||
get_cmd_output(f"./gradlew --rerun-tasks {clean_tasks}")
|
||||
get_cmd_output(f"./gradlew --rerun-tasks {assemble_tasks}")
|
||||
|
||||
# Move generated code into codegen-diff/ directory
|
||||
get_cmd_output(f"rm -rf {OUTPUT_PATH}")
|
||||
|
@ -61,12 +61,8 @@ def generate_and_commit_generated_code(revision_sha, targets=None):
|
|||
if target in targets:
|
||||
get_cmd_output(f"mv {target}/build/smithyprojections/{target} {OUTPUT_PATH}/")
|
||||
if target == target_codegen_server:
|
||||
get_cmd_output(
|
||||
f"mv {PYTHON_EXAMPLES_PATH}/pokemon-service-server-sdk/ {OUTPUT_PATH}/codegen-server-test-python/",
|
||||
check=False)
|
||||
get_cmd_output(
|
||||
f"mv codegen-server-test/typescript/build/smithyprojections/codegen-server-test-typescript {OUTPUT_PATH}/",
|
||||
check=False)
|
||||
get_cmd_output(f"mv {target}/python/build/smithyprojections/{target}-python {OUTPUT_PATH}/")
|
||||
get_cmd_output(f"mv {target}/typescript/build/smithyprojections/{target}-typescript {OUTPUT_PATH}/")
|
||||
|
||||
# Clean up the SDK directory
|
||||
get_cmd_output(f"rm -f {OUTPUT_PATH}/aws-sdk/versions.toml")
|
||||
|
@ -79,6 +75,7 @@ def generate_and_commit_generated_code(revision_sha, targets=None):
|
|||
|
||||
# Clean up the server-test folder
|
||||
get_cmd_output(f"rm -rf {OUTPUT_PATH}/codegen-server-test/source")
|
||||
get_cmd_output(f"rm -rf {OUTPUT_PATH}/codegen-server-test-python/source")
|
||||
get_cmd_output(f"rm -rf {OUTPUT_PATH}/codegen-server-test-typescript/source")
|
||||
run(f"find {OUTPUT_PATH}/codegen-server-test | "
|
||||
f"grep -E 'smithy-build-info.json|sources/manifest|model.json' | "
|
||||
|
|
Loading…
Reference in New Issue