okhttp: Use a real socket during server transport testing

The PipeSocket was convenient and avoided real I/O, but the
shutdown/close while connecting/handshaking tests were triggering a
Socket bug in Java (https://bugs.openjdk.org/browse/JDK-8278326). Using
a real socket doesn't trigger the bug because the test stops sharing
state with the code under test.

Fixes #10228

```
Details
==================
WARNING: ThreadSanitizer: data race (pid=4528)
  Write of size 1 at 0x0000cfb9d5f4 by thread T36 (mutexes: write M0):
    #0 java.net.Socket.setCreated()V Socket.java:687
    #1 java.net.AbstractPlainSocketImpl.create(Z)V AbstractPlainSocketImpl.java:149
    #2 java.net.Socket.createImpl(Z)V Socket.java:477
    #3 java.net.Socket.getImpl()Ljava/net/SocketImpl; Socket.java:540
    #4 java.net.Socket.setTcpNoDelay(Z)V Socket.java:998
    #5 io.grpc.okhttp.OkHttpServerTransport.startIo(Lio/grpc/internal/SerializingExecutor;)V OkHttpServerTransport.java:164
    #6 io.grpc.okhttp.OkHttpServerTransport.lambda$start$0(Lio/grpc/internal/SerializingExecutor;)V OkHttpServerTransport.java:159
    #7 io.grpc.okhttp.OkHttpServerTransport$$Lambda$56.run()V ??
    #8 io.grpc.internal.SerializingExecutor.run()V SerializingExecutor.java:133
    #9 java.util.concurrent.ThreadPoolExecutor.runWorker(Ljava/util/concurrent/ThreadPoolExecutor$Worker;)V ThreadPoolExecutor.java:1130
    #10 java.util.concurrent.ThreadPoolExecutor$Worker.run()V ThreadPoolExecutor.java:630
    #11 java.lang.Thread.run()V Thread.java:830
    #12 (Generated Stub) <null>

  Previous read of size 1 at 0x0000cfb9d5f4 by thread T35 (mutexes: write M1, write M2):
    #0 java.net.Socket.close()V Socket.java:1512
    #1 io.grpc.okhttp.OkHttpServerTransportTest$PipeSocket.close()V OkHttpServerTransportTest.java:1384
    #2 io.grpc.okhttp.OkHttpServerTransportTest.clientCloseDuringHandshake()V OkHttpServerTransportTest.java:290
```
This commit is contained in:
Eric Anderson 2023-06-13 16:49:23 -07:00
parent 4d3a29b2af
commit d4e26cc689
1 changed files with 47 additions and 61 deletions

View File

@ -56,11 +56,9 @@ import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketAddress;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Deque;
@ -69,6 +67,7 @@ import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import okio.Buffer;
import okio.BufferedSource;
@ -96,14 +95,14 @@ public class OkHttpServerTransportTest {
private ServerTransportListener transportListener
= mock(ServerTransportListener.class, delegatesTo(mockTransportListener));
private OkHttpServerTransport serverTransport;
private final PipeSocket socket = new PipeSocket();
private final ExecutorService threadPool = Executors.newCachedThreadPool();
private final SocketPair socketPair = SocketPair.create(threadPool);
private final FrameWriter clientFrameWriter
= new Http2().newWriter(Okio.buffer(Okio.sink(socket.inputStreamSource)), true);
= new Http2().newWriter(Okio.buffer(Okio.sink(socketPair.getClientOutputStream())), true);
private final FrameReader clientFrameReader
= new Http2().newReader(Okio.buffer(Okio.source(socket.outputStreamSink)), true);
= new Http2().newReader(Okio.buffer(Okio.source(socketPair.getClientInputStream())), true);
private final FrameReader.Handler clientFramesRead = mock(FrameReader.Handler.class);
private final DataFrameHandler clientDataFrames = mock(DataFrameHandler.class);
private ExecutorService threadPool = Executors.newCachedThreadPool();
private HandshakerSocketFactory handshakerSocketFactory
= mock(HandshakerSocketFactory.class, delegatesTo(new PlaintextHandshakerSocketFactory()));
private final FakeClock fakeClock = new FakeClock();
@ -142,7 +141,11 @@ public class OkHttpServerTransportTest {
@After
public void tearDown() throws Exception {
threadPool.shutdownNow();
socket.closeSourceAndSink();
try {
socketPair.client.close();
} finally {
socketPair.server.close();
}
}
@Test
@ -172,7 +175,7 @@ public class OkHttpServerTransportTest {
verifyGracefulShutdown(1);
pingPong();
fakeClock.forwardNanos(TimeUnit.SECONDS.toNanos(3));
assertThat(socket.isClosed()).isTrue();
assertThat(socketPair.server.isClosed()).isTrue();
}
@Test
@ -254,7 +257,7 @@ public class OkHttpServerTransportTest {
@Test
public void shutdownDuringHandshake() throws Exception {
doAnswer(invocation -> {
socket.getInputStream().read();
((Socket) invocation.getArguments()[0]).getInputStream().read();
throw new IOException("handshake purposefully failed");
}).when(handshakerSocketFactory).handshake(any(Socket.class), any(Attributes.class));
serverBuilder.transportExecutor(threadPool);
@ -268,7 +271,7 @@ public class OkHttpServerTransportTest {
@Test
public void shutdownNowDuringHandshake() throws Exception {
doAnswer(invocation -> {
socket.getInputStream().read();
((Socket) invocation.getArguments()[0]).getInputStream().read();
throw new IOException("handshake purposefully failed");
}).when(handshakerSocketFactory).handshake(any(Socket.class), any(Attributes.class));
serverBuilder.transportExecutor(threadPool);
@ -282,12 +285,12 @@ public class OkHttpServerTransportTest {
@Test
public void clientCloseDuringHandshake() throws Exception {
doAnswer(invocation -> {
socket.getInputStream().read();
((Socket) invocation.getArguments()[0]).getInputStream().read();
throw new IOException("handshake purposefully failed");
}).when(handshakerSocketFactory).handshake(any(Socket.class), any(Attributes.class));
serverBuilder.transportExecutor(threadPool);
initTransport();
socket.close();
socketPair.client.close();
verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated();
verify(transportListener, never()).transportReady(any(Attributes.class));
@ -296,7 +299,7 @@ public class OkHttpServerTransportTest {
@Test
public void closeDuringHttp2Preface() throws Exception {
initTransport();
socket.close();
socketPair.client.close();
verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated();
verify(transportListener, never()).transportReady(any(Attributes.class));
@ -307,7 +310,7 @@ public class OkHttpServerTransportTest {
initTransport();
clientFrameWriter.connectionPreface();
clientFrameWriter.flush();
socket.close();
socketPair.client.close();
verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated();
verify(transportListener, never()).transportReady(any(Attributes.class));
@ -329,7 +332,7 @@ public class OkHttpServerTransportTest {
initTransport();
handshake();
socket.closeSourceAndSink();
socketPair.client.close();
verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated();
}
@ -1086,8 +1089,8 @@ public class OkHttpServerTransportTest {
assertThat(stats.data.messagesReceived).isEqualTo(0);
assertThat(stats.data.remoteFlowControlWindow).isEqualTo(30000); // Lower bound
assertThat(stats.data.localFlowControlWindow).isEqualTo(66535);
assertThat(stats.local).isEqualTo(new InetSocketAddress("127.0.0.1", 4000));
assertThat(stats.remote).isEqualTo(new InetSocketAddress("127.0.0.2", 5000));
assertThat(stats.local).isEqualTo(socketPair.server.getLocalSocketAddress());
assertThat(stats.remote).isEqualTo(socketPair.server.getRemoteSocketAddress());
}
@Test
@ -1188,7 +1191,7 @@ public class OkHttpServerTransportTest {
private void initTransport() throws Exception {
serverTransport = new OkHttpServerTransport(
new OkHttpServerTransport.Config(serverBuilder, Arrays.asList()),
socket);
socketPair.server);
serverTransport.start(transportListener);
}
@ -1357,61 +1360,44 @@ public class OkHttpServerTransportTest {
}
}
private static class PipeSocket extends Socket {
private final PipedOutputStream outputStream = new PipedOutputStream();
private final PipedInputStream outputStreamSink = new PipedInputStream();
private final PipedOutputStream inputStreamSource = new PipedOutputStream();
private final PipedInputStream inputStream = new PipedInputStream();
private static class SocketPair {
public final Socket client;
public final Socket server;
public PipeSocket() {
public SocketPair(Socket client, Socket server) {
this.client = client;
this.server = server;
}
public InputStream getClientInputStream() {
try {
outputStreamSink.connect(outputStream);
inputStream.connect(inputStreamSource);
return client.getInputStream();
} catch (IOException ex) {
throw new AssertionError(ex);
throw new RuntimeException(ex);
}
}
@Override
public synchronized void close() throws IOException {
public OutputStream getClientOutputStream() {
try {
outputStream.close();
} finally {
inputStream.close();
// PipedInputStream can only be woken by PipedOutputStream, so PipedOutputStream.close() is
// a better imitation of Socket.close().
inputStreamSource.close();
super.close();
return client.getOutputStream();
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}
public void closeSourceAndSink() throws IOException {
public static SocketPair create(ExecutorService threadPool) {
try {
outputStreamSink.close();
} finally {
inputStreamSource.close();
try (ServerSocket serverSocket = new ServerSocket(0)) {
Future<Socket> serverFuture = threadPool.submit(() -> serverSocket.accept());
Socket client = new Socket();
client.connect(serverSocket.getLocalSocketAddress());
Socket server = serverFuture.get();
return new SocketPair(client, server);
}
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
@Override
public SocketAddress getLocalSocketAddress() {
return new InetSocketAddress("127.0.0.1", 4000);
}
@Override
public SocketAddress getRemoteSocketAddress() {
return new InetSocketAddress("127.0.0.2", 5000);
}
@Override
public OutputStream getOutputStream() {
return outputStream;
}
@Override
public InputStream getInputStream() {
return inputStream;
}
}
private interface DataFrameHandler {