mirror of https://github.com/grpc/grpc-java.git
Cancel server context when call is cancelled
Context was being closed when the server was closing the RPC, but not if the client cancelled.
This commit is contained in:
parent
6e94cf33db
commit
e475d388b9
|
@ -184,18 +184,10 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<RespT> {
|
|||
|
||||
@Override
|
||||
public void close(Status status, Metadata trailers) {
|
||||
try {
|
||||
checkState(!closeCalled, "call already closed");
|
||||
closeCalled = true;
|
||||
inboundHeaders = null;
|
||||
stream.close(status, trailers);
|
||||
} finally {
|
||||
if (status.getCode() == Status.Code.OK) {
|
||||
context.cancel(null);
|
||||
} else {
|
||||
context.cancel(status.getCause() != null ? status.getCause() : status.asRuntimeException());
|
||||
}
|
||||
}
|
||||
checkState(!closeCalled, "call already closed");
|
||||
closeCalled = true;
|
||||
inboundHeaders = null;
|
||||
stream.close(status, trailers);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -205,7 +197,7 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<RespT> {
|
|||
|
||||
ServerStreamListener newServerStreamListener(ServerCall.Listener<ReqT> listener,
|
||||
Future<?> timeout) {
|
||||
return new ServerStreamListenerImpl<ReqT>(this, listener, timeout);
|
||||
return new ServerStreamListenerImpl<ReqT>(this, listener, timeout, context);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -217,13 +209,16 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<RespT> {
|
|||
private final ServerCallImpl<ReqT, ?> call;
|
||||
private final ServerCall.Listener<ReqT> listener;
|
||||
private final Future<?> timeout;
|
||||
private final Context.CancellableContext context;
|
||||
private boolean messageReceived;
|
||||
|
||||
public ServerStreamListenerImpl(
|
||||
ServerCallImpl<ReqT, ?> call, ServerCall.Listener<ReqT> listener, Future<?> timeout) {
|
||||
ServerCallImpl<ReqT, ?> call, ServerCall.Listener<ReqT> listener, Future<?> timeout,
|
||||
Context.CancellableContext context) {
|
||||
this.call = checkNotNull(call, "call");
|
||||
this.listener = checkNotNull(listener, "listener must not be null");
|
||||
this.timeout = checkNotNull(timeout, "timeout");
|
||||
this.context = checkNotNull(context, "context");
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -263,11 +258,17 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<RespT> {
|
|||
@Override
|
||||
public void closed(Status status) {
|
||||
timeout.cancel(true);
|
||||
if (status.isOk()) {
|
||||
listener.onComplete();
|
||||
} else {
|
||||
call.cancelled = true;
|
||||
listener.onCancel();
|
||||
try {
|
||||
if (status.isOk()) {
|
||||
listener.onComplete();
|
||||
} else {
|
||||
call.cancelled = true;
|
||||
listener.onCancel();
|
||||
}
|
||||
} finally {
|
||||
// Cancel context after delivering RPC closure notification to allow the application to
|
||||
// clean up and update any state based on whether onComplete or onCancel was called.
|
||||
context.cancel(null);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -32,7 +32,6 @@
|
|||
package io.grpc.internal;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.fail;
|
||||
|
@ -180,20 +179,6 @@ public class ServerCallImplTest {
|
|||
assertTrue(call.isReady());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void closeWithOkCancelsContextWithNoCause() {
|
||||
call.close(Status.OK, new Metadata());
|
||||
assertTrue(context.isCancelled());
|
||||
assertNull(context.cause());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void closeWithErrorCancelsContextWithCause() {
|
||||
call.close(Status.CANCELLED, new Metadata());
|
||||
assertTrue(context.isCancelled());
|
||||
assertNotNull(context.cause());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setMessageCompression() {
|
||||
call.setMessageCompression(true);
|
||||
|
@ -204,7 +189,7 @@ public class ServerCallImplTest {
|
|||
@Test
|
||||
public void streamListener_halfClosed() {
|
||||
ServerStreamListenerImpl<Long> streamListener =
|
||||
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, timeout);
|
||||
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, timeout, context);
|
||||
|
||||
streamListener.halfClosed();
|
||||
|
||||
|
@ -214,7 +199,7 @@ public class ServerCallImplTest {
|
|||
@Test
|
||||
public void streamListener_halfClosed_onlyOnce() {
|
||||
ServerStreamListenerImpl<Long> streamListener =
|
||||
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, timeout);
|
||||
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, timeout, context);
|
||||
streamListener.halfClosed();
|
||||
// canceling the call should short circuit future halfClosed() calls.
|
||||
streamListener.closed(Status.CANCELLED);
|
||||
|
@ -227,29 +212,33 @@ public class ServerCallImplTest {
|
|||
@Test
|
||||
public void streamListener_closedOk() {
|
||||
ServerStreamListenerImpl<Long> streamListener =
|
||||
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, timeout);
|
||||
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, timeout, context);
|
||||
|
||||
streamListener.closed(Status.OK);
|
||||
|
||||
verify(callListener).onComplete();
|
||||
assertTrue(timeout.isCancelled());
|
||||
assertTrue(context.isCancelled());
|
||||
assertNull(context.cause());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void streamListener_closedCancelled() {
|
||||
ServerStreamListenerImpl<Long> streamListener =
|
||||
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, timeout);
|
||||
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, timeout, context);
|
||||
|
||||
streamListener.closed(Status.CANCELLED);
|
||||
|
||||
verify(callListener).onCancel();
|
||||
assertTrue(timeout.isCancelled());
|
||||
assertTrue(context.isCancelled());
|
||||
assertNull(context.cause());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void streamListener_onReady() {
|
||||
ServerStreamListenerImpl<Long> streamListener =
|
||||
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, timeout);
|
||||
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, timeout, context);
|
||||
|
||||
streamListener.onReady();
|
||||
|
||||
|
@ -259,7 +248,7 @@ public class ServerCallImplTest {
|
|||
@Test
|
||||
public void streamListener_onReady_onlyOnce() {
|
||||
ServerStreamListenerImpl<Long> streamListener =
|
||||
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, timeout);
|
||||
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, timeout, context);
|
||||
streamListener.onReady();
|
||||
// canceling the call should short circuit future halfClosed() calls.
|
||||
streamListener.closed(Status.CANCELLED);
|
||||
|
@ -272,7 +261,7 @@ public class ServerCallImplTest {
|
|||
@Test
|
||||
public void streamListener_messageRead() {
|
||||
ServerStreamListenerImpl<Long> streamListener =
|
||||
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, timeout);
|
||||
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, timeout, context);
|
||||
streamListener.messageRead(method.streamRequest(1234L));
|
||||
|
||||
verify(callListener).onMessage(1234L);
|
||||
|
@ -281,7 +270,7 @@ public class ServerCallImplTest {
|
|||
@Test
|
||||
public void streamListener_messageRead_unaryFailsOnMultiple() {
|
||||
ServerStreamListenerImpl<Long> streamListener =
|
||||
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, timeout);
|
||||
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, timeout, context);
|
||||
streamListener.messageRead(method.streamRequest(1234L));
|
||||
streamListener.messageRead(method.streamRequest(1234L));
|
||||
|
||||
|
@ -295,7 +284,7 @@ public class ServerCallImplTest {
|
|||
@Test
|
||||
public void streamListener_messageRead_onlyOnce() {
|
||||
ServerStreamListenerImpl<Long> streamListener =
|
||||
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, timeout);
|
||||
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, timeout, context);
|
||||
streamListener.messageRead(method.streamRequest(1234L));
|
||||
// canceling the call should short circuit future halfClosed() calls.
|
||||
streamListener.closed(Status.CANCELLED);
|
||||
|
|
Loading…
Reference in New Issue