control error correction on a client/server basis

This commit is contained in:
Russell Cohen 2023-09-14 12:36:53 -04:00
parent 0db25dc16d
commit a484e5ed0e
7 changed files with 38 additions and 21 deletions

View File

@ -412,7 +412,11 @@ class BuilderGenerator(
val default = generator.defaultValue(member)
if (!memberSymbol.isOptional()) {
if (default != null) {
rust(".unwrap_or_else(#T)", default)
if (default.isRustDefault) {
rust(".unwrap_or_default()")
} else {
rust(".unwrap_or_else(#T)", default)
}
} else {
if (errorCorrection) {
generator.errorCorrection(member)?.also { correction -> rust(".or_else(||#T)", correction) }

View File

@ -39,15 +39,17 @@ class DefaultValueGenerator(
) {
private val instantiator = PrimitiveInstantiator(runtimeConfig, symbolProvider)
data class DefaultValue(val isRustDefault: Boolean, val expr: Writable)
/** Returns the default value as set by the defaultValue trait */
fun defaultValue(member: MemberShape): Writable? {
fun defaultValue(member: MemberShape): DefaultValue? {
val target = model.expectShape(member.target)
return when (val default = symbolProvider.toSymbol(member).defaultValue()) {
is Default.NoDefault -> null
is Default.RustDefault -> writable("Default::default")
is Default.RustDefault -> DefaultValue(isRustDefault = true, writable("Default::default"))
is Default.NonZeroDefault -> {
val instantiation = instantiator.instantiate(target as SimpleShape, default.value)
writable { rust("||#T", instantiation) }
DefaultValue(isRustDefault = false, writable { rust("||#T", instantiation) })
}
}
}
@ -62,10 +64,12 @@ class DefaultValueGenerator(
when (target) {
is EnumShape -> rustTemplate(""""no value was set".parse::<#{Shape}>().ok()""", "Shape" to symbol)
is BooleanShape, is NumberShape, is StringShape, is DocumentShape, is ListShape, is MapShape -> rust("Some(Default::default())")
is StructureShape -> rust(
"#T::default().build_with_error_correction().ok()",
symbolProvider.symbolForBuilder(target),
is StructureShape -> rustTemplate(
"#{error_correct}(#{Builder}::default()).ok()",
"Builder" to symbolProvider.symbolForBuilder(target),
"error_correct" to errorCorrectingBuilder(target, symbolProvider, model)!!,
)
is TimestampShape -> instantiator.instantiate(target, Node.from(0)).some()(this)
is BlobShape -> instantiator.instantiate(target, Node.from("")).some()(this)

View File

@ -122,6 +122,7 @@ class AwsJsonSerializerGenerator(
open class AwsJson(
val codegenContext: CodegenContext,
val awsJsonVersion: AwsJsonVersion,
val enableErrorCorrection: Boolean,
) : Protocol {
private val runtimeConfig = codegenContext.runtimeConfig
private val errorScope = arrayOf(
@ -148,6 +149,7 @@ open class AwsJson(
codegenContext,
httpBindingResolver,
::awsJsonFieldName,
enableErrorCorrection = enableErrorCorrection,
)
override fun structuredDataSerializer(): StructuredDataSerializerGenerator =

View File

@ -59,7 +59,7 @@ class RestJsonHttpBindingResolver(
}
}
open class RestJson(val codegenContext: CodegenContext) : Protocol {
open class RestJson(val codegenContext: CodegenContext, private val enableErrorCorrection: Boolean) : Protocol {
private val runtimeConfig = codegenContext.runtimeConfig
private val errorScope = arrayOf(
"Bytes" to RuntimeType.Bytes,
@ -95,7 +95,7 @@ open class RestJson(val codegenContext: CodegenContext) : Protocol {
listOf("x-amzn-errortype" to errorShape.id.name)
override fun structuredDataParser(): StructuredDataParserGenerator =
JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName)
JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName, enableErrorCorrection = enableErrorCorrection)
override fun structuredDataSerializer(): StructuredDataSerializerGenerator =
JsonSerializerGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName)

View File

@ -33,11 +33,13 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.errorCorrectingBuilder
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
@ -95,6 +97,7 @@ class JsonParserGenerator(
private val returnSymbolToParse: (Shape) -> ReturnSymbolToParse = { shape ->
ReturnSymbolToParse(codegenContext.symbolProvider.toSymbol(shape), false)
},
private val enableErrorCorrection: Boolean,
private val customizations: List<JsonParserCustomization> = listOf(),
) : StructuredDataParserGenerator {
private val model = codegenContext.model
@ -515,17 +518,20 @@ class JsonParserGenerator(
// Only call `build()` if the builder is not fallible. Otherwise, return the builder.
if (returnSymbolToParse.isUnconstrained) {
rust("Ok(Some(builder))")
} else {
} else if (enableErrorCorrection && BuilderGenerator.hasFallibleBuilder(shape, symbolProvider)) {
val errorCorrection = errorCorrectingBuilder(shape, symbolProvider, model)
if (errorCorrection != null) {
rustTemplate(
"""
Ok(Some(#{correct_errors}(builder).map_err(|err|#{Error}::custom_source("Response was invalid", err))?))""",
"correct_errors" to errorCorrection, *codegenScope,
)
val buildExpr = if (errorCorrection != null) {
writable { rustTemplate("#{correct_errors}(builder)", "correctErrors" to errorCorrection) }
} else {
rust("Ok(Some(builder.build()))")
writable { rustTemplate("builder.build()") }
}
rustTemplate(
"""Ok(Some(#{build}.map_err(|err|#{Error}::custom_source("Response was invalid", err))?))""",
"build" to buildExpr,
*codegenScope,
)
} else {
rust("Ok(Some(builder.build()))")
}
}
}

View File

@ -129,7 +129,7 @@ object EventStreamTestModels {
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
validUnmodeledError = """{"Message":"unmodeled error"}""",
) { RestJson(it) },
) { RestJson(it, enableErrorCorrection = false) },
//
// awsJson1_1
@ -145,7 +145,7 @@ object EventStreamTestModels {
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
validUnmodeledError = """{"Message":"unmodeled error"}""",
) { AwsJson(it, AwsJsonVersion.Json11) },
) { AwsJson(it, AwsJsonVersion.Json11, enableErrorCorrection = false) },
//
// restXml

View File

@ -113,6 +113,7 @@ fun jsonParserGenerator(
httpBindingResolver,
jsonName,
returnSymbolToParseFn(codegenContext),
enableErrorCorrection = false,
listOf(
ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(codegenContext),
) + additionalParserCustomizations,
@ -122,7 +123,7 @@ class ServerAwsJsonProtocol(
private val serverCodegenContext: ServerCodegenContext,
awsJsonVersion: AwsJsonVersion,
private val additionalParserCustomizations: List<JsonParserCustomization> = listOf(),
) : AwsJson(serverCodegenContext, awsJsonVersion), ServerProtocol {
) : AwsJson(serverCodegenContext, awsJsonVersion, enableErrorCorrection = false), ServerProtocol {
private val runtimeConfig = codegenContext.runtimeConfig
override val protocolModulePath: String
@ -186,7 +187,7 @@ private fun restRouterType(runtimeConfig: RuntimeConfig) =
class ServerRestJsonProtocol(
private val serverCodegenContext: ServerCodegenContext,
private val additionalParserCustomizations: List<JsonParserCustomization> = listOf(),
) : RestJson(serverCodegenContext), ServerProtocol {
) : RestJson(serverCodegenContext, enableErrorCorrection = false), ServerProtocol {
val runtimeConfig = codegenContext.runtimeConfig
override val protocolModulePath: String = "rest_json_1"