diff --git a/api/src/main/java/io/grpc/ServerInterceptors.java b/api/src/main/java/io/grpc/ServerInterceptors.java index 153c7b4068..b44a7b6c88 100644 --- a/api/src/main/java/io/grpc/ServerInterceptors.java +++ b/api/src/main/java/io/grpc/ServerInterceptors.java @@ -184,6 +184,31 @@ public final class ServerInterceptors { public static ServerServiceDefinition useMarshalledMessages( final ServerServiceDefinition serviceDef, final MethodDescriptor.Marshaller marshaller) { + return useMarshalledMessages(serviceDef, marshaller, marshaller); + } + + /** + * Create a new {@code ServerServiceDefinition} with {@link MethodDescriptor} for deserializing + * requests and separate {@link MethodDescriptor} for serializing responses. The {@code + * ServerCallHandler} created will automatically convert back to the original types for request + * and response before calling the existing {@code ServerCallHandler}. Calling this method + * combined with the intercept methods will allow the developer to choose whether to intercept + * messages of ReqT/RespT, or the modeled types of their application. This can also be chained + * to allow for interceptors to handle messages as multiple different ReqT/RespT types within + * the chain if the added cost of serialization is not a concern. + * + * @param serviceDef the sevice definition to add request and response marshallers to. + * @param requestMarshaller request marshaller + * @param responseMarshaller response marshaller + * @param the request payload type + * @param the response payload type. + * @return a wrapped version of {@code serviceDef} with the ReqT and RespT conversion applied. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/9870") + public static ServerServiceDefinition useMarshalledMessages( + final ServerServiceDefinition serviceDef, + final MethodDescriptor.Marshaller requestMarshaller, + final MethodDescriptor.Marshaller responseMarshaller) { List> wrappedMethods = new ArrayList<>(); List> wrappedDescriptors = @@ -191,8 +216,8 @@ public final class ServerInterceptors { // Wrap the descriptors for (final ServerMethodDefinition definition : serviceDef.getMethods()) { final MethodDescriptor originalMethodDescriptor = definition.getMethodDescriptor(); - final MethodDescriptor wrappedMethodDescriptor = - originalMethodDescriptor.toBuilder(marshaller, marshaller).build(); + final MethodDescriptor wrappedMethodDescriptor = + originalMethodDescriptor.toBuilder(requestMarshaller, responseMarshaller).build(); wrappedDescriptors.add(wrappedMethodDescriptor); wrappedMethods.add(wrapMethod(definition, wrappedMethodDescriptor)); } diff --git a/api/src/test/java/io/grpc/ServerInterceptorsTest.java b/api/src/test/java/io/grpc/ServerInterceptorsTest.java index fa4e7e747a..c9e426a7f9 100644 --- a/api/src/test/java/io/grpc/ServerInterceptorsTest.java +++ b/api/src/test/java/io/grpc/ServerInterceptorsTest.java @@ -425,6 +425,80 @@ public class ServerInterceptorsTest { order); } + /** + * Tests the ServerInterceptors#useMarshalledMessages()} with two marshallers. Makes sure that + * on incoming request the request marshaller's stream method is called and on response the + * response marshaller's parse method is called + */ + @Test + @SuppressWarnings("unchecked") + public void distinctMarshallerForRequestAndResponse() { + final List requestFlowOrder = new ArrayList<>(); + + final Marshaller requestMarshaller = new Marshaller() { + @Override + public InputStream stream(String value) { + requestFlowOrder.add("RequestStream"); + return null; + } + + @Override + public String parse(InputStream stream) { + requestFlowOrder.add("RequestParse"); + return null; + } + }; + final Marshaller responseMarshaller = new Marshaller() { + @Override + public InputStream stream(String value) { + requestFlowOrder.add("ResponseStream"); + return null; + } + + @Override + public String parse(InputStream stream) { + requestFlowOrder.add("ResponseParse"); + return null; + } + }; + final Marshaller dummyMarshaller = new Marshaller() { + @Override + public InputStream stream(Holder value) { + return value.get(); + } + + @Override + public Holder parse(InputStream stream) { + return new Holder(stream); + } + }; + ServerCallHandler handler = (call, headers) -> new Listener() { + @Override + public void onMessage(Holder message) { + requestFlowOrder.add("handler"); + call.sendMessage(message); + } + }; + + MethodDescriptor wrappedMethod = MethodDescriptor.newBuilder() + .setType(MethodType.UNKNOWN) + .setFullMethodName("basic/wrapped") + .setRequestMarshaller(dummyMarshaller) + .setResponseMarshaller(dummyMarshaller) + .build(); + ServerServiceDefinition serviceDef = ServerServiceDefinition.builder( + new ServiceDescriptor("basic", wrappedMethod)) + .addMethod(wrappedMethod, handler).build(); + ServerServiceDefinition intercepted = ServerInterceptors.useMarshalledMessages(serviceDef, + requestMarshaller, responseMarshaller); + ServerMethodDefinition serverMethod = + (ServerMethodDefinition) intercepted.getMethod("basic/wrapped"); + ServerCall serverCall = new NoopServerCall<>(); + serverMethod.getServerCallHandler().startCall(serverCall, headers).onMessage("TestMessage"); + + assertEquals(Arrays.asList("RequestStream", "handler", "ResponseParse"), requestFlowOrder); + } + @SuppressWarnings("unchecked") private static ServerMethodDefinition getSoleMethod( ServerServiceDefinition serviceDef) {