diff --git a/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java b/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java index af5f9eed75..52122c5e8b 100644 --- a/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java +++ b/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java @@ -166,7 +166,7 @@ public final class BinderServerBuilder ObjectPool executorPool = serverImplBuilder.getExecutorPool(); Executor executor = executorPool.getObject(); BinderTransportSecurity.installAuthInterceptor(this, executor); - internalBuilder.setShutdownListener(() -> executorPool.returnObject(executor)); + internalBuilder.setTerminationListener(() -> executorPool.returnObject(executor)); return super.build(); } } diff --git a/binder/src/main/java/io/grpc/binder/internal/ActiveTransportTracker.java b/binder/src/main/java/io/grpc/binder/internal/ActiveTransportTracker.java new file mode 100644 index 0000000000..ad41018648 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/ActiveTransportTracker.java @@ -0,0 +1,110 @@ +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkState; + +import io.grpc.Attributes; +import io.grpc.Metadata; +import io.grpc.internal.ServerListener; +import io.grpc.internal.ServerStream; +import io.grpc.internal.ServerTransport; +import io.grpc.internal.ServerTransportListener; +import javax.annotation.concurrent.GuardedBy; + +/** + * Tracks which {@link BinderTransport.BinderServerTransport} are currently active and allows + * invoking a {@link Runnable} only once all transports are terminated. + */ +final class ActiveTransportTracker implements ServerListener { + private final ServerListener delegate; + private final Runnable terminationListener; + + @GuardedBy("this") + private boolean shutdown = false; + + @GuardedBy("this") + private int activeTransportCount = 0; + + /** + * @param delegate the original server listener that this object decorates. Usually passed to + * {@link BinderServer#start(ServerListener)}. + * @param terminationListener invoked only once the server has started shutdown ({@link + * #serverShutdown()} AND the last active transport is terminated. + */ + ActiveTransportTracker(ServerListener delegate, Runnable terminationListener) { + this.delegate = delegate; + this.terminationListener = terminationListener; + } + + @Override + public ServerTransportListener transportCreated(ServerTransport transport) { + synchronized (this) { + checkState(!shutdown, "Illegal transportCreated() after serverShutdown()"); + activeTransportCount++; + } + ServerTransportListener originalListener = delegate.transportCreated(transport); + return new TrackedTransportListener(originalListener); + } + + private void untrack() { + Runnable maybeTerminationListener; + synchronized (this) { + activeTransportCount--; + maybeTerminationListener = getListenerIfTerminated(); + } + // Prefer running the listener outside of the synchronization lock to release it sooner, since + // we don't know how the callback is implemented nor how long it will take. This should + // minimize the possibility of deadlocks. + if (maybeTerminationListener != null) { + maybeTerminationListener.run(); + } + } + + @Override + public void serverShutdown() { + delegate.serverShutdown(); + Runnable maybeTerminationListener; + synchronized (this) { + shutdown = true; + maybeTerminationListener = getListenerIfTerminated(); + } + // We may be able to shutdown immediately if there are no active transports. + // + // Executed outside of the lock. See "untrack()" above. + if (maybeTerminationListener != null) { + maybeTerminationListener.run(); + } + } + + @GuardedBy("this") + private Runnable getListenerIfTerminated() { + return (shutdown && activeTransportCount == 0) ? terminationListener : null; + } + + /** + * Wraps a {@link ServerTransportListener}, unregistering it from the parent tracker once the + * transport terminates. + */ + private final class TrackedTransportListener implements ServerTransportListener { + private final ServerTransportListener delegate; + + TrackedTransportListener(ServerTransportListener delegate) { + this.delegate = delegate; + } + + @Override + public void streamCreated(ServerStream stream, String method, Metadata headers) { + delegate.streamCreated(stream, method, headers); + } + + @Override + public Attributes transportReady(Attributes attributes) { + return delegate.transportReady(attributes); + } + + @Override + public void transportTerminated() { + delegate.transportTerminated(); + untrack(); + } + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderServer.java b/binder/src/main/java/io/grpc/binder/internal/BinderServer.java index 260410b75d..d3580dbd13 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderServer.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderServer.java @@ -68,7 +68,7 @@ public final class BinderServer implements InternalServer, LeakSafeOneWayBinder. private final LeakSafeOneWayBinder hostServiceBinder; private final BinderTransportSecurity.ServerPolicyChecker serverPolicyChecker; private final InboundParcelablePolicy inboundParcelablePolicy; - private final BinderTransportSecurity.ShutdownListener transportSecurityShutdownListener; + private final Runnable terminationListener; @GuardedBy("this") private ServerListener listener; @@ -86,7 +86,7 @@ public final class BinderServer implements InternalServer, LeakSafeOneWayBinder. ImmutableList.copyOf(checkNotNull(builder.streamTracerFactories, "streamTracerFactories")); this.serverPolicyChecker = BinderInternal.createPolicyChecker(builder.serverSecurityPolicy); this.inboundParcelablePolicy = builder.inboundParcelablePolicy; - this.transportSecurityShutdownListener = builder.shutdownListener; + this.terminationListener = builder.terminationListener; hostServiceBinder = new LeakSafeOneWayBinder(this); } @@ -97,7 +97,7 @@ public final class BinderServer implements InternalServer, LeakSafeOneWayBinder. @Override public synchronized void start(ServerListener serverListener) throws IOException { - this.listener = serverListener; + listener = new ActiveTransportTracker(serverListener, terminationListener); executorService = executorServicePool.getObject(); } @@ -130,7 +130,6 @@ public final class BinderServer implements InternalServer, LeakSafeOneWayBinder. hostServiceBinder.setHandler(GoAwayHandler.INSTANCE); listener.serverShutdown(); executorService = executorServicePool.returnObject(executorService); - transportSecurityShutdownListener.onServerShutdown(); } } @@ -208,7 +207,7 @@ public final class BinderServer implements InternalServer, LeakSafeOneWayBinder. SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); ServerSecurityPolicy serverSecurityPolicy = SecurityPolicies.serverInternalOnly(); InboundParcelablePolicy inboundParcelablePolicy = InboundParcelablePolicy.DEFAULT; - BinderTransportSecurity.ShutdownListener shutdownListener = () -> {}; + Runnable terminationListener = () -> {}; public BinderServer build() { return new BinderServer(this); @@ -269,12 +268,13 @@ public final class BinderServer implements InternalServer, LeakSafeOneWayBinder. } /** - * Installs a callback that will be invoked when this server is {@link #shutdown()} + * Installs a callback that will be invoked when this server is {@link #shutdown()} and all of + * its transports are terminated. * *

Optional. */ - public Builder setShutdownListener(BinderTransportSecurity.ShutdownListener shutdownListener) { - this.shutdownListener = checkNotNull(shutdownListener, "shutdownListener"); + public Builder setTerminationListener(Runnable terminationListener) { + this.terminationListener = checkNotNull(terminationListener, "terminationListener"); return this; } } diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java b/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java index 56464d58a4..72a02c92ff 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java @@ -238,12 +238,4 @@ public final class BinderTransportSecurity { */ ListenableFuture checkAuthorizationForServiceAsync(int uid, String serviceName); } - - /** - * A listener invoked when the {@link io.grpc.binder.internal.BinderServer} shuts down, allowing - * resources to be potentially cleaned up. - */ - public interface ShutdownListener { - void onServerShutdown(); - } } diff --git a/binder/src/test/java/io/grpc/binder/internal/ActiveTransportTrackerTest.java b/binder/src/test/java/io/grpc/binder/internal/ActiveTransportTrackerTest.java new file mode 100644 index 0000000000..099756075f --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/internal/ActiveTransportTrackerTest.java @@ -0,0 +1,113 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +import io.grpc.internal.ServerListener; +import io.grpc.internal.ServerTransport; +import io.grpc.internal.ServerTransportListener; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; +import org.robolectric.RobolectricTestRunner; + +@RunWith(RobolectricTestRunner.class) +public final class ActiveTransportTrackerTest { + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + + private ActiveTransportTracker tracker; + + @Mock Runnable mockShutdownListener; + @Mock ServerListener mockServerListener; + @Mock ServerTransportListener mockServerTransportListener; + @Mock ServerTransport mockServerTransport; + + @Before + public void setUp() { + when(mockServerListener.transportCreated(any())).thenReturn(mockServerTransportListener); + tracker = new ActiveTransportTracker(mockServerListener, mockShutdownListener); + } + + @Test + public void testServerShutdown_onlyNotifiesAfterAllTransportAreTerminated() { + ServerTransportListener wrapperListener1 = registerNewTransport(); + ServerTransportListener wrapperListener2 = registerNewTransport(); + + tracker.serverShutdown(); + // 2 active transports, notification scheduled + verifyNoInteractions(mockShutdownListener); + + wrapperListener1.transportTerminated(); + // 1 active transport remaining, notification still pending + verifyNoInteractions(mockShutdownListener); + + wrapperListener2.transportTerminated(); + // No more active transports, shutdown notified + verify(mockShutdownListener).run(); + } + + @Test + public void testServerShutdown_noActiveTransport_notifiesTerminationImmediately() { + verifyNoInteractions(mockShutdownListener); + + tracker.serverShutdown(); + + verify(mockShutdownListener).run(); + } + + @Test + public void testLastTransportTerminated_serverNotShutdownYet_doesNotNotify() { + ServerTransportListener wrapperListener = registerNewTransport(); + verifyNoInteractions(mockShutdownListener); + + wrapperListener.transportTerminated(); + + verifyNoInteractions(mockShutdownListener); + } + + @Test + public void testTransportCreation_afterServerShutdown_throws() { + tracker.serverShutdown(); + + assertThrows(IllegalStateException.class, this::registerNewTransport); + } + + @Test + public void testServerListenerCallbacks_invokesDelegates() { + ServerTransportListener listener = tracker.transportCreated(mockServerTransport); + verify(mockServerListener).transportCreated(mockServerTransport); + + listener.transportTerminated(); + verify(mockServerTransportListener).transportTerminated(); + + tracker.serverShutdown(); + verify(mockServerListener).serverShutdown(); + } + + private ServerTransportListener registerNewTransport() { + return tracker.transportCreated(mockServerTransport); + } +}