mirror of https://github.com/smithy-lang/smithy-rs
Server streaming body (#1023)
Add support for server blob streaming requests and responses Data is streamed over the HTTP body. Signed-off-by: Guy Margalit <guymguym@gmail.com> Co-authored-by: david-perez <d@vidp.dev>
This commit is contained in:
parent
907c0f3ff9
commit
f76bc159bf
|
@ -18,6 +18,8 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpPro
|
|||
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
|
||||
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
|
||||
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
|
||||
import software.amazon.smithy.rust.codegen.util.inputShape
|
||||
import software.amazon.smithy.rust.codegen.util.outputShape
|
||||
|
||||
/**
|
||||
|
@ -39,6 +41,7 @@ class ServerOperationHandlerGenerator(
|
|||
"PinProjectLite" to ServerCargoDependency.PinProjectLite.asType(),
|
||||
"Tower" to ServerCargoDependency.Tower.asType(),
|
||||
"FuturesUtil" to ServerCargoDependency.FuturesUtil.asType(),
|
||||
"SmithyHttp" to CargoDependency.SmithyHttp(runtimeConfig).asType(),
|
||||
"SmithyHttpServer" to CargoDependency.SmithyHttpServer(runtimeConfig).asType(),
|
||||
"SmithyRejection" to ServerHttpProtocolGenerator.smithyRejection(runtimeConfig),
|
||||
"Phantom" to ServerRuntimeType.Phantom,
|
||||
|
@ -132,13 +135,18 @@ class ServerOperationHandlerGenerator(
|
|||
} else {
|
||||
symbolProvider.toSymbol(operation.outputShape(model)).fullName
|
||||
}
|
||||
val streamingBodyTraitBounds = if (operation.inputShape(model).hasStreamingMember(model)) {
|
||||
"\n B: Into<#{SmithyHttp}::byte_stream::ByteStream>,"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
return """
|
||||
$inputFn
|
||||
Fut: std::future::Future<Output = $outputType> + Send,
|
||||
B: $serverCrate::HttpBody + Send + 'static,
|
||||
B: $serverCrate::HttpBody + Send + 'static, $streamingBodyTraitBounds
|
||||
B::Data: Send,
|
||||
B::Error: Into<$serverCrate::BoxError>,
|
||||
$serverCrate::rejection::SmithyRejection: From<<B as $serverCrate::HttpBody>::Error>
|
||||
"""
|
||||
""".trimIndent()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,6 +39,7 @@ import software.amazon.smithy.rust.codegen.util.getTrait
|
|||
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
|
||||
import software.amazon.smithy.rust.codegen.util.hasTrait
|
||||
import software.amazon.smithy.rust.codegen.util.inputShape
|
||||
import software.amazon.smithy.rust.codegen.util.isStreaming
|
||||
import software.amazon.smithy.rust.codegen.util.orNull
|
||||
import software.amazon.smithy.rust.codegen.util.outputShape
|
||||
import software.amazon.smithy.rust.codegen.util.toSnakeCase
|
||||
|
@ -212,16 +213,16 @@ class ServerProtocolTestGenerator(
|
|||
|
||||
rustTemplate(
|
||||
"""
|
||||
##[allow(unused_mut)] let mut http_request = http::Request::builder()
|
||||
.uri("${httpRequestTestCase.uri}")
|
||||
""",
|
||||
*codegenScope
|
||||
##[allow(unused_mut)] let mut http_request = http::Request::builder()
|
||||
.uri("${httpRequestTestCase.uri}")
|
||||
""",
|
||||
*codegenScope
|
||||
)
|
||||
for (header in httpRequestTestCase.headers) {
|
||||
rust(".header(${header.key.dq()}, ${header.value.dq()})")
|
||||
}
|
||||
rustTemplate(
|
||||
"""
|
||||
"""
|
||||
.body(#{SmithyHttpServer}::Body::from(#{Bytes}::from_static(b${httpRequestTestCase.body.orNull()?.dq()})))
|
||||
.unwrap();
|
||||
""",
|
||||
|
@ -326,15 +327,37 @@ class ServerProtocolTestGenerator(
|
|||
"""
|
||||
use #{AxumCore}::extract::FromRequest;
|
||||
let mut http_request = #{AxumCore}::extract::RequestParts::new(http_request);
|
||||
let input_wrapper = super::$operationName::from_request(&mut http_request).await.expect("failed to parse request");
|
||||
let input = input_wrapper.0;
|
||||
let parsed = super::$operationName::from_request(&mut http_request).await.expect("failed to parse request").0;
|
||||
""",
|
||||
*codegenScope,
|
||||
)
|
||||
if (operationShape.outputShape(model).hasStreamingMember(model)) {
|
||||
rustWriter.rust("""todo!("streaming types aren't supported yet");""")
|
||||
|
||||
if (inputShape.hasStreamingMember(model)) {
|
||||
// A streaming shape does not implement `PartialEq`, so we have to iterate over the input shape's members
|
||||
// and handle the equality assertion separately.
|
||||
for (member in inputShape.members()) {
|
||||
val memberName = codegenContext.symbolProvider.toMemberName(member)
|
||||
if (member.isStreaming(codegenContext.model)) {
|
||||
rustWriter.rustTemplate(
|
||||
"""
|
||||
#{AssertEq}(
|
||||
parsed.$memberName.collect().await.unwrap().into_bytes(),
|
||||
expected.$memberName.collect().await.unwrap().into_bytes()
|
||||
);
|
||||
""",
|
||||
*codegenScope
|
||||
)
|
||||
} else {
|
||||
rustWriter.rustTemplate(
|
||||
"""
|
||||
#{AssertEq}(parsed.$memberName, expected.$memberName, "Unexpected value for `$memberName`");
|
||||
""",
|
||||
*codegenScope
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rustWriter.rustTemplate("#{AssertEq}(input, expected);", *codegenScope)
|
||||
rustWriter.rustTemplate("#{AssertEq}(parsed, expected);", *codegenScope)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -357,7 +380,7 @@ class ServerProtocolTestGenerator(
|
|||
assertOk(rustWriter) {
|
||||
rustWriter.write(
|
||||
"#T(&body, ${
|
||||
rustWriter.escape(body).dq()
|
||||
rustWriter.escape(body).dq()
|
||||
}, #T::from(${(mediaType ?: "unknown").dq()}))",
|
||||
RuntimeType.ProtocolTestHelper(codegenContext.runtimeConfig, "validate_body"),
|
||||
RuntimeType.ProtocolTestHelper(codegenContext.runtimeConfig, "MediaType")
|
||||
|
@ -386,9 +409,9 @@ class ServerProtocolTestGenerator(
|
|||
basicCheck(
|
||||
requireHeaders,
|
||||
rustWriter,
|
||||
"required_headers",
|
||||
actualExpression,
|
||||
"require_headers"
|
||||
"required_headers",
|
||||
actualExpression,
|
||||
"require_headers"
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -396,9 +419,9 @@ class ServerProtocolTestGenerator(
|
|||
basicCheck(
|
||||
forbidHeaders,
|
||||
rustWriter,
|
||||
"forbidden_headers",
|
||||
"forbidden_headers",
|
||||
actualExpression,
|
||||
"forbid_headers"
|
||||
"forbid_headers"
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -511,16 +534,7 @@ class ServerProtocolTestGenerator(
|
|||
FailingTest(RestJson, "RestJsonNoInputAndNoOutput", Action.Response),
|
||||
FailingTest(RestJson, "RestJsonNoInputAndOutputWithJson", Action.Response),
|
||||
FailingTest(RestJson, "RestJsonSupportsNaNFloatInputs", Action.Request),
|
||||
FailingTest(RestJson, "RestJsonStreamingTraitsWithBlob", Action.Request),
|
||||
FailingTest(RestJson, "RestJsonStreamingTraitsWithNoBlobBody", Action.Request),
|
||||
FailingTest(RestJson, "RestJsonStreamingTraitsWithBlob", Action.Response),
|
||||
FailingTest(RestJson, "RestJsonStreamingTraitsWithNoBlobBody", Action.Response),
|
||||
FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithBlob", Action.Request),
|
||||
FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithNoBlobBody", Action.Request),
|
||||
FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithBlob", Action.Response),
|
||||
FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithNoBlobBody", Action.Response),
|
||||
FailingTest(RestJson, "RestJsonStreamingTraitsWithMediaTypeWithBlob", Action.Request),
|
||||
FailingTest(RestJson, "RestJsonStreamingTraitsWithMediaTypeWithBlob", Action.Response),
|
||||
FailingTest(RestJson, "RestJsonHttpWithEmptyBlobPayload", Action.Request),
|
||||
FailingTest(RestJson, "RestJsonHttpWithEmptyStructurePayload", Action.Request),
|
||||
|
||||
|
@ -591,56 +605,64 @@ class ServerProtocolTestGenerator(
|
|||
).asObjectNode().get()
|
||||
).build()
|
||||
private fun fixRestJsonAllQueryStringTypes(testCase: HttpRequestTestCase): HttpRequestTestCase =
|
||||
testCase.toBuilder().params(
|
||||
Node.parse("""{
|
||||
"queryString": "Hello there",
|
||||
"queryStringList": ["a", "b", "c"],
|
||||
"queryStringSet": ["a", "b", "c"],
|
||||
"queryByte": 1,
|
||||
"queryShort": 2,
|
||||
"queryInteger": 3,
|
||||
"queryIntegerList": [1, 2, 3],
|
||||
"queryIntegerSet": [1, 2, 3],
|
||||
"queryLong": 4,
|
||||
"queryFloat": 1.1,
|
||||
"queryDouble": 1.1,
|
||||
"queryDoubleList": [1.1, 2.1, 3.1],
|
||||
"queryBoolean": true,
|
||||
"queryBooleanList": [true, false, true],
|
||||
"queryTimestamp": 1,
|
||||
"queryTimestampList": [1, 2, 3],
|
||||
"queryEnum": "Foo",
|
||||
"queryEnumList": ["Foo", "Baz", "Bar"],
|
||||
"queryParamsMapOfStringList": {
|
||||
"String": ["Hello there"],
|
||||
"StringList": ["a", "b", "c"],
|
||||
"StringSet": ["a", "b", "c"],
|
||||
"Byte": ["1"],
|
||||
"Short": ["2"],
|
||||
"Integer": ["3"],
|
||||
"IntegerList": ["1", "2", "3"],
|
||||
"IntegerSet": ["1", "2", "3"],
|
||||
"Long": ["4"],
|
||||
"Float": ["1.1"],
|
||||
"Double": ["1.1"],
|
||||
"DoubleList": ["1.1", "2.1", "3.1"],
|
||||
"Boolean": ["true"],
|
||||
"BooleanList": ["true", "false", "true"],
|
||||
"Timestamp": ["1970-01-01T00:00:01Z"],
|
||||
"TimestampList": ["1970-01-01T00:00:01Z", "1970-01-01T00:00:02Z", "1970-01-01T00:00:03Z"],
|
||||
"Enum": ["Foo"],
|
||||
"EnumList": ["Foo", "Baz", "Bar"]
|
||||
testCase.toBuilder().params(
|
||||
Node.parse(
|
||||
"""
|
||||
{
|
||||
"queryString": "Hello there",
|
||||
"queryStringList": ["a", "b", "c"],
|
||||
"queryStringSet": ["a", "b", "c"],
|
||||
"queryByte": 1,
|
||||
"queryShort": 2,
|
||||
"queryInteger": 3,
|
||||
"queryIntegerList": [1, 2, 3],
|
||||
"queryIntegerSet": [1, 2, 3],
|
||||
"queryLong": 4,
|
||||
"queryFloat": 1.1,
|
||||
"queryDouble": 1.1,
|
||||
"queryDoubleList": [1.1, 2.1, 3.1],
|
||||
"queryBoolean": true,
|
||||
"queryBooleanList": [true, false, true],
|
||||
"queryTimestamp": 1,
|
||||
"queryTimestampList": [1, 2, 3],
|
||||
"queryEnum": "Foo",
|
||||
"queryEnumList": ["Foo", "Baz", "Bar"],
|
||||
"queryParamsMapOfStringList": {
|
||||
"String": ["Hello there"],
|
||||
"StringList": ["a", "b", "c"],
|
||||
"StringSet": ["a", "b", "c"],
|
||||
"Byte": ["1"],
|
||||
"Short": ["2"],
|
||||
"Integer": ["3"],
|
||||
"IntegerList": ["1", "2", "3"],
|
||||
"IntegerSet": ["1", "2", "3"],
|
||||
"Long": ["4"],
|
||||
"Float": ["1.1"],
|
||||
"Double": ["1.1"],
|
||||
"DoubleList": ["1.1", "2.1", "3.1"],
|
||||
"Boolean": ["true"],
|
||||
"BooleanList": ["true", "false", "true"],
|
||||
"Timestamp": ["1970-01-01T00:00:01Z"],
|
||||
"TimestampList": ["1970-01-01T00:00:01Z", "1970-01-01T00:00:02Z", "1970-01-01T00:00:03Z"],
|
||||
"Enum": ["Foo"],
|
||||
"EnumList": ["Foo", "Baz", "Bar"]
|
||||
}
|
||||
}
|
||||
}""".trimMargin()).asObjectNode().get()
|
||||
).build()
|
||||
""".trimMargin()
|
||||
).asObjectNode().get()
|
||||
).build()
|
||||
private fun fixRestJsonQueryStringEscaping(testCase: HttpRequestTestCase): HttpRequestTestCase =
|
||||
testCase.toBuilder().params(
|
||||
Node.parse("""{
|
||||
"queryString": "%:/?#[]@!${'$'}&'()*+,;=😹",
|
||||
"queryParamsMapOfStringList": {
|
||||
"String": ["%:/?#[]@!${'$'}&'()*+,;=😹"]
|
||||
}
|
||||
}""".trimMargin()).asObjectNode().get()
|
||||
Node.parse(
|
||||
"""
|
||||
{
|
||||
"queryString": "%:/?#[]@!${'$'}&'()*+,;=😹",
|
||||
"queryParamsMapOfStringList": {
|
||||
"String": ["%:/?#[]@!${'$'}&'()*+,;=😹"]
|
||||
}
|
||||
}
|
||||
""".trimMargin()
|
||||
).asObjectNode().get()
|
||||
).build()
|
||||
// This test assumes that errors in responses are identified by an `X-Amzn-Errortype` header with the error shape name.
|
||||
// However, Smithy specifications for AWS protocols that serialize to JSON recommend that new server implementations
|
||||
|
|
|
@ -8,7 +8,6 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols
|
|||
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
|
||||
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
|
||||
import software.amazon.smithy.codegen.core.Symbol
|
||||
import software.amazon.smithy.model.knowledge.HttpBinding
|
||||
import software.amazon.smithy.model.knowledge.HttpBindingIndex
|
||||
import software.amazon.smithy.model.node.ExpectationNotMetException
|
||||
import software.amazon.smithy.model.shapes.CollectionShape
|
||||
|
@ -55,6 +54,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredData
|
|||
import software.amazon.smithy.rust.codegen.util.UNREACHABLE
|
||||
import software.amazon.smithy.rust.codegen.util.dq
|
||||
import software.amazon.smithy.rust.codegen.util.expectTrait
|
||||
import software.amazon.smithy.rust.codegen.util.findStreamingMember
|
||||
import software.amazon.smithy.rust.codegen.util.getTrait
|
||||
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
|
||||
import software.amazon.smithy.rust.codegen.util.inputShape
|
||||
|
@ -119,6 +119,7 @@ private class ServerHttpProtocolImplGenerator(
|
|||
"PercentEncoding" to CargoDependency.PercentEncoding.asType(),
|
||||
"Regex" to CargoDependency.Regex.asType(),
|
||||
"SerdeUrlEncoded" to ServerCargoDependency.SerdeUrlEncoded.asType(),
|
||||
"SmithyHttp" to CargoDependency.SmithyHttp(runtimeConfig).asType(),
|
||||
"SmithyHttpServer" to CargoDependency.SmithyHttpServer(runtimeConfig).asType(),
|
||||
"SmithyRejection" to ServerHttpProtocolGenerator.smithyRejection(runtimeConfig),
|
||||
"http" to RuntimeType.http
|
||||
|
@ -132,13 +133,12 @@ private class ServerHttpProtocolImplGenerator(
|
|||
}
|
||||
|
||||
/*
|
||||
* Generation of `FromRequest` and `IntoResponse`. They are currently only implemented for non-streaming request
|
||||
* and response bodies, that is, models without streaming traits
|
||||
* (https://awslabs.github.io/smithy/1.0/spec/core/stream-traits.html).
|
||||
* For non-streaming request bodies, we require the HTTP body to be fully read in memory before parsing or
|
||||
* deserialization. From a server perspective we need a way to parse an HTTP request from `Bytes` and serialize
|
||||
* Generation of `FromRequest` and `IntoResponse`.
|
||||
* For non-streaming request bodies, that is, models without streaming traits
|
||||
* (https://awslabs.github.io/smithy/1.0/spec/core/stream-traits.html)
|
||||
* we require the HTTP body to be fully read in memory before parsing or deserialization.
|
||||
* From a server perspective we need a way to parse an HTTP request from `Bytes` and serialize
|
||||
* an HTTP response to `Bytes`.
|
||||
* TODO Add support for streaming.
|
||||
* These traits are the public entrypoint of the ser/de logic of the `aws-smithy-http-server` server.
|
||||
*/
|
||||
private fun RustWriter.renderTraits(
|
||||
|
@ -147,38 +147,24 @@ private class ServerHttpProtocolImplGenerator(
|
|||
operationShape: OperationShape
|
||||
) {
|
||||
val operationName = symbolProvider.toSymbol(operationShape).name
|
||||
// Implement Axum `FromRequest` trait for input types.
|
||||
val inputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}"
|
||||
|
||||
val fromRequest = if (operationShape.inputShape(model).hasStreamingMember(model)) {
|
||||
// For streaming request bodies, we need to generate a different implementation of the `FromRequest` trait.
|
||||
// It will first offer the streaming input to the parser and potentially read the body into memory
|
||||
// if an error occurred or if the streaming parser indicates that it needs the full data to proceed.
|
||||
"""
|
||||
async fn from_request(_req: &mut #{AxumCore}::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
todo!("Streaming support for input shapes is not yet supported in `smithy-rs`")
|
||||
}
|
||||
""".trimIndent()
|
||||
} else {
|
||||
"""
|
||||
async fn from_request(req: &mut #{AxumCore}::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
Ok($inputName(#{parse_request}(req).await?))
|
||||
}
|
||||
""".trimIndent()
|
||||
}
|
||||
// Implement Axum `FromRequest` trait for input types.
|
||||
rustTemplate(
|
||||
"""
|
||||
pub struct $inputName(pub #{I});
|
||||
##[#{AsyncTrait}::async_trait]
|
||||
impl<B> #{AxumCore}::extract::FromRequest<B> for $inputName
|
||||
where
|
||||
B: #{SmithyHttpServer}::HttpBody + Send,
|
||||
B: #{SmithyHttpServer}::HttpBody + Send, ${getStreamingBodyTraitBounds(operationShape)}
|
||||
B::Data: Send,
|
||||
B::Error: Into<#{SmithyHttpServer}::BoxError>,
|
||||
#{SmithyRejection}: From<<B as #{SmithyHttpServer}::HttpBody>::Error>
|
||||
{
|
||||
type Rejection = #{SmithyRejection};
|
||||
$fromRequest
|
||||
async fn from_request(req: &mut #{AxumCore}::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
Ok($inputName(#{parse_request}(req).await?))
|
||||
}
|
||||
}
|
||||
""".trimIndent(),
|
||||
*codegenScope,
|
||||
|
@ -187,21 +173,19 @@ private class ServerHttpProtocolImplGenerator(
|
|||
)
|
||||
|
||||
// Implement Axum `IntoResponse` for output types.
|
||||
|
||||
val outputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}"
|
||||
val errorSymbol = operationShape.errorSymbol(symbolProvider)
|
||||
|
||||
val httpExtensions = setHttpExtensions(operationShape)
|
||||
// For streaming response bodies, we need to generate a different implementation of the `IntoResponse` trait.
|
||||
// The body type will have to be a `StreamBody`. The service implementer will return a `Stream` from their handler.
|
||||
val intoResponseStreaming = "todo!(\"Streaming support for output shapes is not yet supported in `smithy-rs`\")"
|
||||
|
||||
if (operationShape.errors.isNotEmpty()) {
|
||||
val intoResponseImpl = if (operationShape.outputShape(model).hasStreamingMember(model)) {
|
||||
intoResponseStreaming
|
||||
} else {
|
||||
// The output of fallible operations is a `Result` which we convert into an
|
||||
// isomorphic `enum` type we control that can in turn be converted into a response.
|
||||
val intoResponseImpl =
|
||||
"""
|
||||
let mut response = match self {
|
||||
Self::Output(o) => {
|
||||
match #{serialize_response}(&o) {
|
||||
match #{serialize_response}(o) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
e.into_response()
|
||||
|
@ -223,9 +207,7 @@ private class ServerHttpProtocolImplGenerator(
|
|||
$httpExtensions
|
||||
response
|
||||
""".trimIndent()
|
||||
}
|
||||
// The output of fallible operations is a `Result` which we convert into an isomorphic `enum` type we control
|
||||
// that can in turn be converted into a response.
|
||||
|
||||
rustTemplate(
|
||||
"""
|
||||
pub enum $outputName {
|
||||
|
@ -246,27 +228,25 @@ private class ServerHttpProtocolImplGenerator(
|
|||
"serialize_error" to serverSerializeError(operationShape)
|
||||
)
|
||||
} else {
|
||||
val handleSerializeOutput = if (operationShape.outputShape(model).hasStreamingMember(model)) {
|
||||
intoResponseStreaming
|
||||
} else {
|
||||
// The output of non-fallible operations is a model type which we convert into
|
||||
// a "wrapper" unit `struct` type we control that can in turn be converted into a response.
|
||||
val intoResponseImpl =
|
||||
"""
|
||||
let mut response = match #{serialize_response}(&self.0) {
|
||||
let mut response = match #{serialize_response}(self.0) {
|
||||
Ok(response) => response,
|
||||
Err(e) => e.into_response()
|
||||
};
|
||||
$httpExtensions
|
||||
response
|
||||
""".trimIndent()
|
||||
}
|
||||
// The output of non-fallible operations is a model type which we convert into a "wrapper" unit `struct` type
|
||||
// we control that can in turn be converted into a response.
|
||||
|
||||
rustTemplate(
|
||||
"""
|
||||
pub struct $outputName(pub #{O});
|
||||
##[#{AsyncTrait}::async_trait]
|
||||
impl #{AxumCore}::response::IntoResponse for $outputName {
|
||||
fn into_response(self) -> #{AxumCore}::response::Response {
|
||||
$handleSerializeOutput
|
||||
$intoResponseImpl
|
||||
}
|
||||
}
|
||||
""".trimIndent(),
|
||||
|
@ -335,6 +315,7 @@ private class ServerHttpProtocolImplGenerator(
|
|||
val inputSymbol = symbolProvider.toSymbol(inputShape)
|
||||
val includedMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT)
|
||||
val unusedVars = if (includedMembers.isEmpty()) "##[allow(unused_variables)] " else ""
|
||||
|
||||
return RuntimeType.forInlineFun(fnName, operationDeserModule) {
|
||||
Attribute.Custom("allow(clippy::unnecessary_wraps)").render(it)
|
||||
it.rustBlockTemplate(
|
||||
|
@ -346,11 +327,11 @@ private class ServerHttpProtocolImplGenerator(
|
|||
#{SmithyRejection}
|
||||
>
|
||||
where
|
||||
B: #{SmithyHttpServer}::HttpBody + Send,
|
||||
B: #{SmithyHttpServer}::HttpBody + Send, ${getStreamingBodyTraitBounds(operationShape)}
|
||||
B::Data: Send,
|
||||
B::Error: Into<#{SmithyHttpServer}::BoxError>,
|
||||
#{SmithyRejection}: From<<B as #{SmithyHttpServer}::HttpBody>::Error>
|
||||
""",
|
||||
""".trimIndent(),
|
||||
*codegenScope,
|
||||
"I" to inputSymbol,
|
||||
) {
|
||||
|
@ -371,8 +352,12 @@ private class ServerHttpProtocolImplGenerator(
|
|||
val outputSymbol = symbolProvider.toSymbol(outputShape)
|
||||
return RuntimeType.forInlineFun(fnName, operationSerModule) {
|
||||
Attribute.Custom("allow(clippy::unnecessary_wraps)").render(it)
|
||||
|
||||
// Note we only need to take ownership of the output in the case that it contains streaming members.
|
||||
// However we currently always take ownership here, but worth noting in case in the future we want
|
||||
// to generate different signatures for streaming vs non-streaming for some reason.
|
||||
it.rustBlockTemplate(
|
||||
"pub fn $fnName(output: &#{O}) -> std::result::Result<#{AxumCore}::response::Response, #{SmithyRejection}>",
|
||||
"pub fn $fnName(output: #{O}) -> std::result::Result<#{AxumCore}::response::Response, #{SmithyRejection}>",
|
||||
*codegenScope,
|
||||
"O" to outputSymbol,
|
||||
) {
|
||||
|
@ -459,13 +444,6 @@ private class ServerHttpProtocolImplGenerator(
|
|||
operationShape: OperationShape,
|
||||
bindings: List<HttpBindingDescriptor>,
|
||||
) {
|
||||
val structuredDataSerializer = protocol.structuredDataSerializer(operationShape)
|
||||
structuredDataSerializer.serverOutputSerializer(operationShape)?.let { serializer ->
|
||||
rust(
|
||||
"let payload = #T(output)?;",
|
||||
serializer
|
||||
)
|
||||
} ?: rust("""let payload = "";""")
|
||||
// avoid non-usage warnings for response
|
||||
Attribute.AllowUnusedMut.render(this)
|
||||
rustTemplate("let mut builder = #{http}::Response::builder();", *codegenScope)
|
||||
|
@ -477,6 +455,24 @@ private class ServerHttpProtocolImplGenerator(
|
|||
serializedValue(this)
|
||||
}
|
||||
}
|
||||
val streamingMember = operationShape.outputShape(model).findStreamingMember(model)
|
||||
if (streamingMember != null) {
|
||||
val memberName = symbolProvider.toMemberName(streamingMember)
|
||||
rustTemplate(
|
||||
"""
|
||||
let payload = #{SmithyHttpServer}::body::Body::wrap_stream(output.$memberName);
|
||||
""",
|
||||
*codegenScope,
|
||||
)
|
||||
} else {
|
||||
val structuredDataSerializer = protocol.structuredDataSerializer(operationShape)
|
||||
structuredDataSerializer.serverOutputSerializer(operationShape)?.let { serializer ->
|
||||
rust(
|
||||
"let payload = #T(&output)?;",
|
||||
serializer
|
||||
)
|
||||
} ?: rust("""let payload = "";""")
|
||||
}
|
||||
rustTemplate(
|
||||
"""
|
||||
builder.body(#{SmithyHttpServer}::body::to_boxed(payload))?
|
||||
|
@ -510,11 +506,13 @@ private class ServerHttpProtocolImplGenerator(
|
|||
}
|
||||
|
||||
val bindingGenerator = ServerResponseBindingGenerator(protocol, codegenContext, operationShape)
|
||||
val addHeadersFn = bindingGenerator.generateAddHeadersFn(errorShape?: operationShape)
|
||||
val addHeadersFn = bindingGenerator.generateAddHeadersFn(errorShape ?: operationShape)
|
||||
if (addHeadersFn != null) {
|
||||
// notice that we need to borrow the output only for output shapes but not for error shapes
|
||||
val outputOwnedOrBorrow = if (errorShape == null) "&output" else "output"
|
||||
rust(
|
||||
"""
|
||||
builder = #{T}(output, builder)?;
|
||||
builder = #{T}($outputOwnedOrBorrow, builder)?;
|
||||
""".trimIndent(),
|
||||
addHeadersFn
|
||||
)
|
||||
|
@ -528,12 +526,11 @@ private class ServerHttpProtocolImplGenerator(
|
|||
val operationName = symbolProvider.toSymbol(operationShape).name
|
||||
val member = binding.member
|
||||
return when (binding.location) {
|
||||
HttpLocation.HEADER, HttpLocation.PREFIX_HEADERS, HttpLocation.DOCUMENT -> {
|
||||
// All of these are handled separately.
|
||||
null
|
||||
}
|
||||
HttpLocation.HEADER,
|
||||
HttpLocation.PREFIX_HEADERS,
|
||||
HttpLocation.DOCUMENT,
|
||||
HttpLocation.PAYLOAD -> {
|
||||
logger.warning("[rust-server-codegen] $operationName: response serialization does not currently support ${binding.location} bindings")
|
||||
// All of these are handled separately.
|
||||
null
|
||||
}
|
||||
HttpLocation.RESPONSE_CODE -> writable {
|
||||
|
@ -608,18 +605,28 @@ private class ServerHttpProtocolImplGenerator(
|
|||
return when (binding.location) {
|
||||
HttpLocation.HEADER -> writable { serverRenderHeaderParser(this, binding, operationShape) }
|
||||
HttpLocation.PAYLOAD -> {
|
||||
val structureShapeHandler: RustWriter.(String) -> Unit = { body ->
|
||||
rust("#T($body)", structuredDataParser.payloadParser(binding.member))
|
||||
}
|
||||
val deserializer = httpBindingGenerator.generateDeserializePayloadFn(
|
||||
operationShape,
|
||||
binding,
|
||||
errorSymbol,
|
||||
structuredHandler = structureShapeHandler
|
||||
)
|
||||
return if (binding.member.isStreaming(model)) {
|
||||
writable { rust("""todo!("streaming request bodies");""") }
|
||||
writable {
|
||||
rustTemplate(
|
||||
"""
|
||||
{
|
||||
let body = request.take_body().ok_or(#{SmithyHttpServer}::rejection::BodyAlreadyExtracted)?;
|
||||
Some(body.into())
|
||||
}
|
||||
""".trimIndent(),
|
||||
*codegenScope
|
||||
)
|
||||
}
|
||||
} else {
|
||||
val structureShapeHandler: RustWriter.(String) -> Unit = { body ->
|
||||
rust("#T($body)", structuredDataParser.payloadParser(binding.member))
|
||||
}
|
||||
val deserializer = httpBindingGenerator.generateDeserializePayloadFn(
|
||||
operationShape,
|
||||
binding,
|
||||
errorSymbol,
|
||||
structuredHandler = structureShapeHandler
|
||||
)
|
||||
writable {
|
||||
rustTemplate(
|
||||
"""
|
||||
|
@ -1047,4 +1054,12 @@ private class ServerHttpProtocolImplGenerator(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun getStreamingBodyTraitBounds(operationShape: OperationShape): String {
|
||||
if (operationShape.inputShape(model).hasStreamingMember(model)) {
|
||||
return "\n B: Into<#{SmithyHttp}::byte_stream::ByteStream>,"
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ bytes = "1.1"
|
|||
futures-util = { version = "0.3", default-features = false }
|
||||
http = "0.2"
|
||||
http-body = "0.4"
|
||||
hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp"] }
|
||||
hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp", "stream"] }
|
||||
mime = "0.3"
|
||||
nom = "7"
|
||||
pin-project-lite = "0.2"
|
||||
|
|
|
@ -178,6 +178,12 @@ impl From<aws_smithy_types::date_time::DateTimeParseError> for SmithyRejection {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<aws_smithy_types::date_time::DateTimeFormatError> for SmithyRejection {
|
||||
fn from(err: aws_smithy_types::date_time::DateTimeFormatError) -> Self {
|
||||
SmithyRejection::Serialize(Serialize::from_err(err))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<aws_smithy_types::primitive::PrimitiveParseError> for SmithyRejection {
|
||||
fn from(err: aws_smithy_types::primitive::PrimitiveParseError) -> Self {
|
||||
SmithyRejection::Deserialize(Deserialize::from_err(err))
|
||||
|
|
|
@ -326,6 +326,14 @@ impl From<Vec<u8>> for ByteStream {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<hyper::Body> for ByteStream {
|
||||
fn from(input: hyper::Body) -> Self {
|
||||
ByteStream::new(SdkBody::from_dyn(
|
||||
input.map_err(|e| e.into_cause().unwrap()).boxed(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Error(Box<dyn StdError + Send + Sync + 'static>);
|
||||
|
||||
|
|
Loading…
Reference in New Issue