benchmarks: Propagate errors in LoadWorkerTest startup

Also clean up resources at the end of test.
This commit is contained in:
Eric Anderson 2022-05-03 16:55:08 -07:00 committed by GitHub
parent 41c027c11b
commit de7db565a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 30 additions and 19 deletions

View File

@ -19,6 +19,7 @@ package io.grpc.benchmarks.driver;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import com.google.common.util.concurrent.SettableFuture;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.benchmarks.Utils; import io.grpc.benchmarks.Utils;
import io.grpc.benchmarks.proto.Control; import io.grpc.benchmarks.proto.Control;
@ -26,10 +27,9 @@ import io.grpc.benchmarks.proto.Stats;
import io.grpc.benchmarks.proto.WorkerServiceGrpc; import io.grpc.benchmarks.proto.WorkerServiceGrpc;
import io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.StreamObserver; import io.grpc.stub.StreamObserver;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
@ -51,6 +51,7 @@ public class LoadWorkerTest {
private ManagedChannel channel; private ManagedChannel channel;
private WorkerServiceGrpc.WorkerServiceStub workerServiceStub; private WorkerServiceGrpc.WorkerServiceStub workerServiceStub;
private LinkedBlockingQueue<Stats.ClientStats> marksQueue; private LinkedBlockingQueue<Stats.ClientStats> marksQueue;
private StreamObserver<Control.ServerArgs> serverLifetime;
@Before @Before
public void setup() throws Exception { public void setup() throws Exception {
@ -62,6 +63,18 @@ public class LoadWorkerTest {
marksQueue = new LinkedBlockingQueue<>(); marksQueue = new LinkedBlockingQueue<>();
} }
@After
public void tearDown() {
if (serverLifetime != null) {
serverLifetime.onCompleted();
}
try {
WorkerServiceGrpc.newBlockingStub(channel).quitWorker(Control.Void.getDefaultInstance());
} finally {
channel.shutdownNow();
}
}
@Test @Test
public void runUnaryBlockingClosedLoop() throws Exception { public void runUnaryBlockingClosedLoop() throws Exception {
Control.ServerArgs.Builder serverArgsBuilder = Control.ServerArgs.newBuilder(); Control.ServerArgs.Builder serverArgsBuilder = Control.ServerArgs.newBuilder();
@ -203,13 +216,13 @@ public class LoadWorkerTest {
} }
private StreamObserver<Control.ClientArgs> startClient(Control.ClientArgs clientArgs) private StreamObserver<Control.ClientArgs> startClient(Control.ClientArgs clientArgs)
throws InterruptedException { throws Exception {
final CountDownLatch clientReady = new CountDownLatch(1); final SettableFuture<Void> clientReady = SettableFuture.create();
StreamObserver<Control.ClientArgs> clientObserver = workerServiceStub.runClient( StreamObserver<Control.ClientArgs> clientObserver = workerServiceStub.runClient(
new StreamObserver<Control.ClientStatus>() { new StreamObserver<Control.ClientStatus>() {
@Override @Override
public void onNext(Control.ClientStatus value) { public void onNext(Control.ClientStatus value) {
clientReady.countDown(); clientReady.set(null);
if (value.hasStats()) { if (value.hasStats()) {
marksQueue.add(value.getStats()); marksQueue.add(value.getStats());
} }
@ -217,45 +230,43 @@ public class LoadWorkerTest {
@Override @Override
public void onError(Throwable t) { public void onError(Throwable t) {
clientReady.setException(t);
} }
@Override @Override
public void onCompleted() { public void onCompleted() {
clientReady.setException(
new RuntimeException("onCompleted() before receiving response"));
} }
}); });
// Start the client // Start the client
clientObserver.onNext(clientArgs); clientObserver.onNext(clientArgs);
if (!clientReady.await(TIMEOUT, TimeUnit.SECONDS)) { clientReady.get(TIMEOUT, TimeUnit.SECONDS);
fail("Client failed to start");
}
return clientObserver; return clientObserver;
} }
private int startServer(Control.ServerArgs serverArgs) throws InterruptedException { private int startServer(Control.ServerArgs serverArgs) throws Exception {
final AtomicInteger serverPort = new AtomicInteger(); final SettableFuture<Integer> port = SettableFuture.create();
final CountDownLatch serverReady = new CountDownLatch(1); serverLifetime =
StreamObserver<Control.ServerArgs> serverObserver =
workerServiceStub.runServer(new StreamObserver<Control.ServerStatus>() { workerServiceStub.runServer(new StreamObserver<Control.ServerStatus>() {
@Override @Override
public void onNext(Control.ServerStatus value) { public void onNext(Control.ServerStatus value) {
serverPort.set(value.getPort()); port.set(value.getPort());
serverReady.countDown();
} }
@Override @Override
public void onError(Throwable t) { public void onError(Throwable t) {
port.setException(t);
} }
@Override @Override
public void onCompleted() { public void onCompleted() {
port.setException(new RuntimeException("onCompleted() before receiving response"));
} }
}); });
// trigger server startup // trigger server startup
serverObserver.onNext(serverArgs); serverLifetime.onNext(serverArgs);
if (!serverReady.await(TIMEOUT, TimeUnit.SECONDS)) { return port.get(TIMEOUT, TimeUnit.SECONDS);
fail("Server failed to start");
}
return serverPort.get();
} }
} }