From a978c9edc0f3ec33b7b13aad9161d1929d494e5f Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Mon, 4 Apr 2022 10:54:58 -0700 Subject: [PATCH] okhttp: Avoid test-specific transport.start() With the completely different constructor it was hard to track which fields were different during the test and reduced confidence. Now the test code flows are much closer to the real-life code flows. --- .../okhttp/ExceptionHandlingFrameWriter.java | 13 +- .../io/grpc/okhttp/OkHttpClientTransport.java | 152 ++++++++++-------- .../ExceptionHandlingFrameWriterTest.java | 5 +- .../okhttp/OkHttpClientTransportTest.java | 123 ++++++++++++-- 4 files changed, 191 insertions(+), 102 deletions(-) diff --git a/okhttp/src/main/java/io/grpc/okhttp/ExceptionHandlingFrameWriter.java b/okhttp/src/main/java/io/grpc/okhttp/ExceptionHandlingFrameWriter.java index 9f7074121f..577d9c8de7 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/ExceptionHandlingFrameWriter.java +++ b/okhttp/src/main/java/io/grpc/okhttp/ExceptionHandlingFrameWriter.java @@ -39,23 +39,14 @@ final class ExceptionHandlingFrameWriter implements FrameWriter { private final FrameWriter frameWriter; - private final OkHttpFrameLogger frameLogger; + private final OkHttpFrameLogger frameLogger = + new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class); ExceptionHandlingFrameWriter( TransportExceptionHandler transportExceptionHandler, FrameWriter frameWriter) { - this(transportExceptionHandler, frameWriter, - new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class)); - } - - @VisibleForTesting - ExceptionHandlingFrameWriter( - TransportExceptionHandler transportExceptionHandler, - FrameWriter frameWriter, - OkHttpFrameLogger frameLogger) { this.transportExceptionHandler = checkNotNull(transportExceptionHandler, "transportExceptionHandler"); this.frameWriter = Preconditions.checkNotNull(frameWriter, "frameWriter"); - this.frameLogger = Preconditions.checkNotNull(frameLogger, "frameLogger"); } @Override diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index e233fa2002..ad39112d8c 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -148,9 +148,8 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep // Returns new unstarted stopwatches private final Supplier stopwatchFactory; private final int initialWindowSize; + private final Variant variant; private Listener listener; - private FrameReader testFrameReader; - private OkHttpFrameLogger testFrameLogger; @GuardedBy("lock") private ExceptionHandlingFrameWriter frameWriter; private OutboundFlowController outboundFlow; @@ -192,7 +191,6 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep @GuardedBy("lock") private final Deque pendingStreams = new LinkedList<>(); private final ConnectionSpec connectionSpec; - private FrameWriter testFrameWriter; private ScheduledExecutorService scheduler; private KeepAliveManager keepAliveManager; private boolean enableKeepAlive; @@ -228,7 +226,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep Runnable connectingCallback; SettableFuture connectedFuture; - OkHttpClientTransport( + public OkHttpClientTransport( InetSocketAddress address, String authority, @Nullable String userAgent, @@ -245,6 +243,46 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep int maxInboundMetadataSize, TransportTracer transportTracer, boolean useGetForSafeMethods) { + this( + address, + authority, + userAgent, + eagAttrs, + executor, + socketFactory, + sslSocketFactory, + hostnameVerifier, + connectionSpec, + GrpcUtil.STOPWATCH_SUPPLIER, + new Http2(), + maxMessageSize, + initialWindowSize, + proxiedAddr, + tooManyPingsRunnable, + maxInboundMetadataSize, + transportTracer, + useGetForSafeMethods); + } + + private OkHttpClientTransport( + InetSocketAddress address, + String authority, + @Nullable String userAgent, + Attributes eagAttrs, + Executor executor, + @Nullable SocketFactory socketFactory, + @Nullable SSLSocketFactory sslSocketFactory, + @Nullable HostnameVerifier hostnameVerifier, + ConnectionSpec connectionSpec, + Supplier stopwatchFactory, + Variant variant, + int maxMessageSize, + int initialWindowSize, + @Nullable HttpConnectProxiedSocketAddress proxiedAddr, + Runnable tooManyPingsRunnable, + int maxInboundMetadataSize, + TransportTracer transportTracer, + boolean useGetForSafeMethods) { this.address = Preconditions.checkNotNull(address, "address"); this.defaultAuthority = authority; this.maxMessageSize = maxMessageSize; @@ -258,7 +296,8 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep this.sslSocketFactory = sslSocketFactory; this.hostnameVerifier = hostnameVerifier; this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec"); - this.stopwatchFactory = GrpcUtil.STOPWATCH_SUPPLIER; + this.stopwatchFactory = Preconditions.checkNotNull(stopwatchFactory, "stopwatchFactory"); + this.variant = Preconditions.checkNotNull(variant, "variant"); this.userAgent = GrpcUtil.getGrpcUserAgent("okhttp", userAgent); this.proxiedAddr = proxiedAddr; this.tooManyPingsRunnable = @@ -279,43 +318,36 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep OkHttpClientTransport( String userAgent, Executor executor, - FrameReader frameReader, - FrameWriter testFrameWriter, - OkHttpFrameLogger testFrameLogger, - int nextStreamId, - Socket socket, + @Nullable SocketFactory socketFactory, Supplier stopwatchFactory, + Variant variant, @Nullable Runnable connectingCallback, SettableFuture connectedFuture, int maxMessageSize, int initialWindowSize, Runnable tooManyPingsRunnable, TransportTracer transportTracer) { - useGetForSafeMethods = false; - address = null; - this.maxMessageSize = maxMessageSize; - this.initialWindowSize = initialWindowSize; - defaultAuthority = "notarealauthority:80"; - this.userAgent = GrpcUtil.getGrpcUserAgent("okhttp", userAgent); - this.executor = Preconditions.checkNotNull(executor, "executor"); - serializingExecutor = new SerializingExecutor(executor); - this.socketFactory = SocketFactory.getDefault(); - this.testFrameReader = Preconditions.checkNotNull(frameReader, "frameReader"); - this.testFrameWriter = Preconditions.checkNotNull(testFrameWriter, "testFrameWriter"); - this.testFrameLogger = Preconditions.checkNotNull(testFrameLogger, "testFrameLogger"); - this.socket = Preconditions.checkNotNull(socket, "socket"); - this.nextStreamId = nextStreamId; - this.stopwatchFactory = stopwatchFactory; - this.connectionSpec = null; + this( + new InetSocketAddress("127.0.0.1", 80), + "notarealauthority:80", + userAgent, + Attributes.EMPTY, + executor, + socketFactory, + null, + null, + OkHttpChannelBuilder.INTERNAL_DEFAULT_CONNECTION_SPEC, + stopwatchFactory, + variant, + maxMessageSize, + initialWindowSize, + null, + tooManyPingsRunnable, + Integer.MAX_VALUE, + transportTracer, + false); this.connectingCallback = connectingCallback; this.connectedFuture = Preconditions.checkNotNull(connectedFuture, "connectedFuture"); - this.proxiedAddr = null; - this.tooManyPingsRunnable = - Preconditions.checkNotNull(tooManyPingsRunnable, "tooManyPingsRunnable"); - this.maxInboundMetadataSize = Integer.MAX_VALUE; - this.transportTracer = Preconditions.checkNotNull(transportTracer, "transportTracer"); - this.logId = InternalLogId.allocate(getClass(), String.valueOf(socket.getInetAddress())); - initTransportTracer(); } // sslSocketFactory is set to null when use plaintext. @@ -349,10 +381,6 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep this.keepAliveWithoutCalls = keepAliveWithoutCalls; } - private boolean isForTest() { - return address == null; - } - @Override public void ping(final PingCallback callback, Executor executor) { long data = 0; @@ -488,32 +516,8 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep keepAliveWithoutCalls); keepAliveManager.onTransportStarted(); } - if (isForTest()) { - synchronized (lock) { - frameWriter = new ExceptionHandlingFrameWriter(OkHttpClientTransport.this, testFrameWriter, - testFrameLogger); - outboundFlow = new OutboundFlowController(OkHttpClientTransport.this, frameWriter); - } - serializingExecutor.execute(new Runnable() { - @Override - public void run() { - if (connectingCallback != null) { - connectingCallback.run(); - } - clientFrameHandler = new ClientFrameHandler(testFrameReader, testFrameLogger); - executor.execute(clientFrameHandler); - synchronized (lock) { - maxConcurrentStreams = Integer.MAX_VALUE; - startPendingStreams(); - } - connectedFuture.set(null); - } - }); - return null; - } final AsyncSink asyncSink = AsyncSink.sink(serializingExecutor, this); - final Variant variant = new Http2(); FrameWriter rawFrameWriter = variant.newWriter(Okio.buffer(asyncSink), true); synchronized (lock) { @@ -616,6 +620,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep serializingExecutor.execute(new Runnable() { @Override public void run() { + if (connectingCallback != null) { + connectingCallback.run(); + } // ClientFrameHandler need to be started after connectionPreface / settings, otherwise it // may send goAway immediately. executor.execute(clientFrameHandler); @@ -623,6 +630,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep maxConcurrentStreams = Integer.MAX_VALUE; startPendingStreams(); } + if (connectedFuture != null) { + connectedFuture.set(null); + } } }); return null; @@ -631,8 +641,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep /** * Should only be called once when the transport is first established. */ - @VisibleForTesting - void sendConnectionPrefaceAndSettings() { + private void sendConnectionPrefaceAndSettings() { synchronized (lock) { frameWriter.connectionPreface(); Settings settings = new Settings(); @@ -855,6 +864,13 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep } } + @VisibleForTesting + void setNextStreamId(int nextStreamId) { + synchronized (lock) { + this.nextStreamId = nextStreamId; + } + } + /** * Finish all active streams due to an IOException, then close the transport. */ @@ -1081,21 +1097,15 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep /** * Runnable which reads frames and dispatches them to in flight calls. */ - @VisibleForTesting class ClientFrameHandler implements FrameReader.Handler, Runnable { - private final OkHttpFrameLogger logger; + private final OkHttpFrameLogger logger = + new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class); FrameReader frameReader; boolean firstSettings = true; ClientFrameHandler(FrameReader frameReader) { - this(frameReader, new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class)); - } - - @VisibleForTesting - ClientFrameHandler(FrameReader frameReader, OkHttpFrameLogger frameLogger) { this.frameReader = frameReader; - logger = frameLogger; } @Override diff --git a/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java b/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java index c26edcd0df..a9d3908884 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java @@ -50,8 +50,7 @@ public class ExceptionHandlingFrameWriterTest { private final TransportExceptionHandler transportExceptionHandler = mock(TransportExceptionHandler.class); private final ExceptionHandlingFrameWriter exceptionHandlingFrameWriter = - new ExceptionHandlingFrameWriter(transportExceptionHandler, mockedFrameWriter, - new OkHttpFrameLogger(Level.FINE, logger)); + new ExceptionHandlingFrameWriter(transportExceptionHandler, mockedFrameWriter); @Test public void exception() throws IOException { @@ -194,4 +193,4 @@ public class ExceptionHandlingFrameWriterTest { logger.removeHandler(handler); } -} \ No newline at end of file +} diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index 1628df4d3c..6e74983240 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -48,6 +48,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; +import com.google.common.base.Preconditions; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; import com.google.common.base.Ticker; @@ -78,18 +79,23 @@ import io.grpc.internal.TransportTracer; import io.grpc.okhttp.OkHttpClientTransport.ClientFrameHandler; import io.grpc.okhttp.OkHttpFrameLogger.Direction; import io.grpc.okhttp.internal.ConnectionSpec; +import io.grpc.okhttp.internal.Protocol; import io.grpc.okhttp.internal.framed.ErrorCode; import io.grpc.okhttp.internal.framed.FrameReader; import io.grpc.okhttp.internal.framed.FrameWriter; import io.grpc.okhttp.internal.framed.Header; import io.grpc.okhttp.internal.framed.HeadersMode; import io.grpc.okhttp.internal.framed.Settings; +import io.grpc.okhttp.internal.framed.Variant; import io.grpc.testing.TestMethodDescriptors; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.io.OutputStream; +import java.io.PipedInputStream; +import java.io.PipedOutputStream; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.ServerSocket; @@ -116,6 +122,8 @@ import javax.net.SocketFactory; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLSocketFactory; import okio.Buffer; +import okio.BufferedSink; +import okio.BufferedSource; import okio.ByteString; import org.junit.After; import org.junit.Before; @@ -154,8 +162,6 @@ public class OkHttpClientTransportTest { @Rule public final Timeout globalTimeout = Timeout.seconds(10); - private FrameWriter frameWriter; - private MethodDescriptor method = TestMethodDescriptors.voidMethod(); @Mock @@ -167,8 +173,10 @@ public class OkHttpClientTransportTest { private final TransportTracer transportTracer = new TransportTracer(); private final Queue capturedBuffer = new ArrayDeque<>(); private OkHttpClientTransport clientTransport; - private MockFrameReader frameReader; - private Socket socket; + private final MockFrameReader frameReader = new MockFrameReader(); + private final Socket socket = new MockSocket(frameReader); + private FrameWriter frameWriter = mock(FrameWriter.class, AdditionalAnswers.delegatesTo( + new MockFrameWriter(socket, capturedBuffer))); private ExecutorService executor = Executors.newCachedThreadPool(); private long nanoTime; // backs a ticker, for testing ping round-trip time measurement private SettableFuture connectedFuture; @@ -183,10 +191,6 @@ public class OkHttpClientTransportTest { @Before public void setUp() { MockitoAnnotations.initMocks(this); - frameReader = new MockFrameReader(); - socket = new MockSocket(frameReader); - MockFrameWriter mockFrameWriter = new MockFrameWriter(socket, capturedBuffer); - frameWriter = mock(FrameWriter.class, AdditionalAnswers.delegatesTo(mockFrameWriter)); } @After @@ -233,12 +237,9 @@ public class OkHttpClientTransportTest { clientTransport = new OkHttpClientTransport( userAgent, executor, - frameReader, - frameWriter, - new OkHttpFrameLogger(Level.FINE, logger), - startId, - socket, + new FakeSocketFactory(socket), stopwatchSupplier, + new FakeVariant(frameReader, frameWriter), connectingCallback, connectedFuture, maxMessageSize, @@ -249,6 +250,9 @@ public class OkHttpClientTransportTest { if (waitingForConnected) { connectedFuture.get(TIME_OUT_MS, TimeUnit.MILLISECONDS); } + if (startId != DEFAULT_START_STREAM_ID) { + clientTransport.setNextStreamId(startId); + } } @Test @@ -301,6 +305,10 @@ public class OkHttpClientTransportTest { logger.setLevel(Level.ALL); initTransport(); + assertThat(logs).hasSize(1); + LogRecord log = logs.remove(0); + assertThat(log.getMessage()).startsWith(Direction.OUTBOUND + " SETTINGS: ack=false"); + assertThat(log.getLevel()).isEqualTo(Level.FINE); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = @@ -310,7 +318,7 @@ public class OkHttpClientTransportTest { frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); assertThat(logs).hasSize(1); - LogRecord log = logs.remove(0); + log = logs.remove(0); assertThat(log.getMessage()).startsWith(Direction.INBOUND + " HEADERS: streamId=" + 3); assertThat(log.getLevel()).isEqualTo(Level.FINE); @@ -414,7 +422,6 @@ public class OkHttpClientTransportTest { int initialWindowSize = 65535; startTransport( DEFAULT_START_STREAM_ID, null, true, DEFAULT_MAX_MESSAGE_SIZE, initialWindowSize, null); - clientTransport.sendConnectionPrefaceAndSettings(); ArgumentCaptor settings = ArgumentCaptor.forClass(Settings.class); verify(frameWriter, timeout(TIME_OUT_MS)).settings(settings.capture()); @@ -430,7 +437,6 @@ public class OkHttpClientTransportTest { int initialWindowSize = 75535; // 65535 + 10000 startTransport( DEFAULT_START_STREAM_ID, null, true, DEFAULT_MAX_MESSAGE_SIZE, initialWindowSize, null); - clientTransport.sendConnectionPrefaceAndSettings(); ArgumentCaptor settings = ArgumentCaptor.forClass(Settings.class); verify(frameWriter, timeout(TIME_OUT_MS)).settings(settings.capture()); @@ -1697,6 +1703,7 @@ public class OkHttpClientTransportTest { @Test public void writeBeforeConnected() throws Exception { initTransportAndDelayConnected(); + reset(frameWriter); final String message = "Hello Server"; MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = @@ -1722,6 +1729,7 @@ public class OkHttpClientTransportTest { @Test public void cancelBeforeConnected() throws Exception { initTransportAndDelayConnected(); + reset(frameWriter); final String message = "Hello Server"; MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = @@ -2385,10 +2393,20 @@ public class OkHttpClientTransportTest { } private static class MockSocket extends Socket { - MockFrameReader frameReader; + final MockFrameReader frameReader; + private final PipedOutputStream outputStream = new PipedOutputStream(); + private final PipedInputStream outputStreamSink = new PipedInputStream(); + private final PipedOutputStream inputStreamSource = new PipedOutputStream(); + private final PipedInputStream inputStream = new PipedInputStream(); MockSocket(MockFrameReader frameReader) { this.frameReader = frameReader; + try { + outputStreamSink.connect(outputStream); + inputStream.connect(inputStreamSource); + } catch (IOException ex) { + throw new AssertionError(ex); + } } @Override @@ -2400,6 +2418,16 @@ public class OkHttpClientTransportTest { public SocketAddress getLocalSocketAddress() { return InetSocketAddress.createUnresolved("localhost", 4000); } + + @Override + public OutputStream getOutputStream() { + return outputStream; + } + + @Override + public InputStream getInputStream() { + return inputStream; + } } static class PingCallbackImpl implements ClientTransport.PingCallback { @@ -2559,4 +2587,65 @@ public class OkHttpClientTransportTest { throw exception; } } + + static class FakeSocketFactory extends SocketFactory { + private Socket socket; + + public FakeSocketFactory(Socket socket) { + this.socket = Preconditions.checkNotNull(socket, "socket"); + } + + @Override public Socket createSocket() { + Preconditions.checkNotNull(this.socket, "socket"); + Socket socket = this.socket; + this.socket = null; + return socket; + } + + @Override public Socket createSocket(InetAddress host, int port) { + return createSocket(); + } + + @Override public Socket createSocket( + InetAddress host, int port, InetAddress localAddress, int localPort) { + return createSocket(); + } + + @Override public Socket createSocket(String host, int port) { + return createSocket(); + } + + @Override public Socket createSocket( + String host, int port, InetAddress localHost, int localPort) { + return createSocket(); + } + } + + static class FakeVariant implements Variant { + private FrameReader frameReader; + private FrameWriter frameWriter; + + public FakeVariant(FrameReader frameReader, FrameWriter frameWriter) { + this.frameReader = frameReader; + this.frameWriter = frameWriter; + } + + @Override public Protocol getProtocol() { + return Protocol.HTTP_2; + } + + @Override public FrameReader newReader(BufferedSource source, boolean client) { + Preconditions.checkNotNull(this.frameReader, "frameReader"); + FrameReader frameReader = this.frameReader; + this.frameReader = null; + return frameReader; + } + + @Override public FrameWriter newWriter(BufferedSink sink, boolean client) { + Preconditions.checkNotNull(this.frameWriter, "frameWriter"); + FrameWriter frameWriter = this.frameWriter; + this.frameWriter = null; + return frameWriter; + } + } }