okhttp: Avoid test-specific transport.start()

With the completely different constructor it was hard to track which
fields were different during the test and reduced confidence. Now the
test code flows are much closer to the real-life code flows.
This commit is contained in:
Eric Anderson 2022-04-04 10:54:58 -07:00
parent 004ee10a73
commit a978c9edc0
4 changed files with 191 additions and 102 deletions

View File

@ -39,23 +39,14 @@ final class ExceptionHandlingFrameWriter implements FrameWriter {
private final FrameWriter frameWriter; private final FrameWriter frameWriter;
private final OkHttpFrameLogger frameLogger; private final OkHttpFrameLogger frameLogger =
new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class);
ExceptionHandlingFrameWriter( ExceptionHandlingFrameWriter(
TransportExceptionHandler transportExceptionHandler, FrameWriter frameWriter) { TransportExceptionHandler transportExceptionHandler, FrameWriter frameWriter) {
this(transportExceptionHandler, frameWriter,
new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class));
}
@VisibleForTesting
ExceptionHandlingFrameWriter(
TransportExceptionHandler transportExceptionHandler,
FrameWriter frameWriter,
OkHttpFrameLogger frameLogger) {
this.transportExceptionHandler = this.transportExceptionHandler =
checkNotNull(transportExceptionHandler, "transportExceptionHandler"); checkNotNull(transportExceptionHandler, "transportExceptionHandler");
this.frameWriter = Preconditions.checkNotNull(frameWriter, "frameWriter"); this.frameWriter = Preconditions.checkNotNull(frameWriter, "frameWriter");
this.frameLogger = Preconditions.checkNotNull(frameLogger, "frameLogger");
} }
@Override @Override

View File

@ -148,9 +148,8 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
// Returns new unstarted stopwatches // Returns new unstarted stopwatches
private final Supplier<Stopwatch> stopwatchFactory; private final Supplier<Stopwatch> stopwatchFactory;
private final int initialWindowSize; private final int initialWindowSize;
private final Variant variant;
private Listener listener; private Listener listener;
private FrameReader testFrameReader;
private OkHttpFrameLogger testFrameLogger;
@GuardedBy("lock") @GuardedBy("lock")
private ExceptionHandlingFrameWriter frameWriter; private ExceptionHandlingFrameWriter frameWriter;
private OutboundFlowController outboundFlow; private OutboundFlowController outboundFlow;
@ -192,7 +191,6 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
@GuardedBy("lock") @GuardedBy("lock")
private final Deque<OkHttpClientStream> pendingStreams = new LinkedList<>(); private final Deque<OkHttpClientStream> pendingStreams = new LinkedList<>();
private final ConnectionSpec connectionSpec; private final ConnectionSpec connectionSpec;
private FrameWriter testFrameWriter;
private ScheduledExecutorService scheduler; private ScheduledExecutorService scheduler;
private KeepAliveManager keepAliveManager; private KeepAliveManager keepAliveManager;
private boolean enableKeepAlive; private boolean enableKeepAlive;
@ -228,7 +226,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
Runnable connectingCallback; Runnable connectingCallback;
SettableFuture<Void> connectedFuture; SettableFuture<Void> connectedFuture;
OkHttpClientTransport( public OkHttpClientTransport(
InetSocketAddress address, InetSocketAddress address,
String authority, String authority,
@Nullable String userAgent, @Nullable String userAgent,
@ -245,6 +243,46 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
int maxInboundMetadataSize, int maxInboundMetadataSize,
TransportTracer transportTracer, TransportTracer transportTracer,
boolean useGetForSafeMethods) { boolean useGetForSafeMethods) {
this(
address,
authority,
userAgent,
eagAttrs,
executor,
socketFactory,
sslSocketFactory,
hostnameVerifier,
connectionSpec,
GrpcUtil.STOPWATCH_SUPPLIER,
new Http2(),
maxMessageSize,
initialWindowSize,
proxiedAddr,
tooManyPingsRunnable,
maxInboundMetadataSize,
transportTracer,
useGetForSafeMethods);
}
private OkHttpClientTransport(
InetSocketAddress address,
String authority,
@Nullable String userAgent,
Attributes eagAttrs,
Executor executor,
@Nullable SocketFactory socketFactory,
@Nullable SSLSocketFactory sslSocketFactory,
@Nullable HostnameVerifier hostnameVerifier,
ConnectionSpec connectionSpec,
Supplier<Stopwatch> stopwatchFactory,
Variant variant,
int maxMessageSize,
int initialWindowSize,
@Nullable HttpConnectProxiedSocketAddress proxiedAddr,
Runnable tooManyPingsRunnable,
int maxInboundMetadataSize,
TransportTracer transportTracer,
boolean useGetForSafeMethods) {
this.address = Preconditions.checkNotNull(address, "address"); this.address = Preconditions.checkNotNull(address, "address");
this.defaultAuthority = authority; this.defaultAuthority = authority;
this.maxMessageSize = maxMessageSize; this.maxMessageSize = maxMessageSize;
@ -258,7 +296,8 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
this.sslSocketFactory = sslSocketFactory; this.sslSocketFactory = sslSocketFactory;
this.hostnameVerifier = hostnameVerifier; this.hostnameVerifier = hostnameVerifier;
this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec"); this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec");
this.stopwatchFactory = GrpcUtil.STOPWATCH_SUPPLIER; this.stopwatchFactory = Preconditions.checkNotNull(stopwatchFactory, "stopwatchFactory");
this.variant = Preconditions.checkNotNull(variant, "variant");
this.userAgent = GrpcUtil.getGrpcUserAgent("okhttp", userAgent); this.userAgent = GrpcUtil.getGrpcUserAgent("okhttp", userAgent);
this.proxiedAddr = proxiedAddr; this.proxiedAddr = proxiedAddr;
this.tooManyPingsRunnable = this.tooManyPingsRunnable =
@ -279,43 +318,36 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
OkHttpClientTransport( OkHttpClientTransport(
String userAgent, String userAgent,
Executor executor, Executor executor,
FrameReader frameReader, @Nullable SocketFactory socketFactory,
FrameWriter testFrameWriter,
OkHttpFrameLogger testFrameLogger,
int nextStreamId,
Socket socket,
Supplier<Stopwatch> stopwatchFactory, Supplier<Stopwatch> stopwatchFactory,
Variant variant,
@Nullable Runnable connectingCallback, @Nullable Runnable connectingCallback,
SettableFuture<Void> connectedFuture, SettableFuture<Void> connectedFuture,
int maxMessageSize, int maxMessageSize,
int initialWindowSize, int initialWindowSize,
Runnable tooManyPingsRunnable, Runnable tooManyPingsRunnable,
TransportTracer transportTracer) { TransportTracer transportTracer) {
useGetForSafeMethods = false; this(
address = null; new InetSocketAddress("127.0.0.1", 80),
this.maxMessageSize = maxMessageSize; "notarealauthority:80",
this.initialWindowSize = initialWindowSize; userAgent,
defaultAuthority = "notarealauthority:80"; Attributes.EMPTY,
this.userAgent = GrpcUtil.getGrpcUserAgent("okhttp", userAgent); executor,
this.executor = Preconditions.checkNotNull(executor, "executor"); socketFactory,
serializingExecutor = new SerializingExecutor(executor); null,
this.socketFactory = SocketFactory.getDefault(); null,
this.testFrameReader = Preconditions.checkNotNull(frameReader, "frameReader"); OkHttpChannelBuilder.INTERNAL_DEFAULT_CONNECTION_SPEC,
this.testFrameWriter = Preconditions.checkNotNull(testFrameWriter, "testFrameWriter"); stopwatchFactory,
this.testFrameLogger = Preconditions.checkNotNull(testFrameLogger, "testFrameLogger"); variant,
this.socket = Preconditions.checkNotNull(socket, "socket"); maxMessageSize,
this.nextStreamId = nextStreamId; initialWindowSize,
this.stopwatchFactory = stopwatchFactory; null,
this.connectionSpec = null; tooManyPingsRunnable,
Integer.MAX_VALUE,
transportTracer,
false);
this.connectingCallback = connectingCallback; this.connectingCallback = connectingCallback;
this.connectedFuture = Preconditions.checkNotNull(connectedFuture, "connectedFuture"); this.connectedFuture = Preconditions.checkNotNull(connectedFuture, "connectedFuture");
this.proxiedAddr = null;
this.tooManyPingsRunnable =
Preconditions.checkNotNull(tooManyPingsRunnable, "tooManyPingsRunnable");
this.maxInboundMetadataSize = Integer.MAX_VALUE;
this.transportTracer = Preconditions.checkNotNull(transportTracer, "transportTracer");
this.logId = InternalLogId.allocate(getClass(), String.valueOf(socket.getInetAddress()));
initTransportTracer();
} }
// sslSocketFactory is set to null when use plaintext. // sslSocketFactory is set to null when use plaintext.
@ -349,10 +381,6 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
this.keepAliveWithoutCalls = keepAliveWithoutCalls; this.keepAliveWithoutCalls = keepAliveWithoutCalls;
} }
private boolean isForTest() {
return address == null;
}
@Override @Override
public void ping(final PingCallback callback, Executor executor) { public void ping(final PingCallback callback, Executor executor) {
long data = 0; long data = 0;
@ -488,32 +516,8 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
keepAliveWithoutCalls); keepAliveWithoutCalls);
keepAliveManager.onTransportStarted(); keepAliveManager.onTransportStarted();
} }
if (isForTest()) {
synchronized (lock) {
frameWriter = new ExceptionHandlingFrameWriter(OkHttpClientTransport.this, testFrameWriter,
testFrameLogger);
outboundFlow = new OutboundFlowController(OkHttpClientTransport.this, frameWriter);
}
serializingExecutor.execute(new Runnable() {
@Override
public void run() {
if (connectingCallback != null) {
connectingCallback.run();
}
clientFrameHandler = new ClientFrameHandler(testFrameReader, testFrameLogger);
executor.execute(clientFrameHandler);
synchronized (lock) {
maxConcurrentStreams = Integer.MAX_VALUE;
startPendingStreams();
}
connectedFuture.set(null);
}
});
return null;
}
final AsyncSink asyncSink = AsyncSink.sink(serializingExecutor, this); final AsyncSink asyncSink = AsyncSink.sink(serializingExecutor, this);
final Variant variant = new Http2();
FrameWriter rawFrameWriter = variant.newWriter(Okio.buffer(asyncSink), true); FrameWriter rawFrameWriter = variant.newWriter(Okio.buffer(asyncSink), true);
synchronized (lock) { synchronized (lock) {
@ -616,6 +620,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
serializingExecutor.execute(new Runnable() { serializingExecutor.execute(new Runnable() {
@Override @Override
public void run() { public void run() {
if (connectingCallback != null) {
connectingCallback.run();
}
// ClientFrameHandler need to be started after connectionPreface / settings, otherwise it // ClientFrameHandler need to be started after connectionPreface / settings, otherwise it
// may send goAway immediately. // may send goAway immediately.
executor.execute(clientFrameHandler); executor.execute(clientFrameHandler);
@ -623,6 +630,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
maxConcurrentStreams = Integer.MAX_VALUE; maxConcurrentStreams = Integer.MAX_VALUE;
startPendingStreams(); startPendingStreams();
} }
if (connectedFuture != null) {
connectedFuture.set(null);
}
} }
}); });
return null; return null;
@ -631,8 +641,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
/** /**
* Should only be called once when the transport is first established. * Should only be called once when the transport is first established.
*/ */
@VisibleForTesting private void sendConnectionPrefaceAndSettings() {
void sendConnectionPrefaceAndSettings() {
synchronized (lock) { synchronized (lock) {
frameWriter.connectionPreface(); frameWriter.connectionPreface();
Settings settings = new Settings(); Settings settings = new Settings();
@ -855,6 +864,13 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
} }
} }
@VisibleForTesting
void setNextStreamId(int nextStreamId) {
synchronized (lock) {
this.nextStreamId = nextStreamId;
}
}
/** /**
* Finish all active streams due to an IOException, then close the transport. * Finish all active streams due to an IOException, then close the transport.
*/ */
@ -1081,21 +1097,15 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
/** /**
* Runnable which reads frames and dispatches them to in flight calls. * Runnable which reads frames and dispatches them to in flight calls.
*/ */
@VisibleForTesting
class ClientFrameHandler implements FrameReader.Handler, Runnable { class ClientFrameHandler implements FrameReader.Handler, Runnable {
private final OkHttpFrameLogger logger; private final OkHttpFrameLogger logger =
new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class);
FrameReader frameReader; FrameReader frameReader;
boolean firstSettings = true; boolean firstSettings = true;
ClientFrameHandler(FrameReader frameReader) { ClientFrameHandler(FrameReader frameReader) {
this(frameReader, new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class));
}
@VisibleForTesting
ClientFrameHandler(FrameReader frameReader, OkHttpFrameLogger frameLogger) {
this.frameReader = frameReader; this.frameReader = frameReader;
logger = frameLogger;
} }
@Override @Override

View File

@ -50,8 +50,7 @@ public class ExceptionHandlingFrameWriterTest {
private final TransportExceptionHandler transportExceptionHandler = private final TransportExceptionHandler transportExceptionHandler =
mock(TransportExceptionHandler.class); mock(TransportExceptionHandler.class);
private final ExceptionHandlingFrameWriter exceptionHandlingFrameWriter = private final ExceptionHandlingFrameWriter exceptionHandlingFrameWriter =
new ExceptionHandlingFrameWriter(transportExceptionHandler, mockedFrameWriter, new ExceptionHandlingFrameWriter(transportExceptionHandler, mockedFrameWriter);
new OkHttpFrameLogger(Level.FINE, logger));
@Test @Test
public void exception() throws IOException { public void exception() throws IOException {

View File

@ -48,6 +48,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import com.google.common.base.Preconditions;
import com.google.common.base.Stopwatch; import com.google.common.base.Stopwatch;
import com.google.common.base.Supplier; import com.google.common.base.Supplier;
import com.google.common.base.Ticker; import com.google.common.base.Ticker;
@ -78,18 +79,23 @@ import io.grpc.internal.TransportTracer;
import io.grpc.okhttp.OkHttpClientTransport.ClientFrameHandler; import io.grpc.okhttp.OkHttpClientTransport.ClientFrameHandler;
import io.grpc.okhttp.OkHttpFrameLogger.Direction; import io.grpc.okhttp.OkHttpFrameLogger.Direction;
import io.grpc.okhttp.internal.ConnectionSpec; import io.grpc.okhttp.internal.ConnectionSpec;
import io.grpc.okhttp.internal.Protocol;
import io.grpc.okhttp.internal.framed.ErrorCode; import io.grpc.okhttp.internal.framed.ErrorCode;
import io.grpc.okhttp.internal.framed.FrameReader; import io.grpc.okhttp.internal.framed.FrameReader;
import io.grpc.okhttp.internal.framed.FrameWriter; import io.grpc.okhttp.internal.framed.FrameWriter;
import io.grpc.okhttp.internal.framed.Header; import io.grpc.okhttp.internal.framed.Header;
import io.grpc.okhttp.internal.framed.HeadersMode; import io.grpc.okhttp.internal.framed.HeadersMode;
import io.grpc.okhttp.internal.framed.Settings; import io.grpc.okhttp.internal.framed.Settings;
import io.grpc.okhttp.internal.framed.Variant;
import io.grpc.testing.TestMethodDescriptors; import io.grpc.testing.TestMethodDescriptors;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.ServerSocket; import java.net.ServerSocket;
@ -116,6 +122,8 @@ import javax.net.SocketFactory;
import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.SSLSocketFactory;
import okio.Buffer; import okio.Buffer;
import okio.BufferedSink;
import okio.BufferedSource;
import okio.ByteString; import okio.ByteString;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
@ -154,8 +162,6 @@ public class OkHttpClientTransportTest {
@Rule public final Timeout globalTimeout = Timeout.seconds(10); @Rule public final Timeout globalTimeout = Timeout.seconds(10);
private FrameWriter frameWriter;
private MethodDescriptor<Void, Void> method = TestMethodDescriptors.voidMethod(); private MethodDescriptor<Void, Void> method = TestMethodDescriptors.voidMethod();
@Mock @Mock
@ -167,8 +173,10 @@ public class OkHttpClientTransportTest {
private final TransportTracer transportTracer = new TransportTracer(); private final TransportTracer transportTracer = new TransportTracer();
private final Queue<Buffer> capturedBuffer = new ArrayDeque<>(); private final Queue<Buffer> capturedBuffer = new ArrayDeque<>();
private OkHttpClientTransport clientTransport; private OkHttpClientTransport clientTransport;
private MockFrameReader frameReader; private final MockFrameReader frameReader = new MockFrameReader();
private Socket socket; private final Socket socket = new MockSocket(frameReader);
private FrameWriter frameWriter = mock(FrameWriter.class, AdditionalAnswers.delegatesTo(
new MockFrameWriter(socket, capturedBuffer)));
private ExecutorService executor = Executors.newCachedThreadPool(); private ExecutorService executor = Executors.newCachedThreadPool();
private long nanoTime; // backs a ticker, for testing ping round-trip time measurement private long nanoTime; // backs a ticker, for testing ping round-trip time measurement
private SettableFuture<Void> connectedFuture; private SettableFuture<Void> connectedFuture;
@ -183,10 +191,6 @@ public class OkHttpClientTransportTest {
@Before @Before
public void setUp() { public void setUp() {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
frameReader = new MockFrameReader();
socket = new MockSocket(frameReader);
MockFrameWriter mockFrameWriter = new MockFrameWriter(socket, capturedBuffer);
frameWriter = mock(FrameWriter.class, AdditionalAnswers.delegatesTo(mockFrameWriter));
} }
@After @After
@ -233,12 +237,9 @@ public class OkHttpClientTransportTest {
clientTransport = new OkHttpClientTransport( clientTransport = new OkHttpClientTransport(
userAgent, userAgent,
executor, executor,
frameReader, new FakeSocketFactory(socket),
frameWriter,
new OkHttpFrameLogger(Level.FINE, logger),
startId,
socket,
stopwatchSupplier, stopwatchSupplier,
new FakeVariant(frameReader, frameWriter),
connectingCallback, connectingCallback,
connectedFuture, connectedFuture,
maxMessageSize, maxMessageSize,
@ -249,6 +250,9 @@ public class OkHttpClientTransportTest {
if (waitingForConnected) { if (waitingForConnected) {
connectedFuture.get(TIME_OUT_MS, TimeUnit.MILLISECONDS); connectedFuture.get(TIME_OUT_MS, TimeUnit.MILLISECONDS);
} }
if (startId != DEFAULT_START_STREAM_ID) {
clientTransport.setNextStreamId(startId);
}
} }
@Test @Test
@ -301,6 +305,10 @@ public class OkHttpClientTransportTest {
logger.setLevel(Level.ALL); logger.setLevel(Level.ALL);
initTransport(); initTransport();
assertThat(logs).hasSize(1);
LogRecord log = logs.remove(0);
assertThat(log.getMessage()).startsWith(Direction.OUTBOUND + " SETTINGS: ack=false");
assertThat(log.getLevel()).isEqualTo(Level.FINE);
MockStreamListener listener = new MockStreamListener(); MockStreamListener listener = new MockStreamListener();
OkHttpClientStream stream = OkHttpClientStream stream =
@ -310,7 +318,7 @@ public class OkHttpClientTransportTest {
frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS);
assertThat(logs).hasSize(1); assertThat(logs).hasSize(1);
LogRecord log = logs.remove(0); log = logs.remove(0);
assertThat(log.getMessage()).startsWith(Direction.INBOUND + " HEADERS: streamId=" + 3); assertThat(log.getMessage()).startsWith(Direction.INBOUND + " HEADERS: streamId=" + 3);
assertThat(log.getLevel()).isEqualTo(Level.FINE); assertThat(log.getLevel()).isEqualTo(Level.FINE);
@ -414,7 +422,6 @@ public class OkHttpClientTransportTest {
int initialWindowSize = 65535; int initialWindowSize = 65535;
startTransport( startTransport(
DEFAULT_START_STREAM_ID, null, true, DEFAULT_MAX_MESSAGE_SIZE, initialWindowSize, null); DEFAULT_START_STREAM_ID, null, true, DEFAULT_MAX_MESSAGE_SIZE, initialWindowSize, null);
clientTransport.sendConnectionPrefaceAndSettings();
ArgumentCaptor<Settings> settings = ArgumentCaptor.forClass(Settings.class); ArgumentCaptor<Settings> settings = ArgumentCaptor.forClass(Settings.class);
verify(frameWriter, timeout(TIME_OUT_MS)).settings(settings.capture()); verify(frameWriter, timeout(TIME_OUT_MS)).settings(settings.capture());
@ -430,7 +437,6 @@ public class OkHttpClientTransportTest {
int initialWindowSize = 75535; // 65535 + 10000 int initialWindowSize = 75535; // 65535 + 10000
startTransport( startTransport(
DEFAULT_START_STREAM_ID, null, true, DEFAULT_MAX_MESSAGE_SIZE, initialWindowSize, null); DEFAULT_START_STREAM_ID, null, true, DEFAULT_MAX_MESSAGE_SIZE, initialWindowSize, null);
clientTransport.sendConnectionPrefaceAndSettings();
ArgumentCaptor<Settings> settings = ArgumentCaptor.forClass(Settings.class); ArgumentCaptor<Settings> settings = ArgumentCaptor.forClass(Settings.class);
verify(frameWriter, timeout(TIME_OUT_MS)).settings(settings.capture()); verify(frameWriter, timeout(TIME_OUT_MS)).settings(settings.capture());
@ -1697,6 +1703,7 @@ public class OkHttpClientTransportTest {
@Test @Test
public void writeBeforeConnected() throws Exception { public void writeBeforeConnected() throws Exception {
initTransportAndDelayConnected(); initTransportAndDelayConnected();
reset(frameWriter);
final String message = "Hello Server"; final String message = "Hello Server";
MockStreamListener listener = new MockStreamListener(); MockStreamListener listener = new MockStreamListener();
OkHttpClientStream stream = OkHttpClientStream stream =
@ -1722,6 +1729,7 @@ public class OkHttpClientTransportTest {
@Test @Test
public void cancelBeforeConnected() throws Exception { public void cancelBeforeConnected() throws Exception {
initTransportAndDelayConnected(); initTransportAndDelayConnected();
reset(frameWriter);
final String message = "Hello Server"; final String message = "Hello Server";
MockStreamListener listener = new MockStreamListener(); MockStreamListener listener = new MockStreamListener();
OkHttpClientStream stream = OkHttpClientStream stream =
@ -2385,10 +2393,20 @@ public class OkHttpClientTransportTest {
} }
private static class MockSocket extends Socket { private static class MockSocket extends Socket {
MockFrameReader frameReader; final MockFrameReader frameReader;
private final PipedOutputStream outputStream = new PipedOutputStream();
private final PipedInputStream outputStreamSink = new PipedInputStream();
private final PipedOutputStream inputStreamSource = new PipedOutputStream();
private final PipedInputStream inputStream = new PipedInputStream();
MockSocket(MockFrameReader frameReader) { MockSocket(MockFrameReader frameReader) {
this.frameReader = frameReader; this.frameReader = frameReader;
try {
outputStreamSink.connect(outputStream);
inputStream.connect(inputStreamSource);
} catch (IOException ex) {
throw new AssertionError(ex);
}
} }
@Override @Override
@ -2400,6 +2418,16 @@ public class OkHttpClientTransportTest {
public SocketAddress getLocalSocketAddress() { public SocketAddress getLocalSocketAddress() {
return InetSocketAddress.createUnresolved("localhost", 4000); return InetSocketAddress.createUnresolved("localhost", 4000);
} }
@Override
public OutputStream getOutputStream() {
return outputStream;
}
@Override
public InputStream getInputStream() {
return inputStream;
}
} }
static class PingCallbackImpl implements ClientTransport.PingCallback { static class PingCallbackImpl implements ClientTransport.PingCallback {
@ -2559,4 +2587,65 @@ public class OkHttpClientTransportTest {
throw exception; throw exception;
} }
} }
static class FakeSocketFactory extends SocketFactory {
private Socket socket;
public FakeSocketFactory(Socket socket) {
this.socket = Preconditions.checkNotNull(socket, "socket");
}
@Override public Socket createSocket() {
Preconditions.checkNotNull(this.socket, "socket");
Socket socket = this.socket;
this.socket = null;
return socket;
}
@Override public Socket createSocket(InetAddress host, int port) {
return createSocket();
}
@Override public Socket createSocket(
InetAddress host, int port, InetAddress localAddress, int localPort) {
return createSocket();
}
@Override public Socket createSocket(String host, int port) {
return createSocket();
}
@Override public Socket createSocket(
String host, int port, InetAddress localHost, int localPort) {
return createSocket();
}
}
static class FakeVariant implements Variant {
private FrameReader frameReader;
private FrameWriter frameWriter;
public FakeVariant(FrameReader frameReader, FrameWriter frameWriter) {
this.frameReader = frameReader;
this.frameWriter = frameWriter;
}
@Override public Protocol getProtocol() {
return Protocol.HTTP_2;
}
@Override public FrameReader newReader(BufferedSource source, boolean client) {
Preconditions.checkNotNull(this.frameReader, "frameReader");
FrameReader frameReader = this.frameReader;
this.frameReader = null;
return frameReader;
}
@Override public FrameWriter newWriter(BufferedSink sink, boolean client) {
Preconditions.checkNotNull(this.frameWriter, "frameWriter");
FrameWriter frameWriter = this.frameWriter;
this.frameWriter = null;
return frameWriter;
}
}
} }