diff --git a/benchmarks/src/test/java/io/grpc/benchmarks/driver/LoadWorkerTest.java b/benchmarks/src/test/java/io/grpc/benchmarks/driver/LoadWorkerTest.java index ddcd93d0d2..0d40b2f6b5 100644 --- a/benchmarks/src/test/java/io/grpc/benchmarks/driver/LoadWorkerTest.java +++ b/benchmarks/src/test/java/io/grpc/benchmarks/driver/LoadWorkerTest.java @@ -19,6 +19,7 @@ package io.grpc.benchmarks.driver; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import com.google.common.util.concurrent.SettableFuture; import io.grpc.ManagedChannel; import io.grpc.benchmarks.Utils; import io.grpc.benchmarks.proto.Control; @@ -26,10 +27,9 @@ import io.grpc.benchmarks.proto.Stats; import io.grpc.benchmarks.proto.WorkerServiceGrpc; import io.grpc.netty.NettyChannelBuilder; import io.grpc.stub.StreamObserver; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; +import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -51,6 +51,7 @@ public class LoadWorkerTest { private ManagedChannel channel; private WorkerServiceGrpc.WorkerServiceStub workerServiceStub; private LinkedBlockingQueue marksQueue; + private StreamObserver serverLifetime; @Before public void setup() throws Exception { @@ -62,6 +63,18 @@ public class LoadWorkerTest { marksQueue = new LinkedBlockingQueue<>(); } + @After + public void tearDown() { + if (serverLifetime != null) { + serverLifetime.onCompleted(); + } + try { + WorkerServiceGrpc.newBlockingStub(channel).quitWorker(Control.Void.getDefaultInstance()); + } finally { + channel.shutdownNow(); + } + } + @Test public void runUnaryBlockingClosedLoop() throws Exception { Control.ServerArgs.Builder serverArgsBuilder = Control.ServerArgs.newBuilder(); @@ -203,13 +216,13 @@ public class LoadWorkerTest { } private StreamObserver startClient(Control.ClientArgs clientArgs) - throws InterruptedException { - final CountDownLatch clientReady = new CountDownLatch(1); + throws Exception { + final SettableFuture clientReady = SettableFuture.create(); StreamObserver clientObserver = workerServiceStub.runClient( new StreamObserver() { @Override public void onNext(Control.ClientStatus value) { - clientReady.countDown(); + clientReady.set(null); if (value.hasStats()) { marksQueue.add(value.getStats()); } @@ -217,45 +230,43 @@ public class LoadWorkerTest { @Override public void onError(Throwable t) { + clientReady.setException(t); } @Override public void onCompleted() { + clientReady.setException( + new RuntimeException("onCompleted() before receiving response")); } }); // Start the client clientObserver.onNext(clientArgs); - if (!clientReady.await(TIMEOUT, TimeUnit.SECONDS)) { - fail("Client failed to start"); - } + clientReady.get(TIMEOUT, TimeUnit.SECONDS); return clientObserver; } - private int startServer(Control.ServerArgs serverArgs) throws InterruptedException { - final AtomicInteger serverPort = new AtomicInteger(); - final CountDownLatch serverReady = new CountDownLatch(1); - StreamObserver serverObserver = + private int startServer(Control.ServerArgs serverArgs) throws Exception { + final SettableFuture port = SettableFuture.create(); + serverLifetime = workerServiceStub.runServer(new StreamObserver() { @Override public void onNext(Control.ServerStatus value) { - serverPort.set(value.getPort()); - serverReady.countDown(); + port.set(value.getPort()); } @Override public void onError(Throwable t) { + port.setException(t); } @Override public void onCompleted() { + port.setException(new RuntimeException("onCompleted() before receiving response")); } }); // trigger server startup - serverObserver.onNext(serverArgs); - if (!serverReady.await(TIMEOUT, TimeUnit.SECONDS)) { - fail("Server failed to start"); - } - return serverPort.get(); + serverLifetime.onNext(serverArgs); + return port.get(TIMEOUT, TimeUnit.SECONDS); } }