mirror of https://github.com/smithy-lang/smithy-rs
control error correction on a client/server basis
This commit is contained in:
parent
0db25dc16d
commit
a484e5ed0e
|
@ -412,7 +412,11 @@ class BuilderGenerator(
|
|||
val default = generator.defaultValue(member)
|
||||
if (!memberSymbol.isOptional()) {
|
||||
if (default != null) {
|
||||
rust(".unwrap_or_else(#T)", default)
|
||||
if (default.isRustDefault) {
|
||||
rust(".unwrap_or_default()")
|
||||
} else {
|
||||
rust(".unwrap_or_else(#T)", default)
|
||||
}
|
||||
} else {
|
||||
if (errorCorrection) {
|
||||
generator.errorCorrection(member)?.also { correction -> rust(".or_else(||#T)", correction) }
|
||||
|
|
|
@ -39,15 +39,17 @@ class DefaultValueGenerator(
|
|||
) {
|
||||
private val instantiator = PrimitiveInstantiator(runtimeConfig, symbolProvider)
|
||||
|
||||
data class DefaultValue(val isRustDefault: Boolean, val expr: Writable)
|
||||
|
||||
/** Returns the default value as set by the defaultValue trait */
|
||||
fun defaultValue(member: MemberShape): Writable? {
|
||||
fun defaultValue(member: MemberShape): DefaultValue? {
|
||||
val target = model.expectShape(member.target)
|
||||
return when (val default = symbolProvider.toSymbol(member).defaultValue()) {
|
||||
is Default.NoDefault -> null
|
||||
is Default.RustDefault -> writable("Default::default")
|
||||
is Default.RustDefault -> DefaultValue(isRustDefault = true, writable("Default::default"))
|
||||
is Default.NonZeroDefault -> {
|
||||
val instantiation = instantiator.instantiate(target as SimpleShape, default.value)
|
||||
writable { rust("||#T", instantiation) }
|
||||
DefaultValue(isRustDefault = false, writable { rust("||#T", instantiation) })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -62,10 +64,12 @@ class DefaultValueGenerator(
|
|||
when (target) {
|
||||
is EnumShape -> rustTemplate(""""no value was set".parse::<#{Shape}>().ok()""", "Shape" to symbol)
|
||||
is BooleanShape, is NumberShape, is StringShape, is DocumentShape, is ListShape, is MapShape -> rust("Some(Default::default())")
|
||||
is StructureShape -> rust(
|
||||
"#T::default().build_with_error_correction().ok()",
|
||||
symbolProvider.symbolForBuilder(target),
|
||||
is StructureShape -> rustTemplate(
|
||||
"#{error_correct}(#{Builder}::default()).ok()",
|
||||
"Builder" to symbolProvider.symbolForBuilder(target),
|
||||
"error_correct" to errorCorrectingBuilder(target, symbolProvider, model)!!,
|
||||
)
|
||||
|
||||
is TimestampShape -> instantiator.instantiate(target, Node.from(0)).some()(this)
|
||||
is BlobShape -> instantiator.instantiate(target, Node.from("")).some()(this)
|
||||
|
||||
|
|
|
@ -122,6 +122,7 @@ class AwsJsonSerializerGenerator(
|
|||
open class AwsJson(
|
||||
val codegenContext: CodegenContext,
|
||||
val awsJsonVersion: AwsJsonVersion,
|
||||
val enableErrorCorrection: Boolean,
|
||||
) : Protocol {
|
||||
private val runtimeConfig = codegenContext.runtimeConfig
|
||||
private val errorScope = arrayOf(
|
||||
|
@ -148,6 +149,7 @@ open class AwsJson(
|
|||
codegenContext,
|
||||
httpBindingResolver,
|
||||
::awsJsonFieldName,
|
||||
enableErrorCorrection = enableErrorCorrection,
|
||||
)
|
||||
|
||||
override fun structuredDataSerializer(): StructuredDataSerializerGenerator =
|
||||
|
|
|
@ -59,7 +59,7 @@ class RestJsonHttpBindingResolver(
|
|||
}
|
||||
}
|
||||
|
||||
open class RestJson(val codegenContext: CodegenContext) : Protocol {
|
||||
open class RestJson(val codegenContext: CodegenContext, private val enableErrorCorrection: Boolean) : Protocol {
|
||||
private val runtimeConfig = codegenContext.runtimeConfig
|
||||
private val errorScope = arrayOf(
|
||||
"Bytes" to RuntimeType.Bytes,
|
||||
|
@ -95,7 +95,7 @@ open class RestJson(val codegenContext: CodegenContext) : Protocol {
|
|||
listOf("x-amzn-errortype" to errorShape.id.name)
|
||||
|
||||
override fun structuredDataParser(): StructuredDataParserGenerator =
|
||||
JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName)
|
||||
JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName, enableErrorCorrection = enableErrorCorrection)
|
||||
|
||||
override fun structuredDataSerializer(): StructuredDataSerializerGenerator =
|
||||
JsonSerializerGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName)
|
||||
|
|
|
@ -33,11 +33,13 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
|
|||
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.writable
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.errorCorrectingBuilder
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
|
||||
|
@ -95,6 +97,7 @@ class JsonParserGenerator(
|
|||
private val returnSymbolToParse: (Shape) -> ReturnSymbolToParse = { shape ->
|
||||
ReturnSymbolToParse(codegenContext.symbolProvider.toSymbol(shape), false)
|
||||
},
|
||||
private val enableErrorCorrection: Boolean,
|
||||
private val customizations: List<JsonParserCustomization> = listOf(),
|
||||
) : StructuredDataParserGenerator {
|
||||
private val model = codegenContext.model
|
||||
|
@ -515,17 +518,20 @@ class JsonParserGenerator(
|
|||
// Only call `build()` if the builder is not fallible. Otherwise, return the builder.
|
||||
if (returnSymbolToParse.isUnconstrained) {
|
||||
rust("Ok(Some(builder))")
|
||||
} else {
|
||||
} else if (enableErrorCorrection && BuilderGenerator.hasFallibleBuilder(shape, symbolProvider)) {
|
||||
val errorCorrection = errorCorrectingBuilder(shape, symbolProvider, model)
|
||||
if (errorCorrection != null) {
|
||||
rustTemplate(
|
||||
"""
|
||||
Ok(Some(#{correct_errors}(builder).map_err(|err|#{Error}::custom_source("Response was invalid", err))?))""",
|
||||
"correct_errors" to errorCorrection, *codegenScope,
|
||||
)
|
||||
val buildExpr = if (errorCorrection != null) {
|
||||
writable { rustTemplate("#{correct_errors}(builder)", "correctErrors" to errorCorrection) }
|
||||
} else {
|
||||
rust("Ok(Some(builder.build()))")
|
||||
writable { rustTemplate("builder.build()") }
|
||||
}
|
||||
rustTemplate(
|
||||
"""Ok(Some(#{build}.map_err(|err|#{Error}::custom_source("Response was invalid", err))?))""",
|
||||
"build" to buildExpr,
|
||||
*codegenScope,
|
||||
)
|
||||
} else {
|
||||
rust("Ok(Some(builder.build()))")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -129,7 +129,7 @@ object EventStreamTestModels {
|
|||
validTestUnion = """{"Foo":"hello"}""",
|
||||
validSomeError = """{"Message":"some error"}""",
|
||||
validUnmodeledError = """{"Message":"unmodeled error"}""",
|
||||
) { RestJson(it) },
|
||||
) { RestJson(it, enableErrorCorrection = false) },
|
||||
|
||||
//
|
||||
// awsJson1_1
|
||||
|
@ -145,7 +145,7 @@ object EventStreamTestModels {
|
|||
validTestUnion = """{"Foo":"hello"}""",
|
||||
validSomeError = """{"Message":"some error"}""",
|
||||
validUnmodeledError = """{"Message":"unmodeled error"}""",
|
||||
) { AwsJson(it, AwsJsonVersion.Json11) },
|
||||
) { AwsJson(it, AwsJsonVersion.Json11, enableErrorCorrection = false) },
|
||||
|
||||
//
|
||||
// restXml
|
||||
|
|
|
@ -113,6 +113,7 @@ fun jsonParserGenerator(
|
|||
httpBindingResolver,
|
||||
jsonName,
|
||||
returnSymbolToParseFn(codegenContext),
|
||||
enableErrorCorrection = false,
|
||||
listOf(
|
||||
ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(codegenContext),
|
||||
) + additionalParserCustomizations,
|
||||
|
@ -122,7 +123,7 @@ class ServerAwsJsonProtocol(
|
|||
private val serverCodegenContext: ServerCodegenContext,
|
||||
awsJsonVersion: AwsJsonVersion,
|
||||
private val additionalParserCustomizations: List<JsonParserCustomization> = listOf(),
|
||||
) : AwsJson(serverCodegenContext, awsJsonVersion), ServerProtocol {
|
||||
) : AwsJson(serverCodegenContext, awsJsonVersion, enableErrorCorrection = false), ServerProtocol {
|
||||
private val runtimeConfig = codegenContext.runtimeConfig
|
||||
|
||||
override val protocolModulePath: String
|
||||
|
@ -186,7 +187,7 @@ private fun restRouterType(runtimeConfig: RuntimeConfig) =
|
|||
class ServerRestJsonProtocol(
|
||||
private val serverCodegenContext: ServerCodegenContext,
|
||||
private val additionalParserCustomizations: List<JsonParserCustomization> = listOf(),
|
||||
) : RestJson(serverCodegenContext), ServerProtocol {
|
||||
) : RestJson(serverCodegenContext, enableErrorCorrection = false), ServerProtocol {
|
||||
val runtimeConfig = codegenContext.runtimeConfig
|
||||
|
||||
override val protocolModulePath: String = "rest_json_1"
|
||||
|
|
Loading…
Reference in New Issue