diff --git a/okhttp/src/main/java/io/grpc/okhttp/ExceptionHandlingFrameWriter.java b/okhttp/src/main/java/io/grpc/okhttp/ExceptionHandlingFrameWriter.java index 9f7074121f..577d9c8de7 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/ExceptionHandlingFrameWriter.java +++ b/okhttp/src/main/java/io/grpc/okhttp/ExceptionHandlingFrameWriter.java @@ -39,23 +39,14 @@ final class ExceptionHandlingFrameWriter implements FrameWriter { private final FrameWriter frameWriter; - private final OkHttpFrameLogger frameLogger; + private final OkHttpFrameLogger frameLogger = + new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class); ExceptionHandlingFrameWriter( TransportExceptionHandler transportExceptionHandler, FrameWriter frameWriter) { - this(transportExceptionHandler, frameWriter, - new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class)); - } - - @VisibleForTesting - ExceptionHandlingFrameWriter( - TransportExceptionHandler transportExceptionHandler, - FrameWriter frameWriter, - OkHttpFrameLogger frameLogger) { this.transportExceptionHandler = checkNotNull(transportExceptionHandler, "transportExceptionHandler"); this.frameWriter = Preconditions.checkNotNull(frameWriter, "frameWriter"); - this.frameLogger = Preconditions.checkNotNull(frameLogger, "frameLogger"); } @Override diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index e233fa2002..ad39112d8c 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -148,9 +148,8 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep // Returns new unstarted stopwatches private final Supplier stopwatchFactory; private final int initialWindowSize; + private final Variant variant; private Listener listener; - private FrameReader testFrameReader; - private OkHttpFrameLogger testFrameLogger; @GuardedBy("lock") private ExceptionHandlingFrameWriter frameWriter; private OutboundFlowController outboundFlow; @@ -192,7 +191,6 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep @GuardedBy("lock") private final Deque pendingStreams = new LinkedList<>(); private final ConnectionSpec connectionSpec; - private FrameWriter testFrameWriter; private ScheduledExecutorService scheduler; private KeepAliveManager keepAliveManager; private boolean enableKeepAlive; @@ -228,7 +226,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep Runnable connectingCallback; SettableFuture connectedFuture; - OkHttpClientTransport( + public OkHttpClientTransport( InetSocketAddress address, String authority, @Nullable String userAgent, @@ -245,6 +243,46 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep int maxInboundMetadataSize, TransportTracer transportTracer, boolean useGetForSafeMethods) { + this( + address, + authority, + userAgent, + eagAttrs, + executor, + socketFactory, + sslSocketFactory, + hostnameVerifier, + connectionSpec, + GrpcUtil.STOPWATCH_SUPPLIER, + new Http2(), + maxMessageSize, + initialWindowSize, + proxiedAddr, + tooManyPingsRunnable, + maxInboundMetadataSize, + transportTracer, + useGetForSafeMethods); + } + + private OkHttpClientTransport( + InetSocketAddress address, + String authority, + @Nullable String userAgent, + Attributes eagAttrs, + Executor executor, + @Nullable SocketFactory socketFactory, + @Nullable SSLSocketFactory sslSocketFactory, + @Nullable HostnameVerifier hostnameVerifier, + ConnectionSpec connectionSpec, + Supplier stopwatchFactory, + Variant variant, + int maxMessageSize, + int initialWindowSize, + @Nullable HttpConnectProxiedSocketAddress proxiedAddr, + Runnable tooManyPingsRunnable, + int maxInboundMetadataSize, + TransportTracer transportTracer, + boolean useGetForSafeMethods) { this.address = Preconditions.checkNotNull(address, "address"); this.defaultAuthority = authority; this.maxMessageSize = maxMessageSize; @@ -258,7 +296,8 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep this.sslSocketFactory = sslSocketFactory; this.hostnameVerifier = hostnameVerifier; this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec"); - this.stopwatchFactory = GrpcUtil.STOPWATCH_SUPPLIER; + this.stopwatchFactory = Preconditions.checkNotNull(stopwatchFactory, "stopwatchFactory"); + this.variant = Preconditions.checkNotNull(variant, "variant"); this.userAgent = GrpcUtil.getGrpcUserAgent("okhttp", userAgent); this.proxiedAddr = proxiedAddr; this.tooManyPingsRunnable = @@ -279,43 +318,36 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep OkHttpClientTransport( String userAgent, Executor executor, - FrameReader frameReader, - FrameWriter testFrameWriter, - OkHttpFrameLogger testFrameLogger, - int nextStreamId, - Socket socket, + @Nullable SocketFactory socketFactory, Supplier stopwatchFactory, + Variant variant, @Nullable Runnable connectingCallback, SettableFuture connectedFuture, int maxMessageSize, int initialWindowSize, Runnable tooManyPingsRunnable, TransportTracer transportTracer) { - useGetForSafeMethods = false; - address = null; - this.maxMessageSize = maxMessageSize; - this.initialWindowSize = initialWindowSize; - defaultAuthority = "notarealauthority:80"; - this.userAgent = GrpcUtil.getGrpcUserAgent("okhttp", userAgent); - this.executor = Preconditions.checkNotNull(executor, "executor"); - serializingExecutor = new SerializingExecutor(executor); - this.socketFactory = SocketFactory.getDefault(); - this.testFrameReader = Preconditions.checkNotNull(frameReader, "frameReader"); - this.testFrameWriter = Preconditions.checkNotNull(testFrameWriter, "testFrameWriter"); - this.testFrameLogger = Preconditions.checkNotNull(testFrameLogger, "testFrameLogger"); - this.socket = Preconditions.checkNotNull(socket, "socket"); - this.nextStreamId = nextStreamId; - this.stopwatchFactory = stopwatchFactory; - this.connectionSpec = null; + this( + new InetSocketAddress("127.0.0.1", 80), + "notarealauthority:80", + userAgent, + Attributes.EMPTY, + executor, + socketFactory, + null, + null, + OkHttpChannelBuilder.INTERNAL_DEFAULT_CONNECTION_SPEC, + stopwatchFactory, + variant, + maxMessageSize, + initialWindowSize, + null, + tooManyPingsRunnable, + Integer.MAX_VALUE, + transportTracer, + false); this.connectingCallback = connectingCallback; this.connectedFuture = Preconditions.checkNotNull(connectedFuture, "connectedFuture"); - this.proxiedAddr = null; - this.tooManyPingsRunnable = - Preconditions.checkNotNull(tooManyPingsRunnable, "tooManyPingsRunnable"); - this.maxInboundMetadataSize = Integer.MAX_VALUE; - this.transportTracer = Preconditions.checkNotNull(transportTracer, "transportTracer"); - this.logId = InternalLogId.allocate(getClass(), String.valueOf(socket.getInetAddress())); - initTransportTracer(); } // sslSocketFactory is set to null when use plaintext. @@ -349,10 +381,6 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep this.keepAliveWithoutCalls = keepAliveWithoutCalls; } - private boolean isForTest() { - return address == null; - } - @Override public void ping(final PingCallback callback, Executor executor) { long data = 0; @@ -488,32 +516,8 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep keepAliveWithoutCalls); keepAliveManager.onTransportStarted(); } - if (isForTest()) { - synchronized (lock) { - frameWriter = new ExceptionHandlingFrameWriter(OkHttpClientTransport.this, testFrameWriter, - testFrameLogger); - outboundFlow = new OutboundFlowController(OkHttpClientTransport.this, frameWriter); - } - serializingExecutor.execute(new Runnable() { - @Override - public void run() { - if (connectingCallback != null) { - connectingCallback.run(); - } - clientFrameHandler = new ClientFrameHandler(testFrameReader, testFrameLogger); - executor.execute(clientFrameHandler); - synchronized (lock) { - maxConcurrentStreams = Integer.MAX_VALUE; - startPendingStreams(); - } - connectedFuture.set(null); - } - }); - return null; - } final AsyncSink asyncSink = AsyncSink.sink(serializingExecutor, this); - final Variant variant = new Http2(); FrameWriter rawFrameWriter = variant.newWriter(Okio.buffer(asyncSink), true); synchronized (lock) { @@ -616,6 +620,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep serializingExecutor.execute(new Runnable() { @Override public void run() { + if (connectingCallback != null) { + connectingCallback.run(); + } // ClientFrameHandler need to be started after connectionPreface / settings, otherwise it // may send goAway immediately. executor.execute(clientFrameHandler); @@ -623,6 +630,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep maxConcurrentStreams = Integer.MAX_VALUE; startPendingStreams(); } + if (connectedFuture != null) { + connectedFuture.set(null); + } } }); return null; @@ -631,8 +641,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep /** * Should only be called once when the transport is first established. */ - @VisibleForTesting - void sendConnectionPrefaceAndSettings() { + private void sendConnectionPrefaceAndSettings() { synchronized (lock) { frameWriter.connectionPreface(); Settings settings = new Settings(); @@ -855,6 +864,13 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep } } + @VisibleForTesting + void setNextStreamId(int nextStreamId) { + synchronized (lock) { + this.nextStreamId = nextStreamId; + } + } + /** * Finish all active streams due to an IOException, then close the transport. */ @@ -1081,21 +1097,15 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep /** * Runnable which reads frames and dispatches them to in flight calls. */ - @VisibleForTesting class ClientFrameHandler implements FrameReader.Handler, Runnable { - private final OkHttpFrameLogger logger; + private final OkHttpFrameLogger logger = + new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class); FrameReader frameReader; boolean firstSettings = true; ClientFrameHandler(FrameReader frameReader) { - this(frameReader, new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class)); - } - - @VisibleForTesting - ClientFrameHandler(FrameReader frameReader, OkHttpFrameLogger frameLogger) { this.frameReader = frameReader; - logger = frameLogger; } @Override diff --git a/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java b/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java index c26edcd0df..a9d3908884 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java @@ -50,8 +50,7 @@ public class ExceptionHandlingFrameWriterTest { private final TransportExceptionHandler transportExceptionHandler = mock(TransportExceptionHandler.class); private final ExceptionHandlingFrameWriter exceptionHandlingFrameWriter = - new ExceptionHandlingFrameWriter(transportExceptionHandler, mockedFrameWriter, - new OkHttpFrameLogger(Level.FINE, logger)); + new ExceptionHandlingFrameWriter(transportExceptionHandler, mockedFrameWriter); @Test public void exception() throws IOException { @@ -194,4 +193,4 @@ public class ExceptionHandlingFrameWriterTest { logger.removeHandler(handler); } -} \ No newline at end of file +} diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index 1628df4d3c..6e74983240 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -48,6 +48,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; +import com.google.common.base.Preconditions; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; import com.google.common.base.Ticker; @@ -78,18 +79,23 @@ import io.grpc.internal.TransportTracer; import io.grpc.okhttp.OkHttpClientTransport.ClientFrameHandler; import io.grpc.okhttp.OkHttpFrameLogger.Direction; import io.grpc.okhttp.internal.ConnectionSpec; +import io.grpc.okhttp.internal.Protocol; import io.grpc.okhttp.internal.framed.ErrorCode; import io.grpc.okhttp.internal.framed.FrameReader; import io.grpc.okhttp.internal.framed.FrameWriter; import io.grpc.okhttp.internal.framed.Header; import io.grpc.okhttp.internal.framed.HeadersMode; import io.grpc.okhttp.internal.framed.Settings; +import io.grpc.okhttp.internal.framed.Variant; import io.grpc.testing.TestMethodDescriptors; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.io.OutputStream; +import java.io.PipedInputStream; +import java.io.PipedOutputStream; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.ServerSocket; @@ -116,6 +122,8 @@ import javax.net.SocketFactory; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLSocketFactory; import okio.Buffer; +import okio.BufferedSink; +import okio.BufferedSource; import okio.ByteString; import org.junit.After; import org.junit.Before; @@ -154,8 +162,6 @@ public class OkHttpClientTransportTest { @Rule public final Timeout globalTimeout = Timeout.seconds(10); - private FrameWriter frameWriter; - private MethodDescriptor method = TestMethodDescriptors.voidMethod(); @Mock @@ -167,8 +173,10 @@ public class OkHttpClientTransportTest { private final TransportTracer transportTracer = new TransportTracer(); private final Queue capturedBuffer = new ArrayDeque<>(); private OkHttpClientTransport clientTransport; - private MockFrameReader frameReader; - private Socket socket; + private final MockFrameReader frameReader = new MockFrameReader(); + private final Socket socket = new MockSocket(frameReader); + private FrameWriter frameWriter = mock(FrameWriter.class, AdditionalAnswers.delegatesTo( + new MockFrameWriter(socket, capturedBuffer))); private ExecutorService executor = Executors.newCachedThreadPool(); private long nanoTime; // backs a ticker, for testing ping round-trip time measurement private SettableFuture connectedFuture; @@ -183,10 +191,6 @@ public class OkHttpClientTransportTest { @Before public void setUp() { MockitoAnnotations.initMocks(this); - frameReader = new MockFrameReader(); - socket = new MockSocket(frameReader); - MockFrameWriter mockFrameWriter = new MockFrameWriter(socket, capturedBuffer); - frameWriter = mock(FrameWriter.class, AdditionalAnswers.delegatesTo(mockFrameWriter)); } @After @@ -233,12 +237,9 @@ public class OkHttpClientTransportTest { clientTransport = new OkHttpClientTransport( userAgent, executor, - frameReader, - frameWriter, - new OkHttpFrameLogger(Level.FINE, logger), - startId, - socket, + new FakeSocketFactory(socket), stopwatchSupplier, + new FakeVariant(frameReader, frameWriter), connectingCallback, connectedFuture, maxMessageSize, @@ -249,6 +250,9 @@ public class OkHttpClientTransportTest { if (waitingForConnected) { connectedFuture.get(TIME_OUT_MS, TimeUnit.MILLISECONDS); } + if (startId != DEFAULT_START_STREAM_ID) { + clientTransport.setNextStreamId(startId); + } } @Test @@ -301,6 +305,10 @@ public class OkHttpClientTransportTest { logger.setLevel(Level.ALL); initTransport(); + assertThat(logs).hasSize(1); + LogRecord log = logs.remove(0); + assertThat(log.getMessage()).startsWith(Direction.OUTBOUND + " SETTINGS: ack=false"); + assertThat(log.getLevel()).isEqualTo(Level.FINE); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = @@ -310,7 +318,7 @@ public class OkHttpClientTransportTest { frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); assertThat(logs).hasSize(1); - LogRecord log = logs.remove(0); + log = logs.remove(0); assertThat(log.getMessage()).startsWith(Direction.INBOUND + " HEADERS: streamId=" + 3); assertThat(log.getLevel()).isEqualTo(Level.FINE); @@ -414,7 +422,6 @@ public class OkHttpClientTransportTest { int initialWindowSize = 65535; startTransport( DEFAULT_START_STREAM_ID, null, true, DEFAULT_MAX_MESSAGE_SIZE, initialWindowSize, null); - clientTransport.sendConnectionPrefaceAndSettings(); ArgumentCaptor settings = ArgumentCaptor.forClass(Settings.class); verify(frameWriter, timeout(TIME_OUT_MS)).settings(settings.capture()); @@ -430,7 +437,6 @@ public class OkHttpClientTransportTest { int initialWindowSize = 75535; // 65535 + 10000 startTransport( DEFAULT_START_STREAM_ID, null, true, DEFAULT_MAX_MESSAGE_SIZE, initialWindowSize, null); - clientTransport.sendConnectionPrefaceAndSettings(); ArgumentCaptor settings = ArgumentCaptor.forClass(Settings.class); verify(frameWriter, timeout(TIME_OUT_MS)).settings(settings.capture()); @@ -1697,6 +1703,7 @@ public class OkHttpClientTransportTest { @Test public void writeBeforeConnected() throws Exception { initTransportAndDelayConnected(); + reset(frameWriter); final String message = "Hello Server"; MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = @@ -1722,6 +1729,7 @@ public class OkHttpClientTransportTest { @Test public void cancelBeforeConnected() throws Exception { initTransportAndDelayConnected(); + reset(frameWriter); final String message = "Hello Server"; MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = @@ -2385,10 +2393,20 @@ public class OkHttpClientTransportTest { } private static class MockSocket extends Socket { - MockFrameReader frameReader; + final MockFrameReader frameReader; + private final PipedOutputStream outputStream = new PipedOutputStream(); + private final PipedInputStream outputStreamSink = new PipedInputStream(); + private final PipedOutputStream inputStreamSource = new PipedOutputStream(); + private final PipedInputStream inputStream = new PipedInputStream(); MockSocket(MockFrameReader frameReader) { this.frameReader = frameReader; + try { + outputStreamSink.connect(outputStream); + inputStream.connect(inputStreamSource); + } catch (IOException ex) { + throw new AssertionError(ex); + } } @Override @@ -2400,6 +2418,16 @@ public class OkHttpClientTransportTest { public SocketAddress getLocalSocketAddress() { return InetSocketAddress.createUnresolved("localhost", 4000); } + + @Override + public OutputStream getOutputStream() { + return outputStream; + } + + @Override + public InputStream getInputStream() { + return inputStream; + } } static class PingCallbackImpl implements ClientTransport.PingCallback { @@ -2559,4 +2587,65 @@ public class OkHttpClientTransportTest { throw exception; } } + + static class FakeSocketFactory extends SocketFactory { + private Socket socket; + + public FakeSocketFactory(Socket socket) { + this.socket = Preconditions.checkNotNull(socket, "socket"); + } + + @Override public Socket createSocket() { + Preconditions.checkNotNull(this.socket, "socket"); + Socket socket = this.socket; + this.socket = null; + return socket; + } + + @Override public Socket createSocket(InetAddress host, int port) { + return createSocket(); + } + + @Override public Socket createSocket( + InetAddress host, int port, InetAddress localAddress, int localPort) { + return createSocket(); + } + + @Override public Socket createSocket(String host, int port) { + return createSocket(); + } + + @Override public Socket createSocket( + String host, int port, InetAddress localHost, int localPort) { + return createSocket(); + } + } + + static class FakeVariant implements Variant { + private FrameReader frameReader; + private FrameWriter frameWriter; + + public FakeVariant(FrameReader frameReader, FrameWriter frameWriter) { + this.frameReader = frameReader; + this.frameWriter = frameWriter; + } + + @Override public Protocol getProtocol() { + return Protocol.HTTP_2; + } + + @Override public FrameReader newReader(BufferedSource source, boolean client) { + Preconditions.checkNotNull(this.frameReader, "frameReader"); + FrameReader frameReader = this.frameReader; + this.frameReader = null; + return frameReader; + } + + @Override public FrameWriter newWriter(BufferedSink sink, boolean client) { + Preconditions.checkNotNull(this.frameWriter, "frameWriter"); + FrameWriter frameWriter = this.frameWriter; + this.frameWriter = null; + return frameWriter; + } + } }