diff --git a/spring-boot/src/main/java/org/springframework/boot/context/embedded/tomcat/TomcatEmbeddedServletContainer.java b/spring-boot/src/main/java/org/springframework/boot/context/embedded/tomcat/TomcatEmbeddedServletContainer.java index 78d537f49a..1d98374456 100644 --- a/spring-boot/src/main/java/org/springframework/boot/context/embedded/tomcat/TomcatEmbeddedServletContainer.java +++ b/spring-boot/src/main/java/org/springframework/boot/context/embedded/tomcat/TomcatEmbeddedServletContainer.java @@ -16,12 +16,16 @@ package org.springframework.boot.context.embedded.tomcat; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; import org.apache.catalina.Container; import org.apache.catalina.Engine; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; +import org.apache.catalina.Server; +import org.apache.catalina.Service; import org.apache.catalina.connector.Connector; import org.apache.catalina.startup.Tomcat; import org.apache.commons.logging.Log; @@ -47,6 +51,8 @@ public class TomcatEmbeddedServletContainer implements EmbeddedServletContainer private final Tomcat tomcat; + private final Map serviceConnectors = new HashMap(); + private final boolean autoStart; /** @@ -71,12 +77,25 @@ public class TomcatEmbeddedServletContainer implements EmbeddedServletContainer private synchronized void initialize() throws EmbeddedServletContainerException { try { + Server server = this.tomcat.getServer(); int instanceId = containerCounter.incrementAndGet(); if (instanceId > 0) { Engine engine = this.tomcat.getEngine(); engine.setName(engine.getName() + "-" + instanceId); } + + // Remove service connectors to that protocol binding doesn't happen yet + for (Service service : server.findServices()) { + Connector[] connectors = service.findConnectors().clone(); + this.serviceConnectors.put(service, connectors); + for (Connector connector : connectors) { + service.removeConnector(connector); + } + } + + // Start the server to trigger initialization listeners this.tomcat.start(); + Container[] children = this.tomcat.getHost().findChildren(); for (Container container : children) { if (container instanceof TomcatEmbeddedContext) { @@ -87,16 +106,7 @@ public class TomcatEmbeddedServletContainer implements EmbeddedServletContainer } } } - try { - // Allow the server to start so the ServletContext is available, but stop - // the connector to prevent requests from being handled before the Spring - // context is ready: - Connector connector = this.tomcat.getConnector(); - connector.getProtocolHandler().stop(); - } - catch (Exception ex) { - this.logger.error("Cannot pause connector: ", ex); - } + // Unlike Jetty, all Tomcat threads are daemon threads. We create a // blocking non-daemon to stop immediate shutdown Thread awaitThread = new Thread("container-" + (containerCounter.get())) { @@ -120,6 +130,20 @@ public class TomcatEmbeddedServletContainer implements EmbeddedServletContainer @Override public void start() throws EmbeddedServletContainerException { + // Add the previously removed connectors (also starting them) + Service[] services = this.tomcat.getServer().findServices(); + for (Service service : services) { + Connector[] connectors = this.serviceConnectors.get(service); + if (connectors != null) { + for (Connector connector : connectors) { + service.addConnector(connector); + if (!this.autoStart) { + unbind(connector); + } + } + this.serviceConnectors.remove(service); + } + } Connector connector = this.tomcat.getConnector(); if (connector != null && this.autoStart) { try { @@ -139,6 +163,19 @@ public class TomcatEmbeddedServletContainer implements EmbeddedServletContainer } } + private void unbind(Connector connector) { + try { + connector.getProtocolHandler().stop(); + } + catch (Exception ex) { + this.logger.error("Cannot pause connector: ", ex); + } + } + + Map getServiceConnectors() { + return this.serviceConnectors; + } + private void logPorts() { StringBuilder ports = new StringBuilder(); for (Connector additionalConnector : this.tomcat.getService().findConnectors()) { diff --git a/spring-boot/src/test/java/org/springframework/boot/context/embedded/tomcat/TomcatEmbeddedServletContainerFactoryTests.java b/spring-boot/src/test/java/org/springframework/boot/context/embedded/tomcat/TomcatEmbeddedServletContainerFactoryTests.java index 5a101b4e63..6c0964491f 100644 --- a/spring-boot/src/test/java/org/springframework/boot/context/embedded/tomcat/TomcatEmbeddedServletContainerFactoryTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/context/embedded/tomcat/TomcatEmbeddedServletContainerFactoryTests.java @@ -17,11 +17,14 @@ package org.springframework.boot.context.embedded.tomcat; import java.util.Arrays; +import java.util.Map; import java.util.concurrent.TimeUnit; import org.apache.catalina.Context; import org.apache.catalina.LifecycleEvent; import org.apache.catalina.LifecycleListener; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.Service; import org.apache.catalina.Valve; import org.apache.catalina.connector.Connector; import org.apache.catalina.startup.Tomcat; @@ -34,6 +37,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; +import static org.mockito.BDDMockito.given; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyObject; import static org.mockito.Mockito.inOrder; @@ -127,11 +131,16 @@ public class TomcatEmbeddedServletContainerFactoryTests extends TomcatEmbeddedServletContainerFactory factory = getFactory(); Connector[] listeners = new Connector[4]; for (int i = 0; i < listeners.length; i++) { - listeners[i] = mock(Connector.class); + Connector connector = mock(Connector.class); + given(connector.getState()).willReturn(LifecycleState.STOPPED); + listeners[i] = connector; } factory.addAdditionalTomcatConnectors(listeners); this.container = factory.getEmbeddedServletContainer(); - assertEquals(listeners.length, factory.getAdditionalTomcatConnectors().size()); + Map connectors = ((TomcatEmbeddedServletContainer) this.container) + .getServiceConnectors(); + assertThat(connectors.values().iterator().next().length, + equalTo(listeners.length + 1)); } @Test