diff --git a/spring-boot/src/main/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactory.java b/spring-boot/src/main/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactory.java index d9f83c54a4..651cb36681 100644 --- a/spring-boot/src/main/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactory.java +++ b/spring-boot/src/main/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactory.java @@ -17,13 +17,8 @@ package org.springframework.boot.web.servlet.server; import java.io.File; -import java.io.IOException; -import java.net.JarURLConnection; import java.net.URL; -import java.net.URLClassLoader; -import java.net.URLConnection; import java.nio.charset.Charset; -import java.security.CodeSource; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -31,13 +26,10 @@ import java.util.List; import java.util.Locale; import java.util.Map; import java.util.concurrent.TimeUnit; -import java.util.jar.JarFile; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.springframework.boot.ApplicationHome; -import org.springframework.boot.ApplicationTemp; import org.springframework.boot.web.server.AbstractConfigurableWebServerFactory; import org.springframework.boot.web.server.MimeMappings; import org.springframework.boot.web.servlet.ServletContextInitializer; @@ -63,9 +55,6 @@ public abstract class AbstractServletWebServerFactory private static final int DEFAULT_SESSION_TIMEOUT = (int) TimeUnit.MINUTES .toSeconds(30); - private static final String[] COMMON_DOC_ROOTS = { "src/main/webapp", "public", - "static" }; - protected final Log logger = LogFactory.getLog(getClass()); private String contextPath = ""; @@ -76,20 +65,22 @@ public abstract class AbstractServletWebServerFactory private boolean persistSession; - private File sessionStoreDir; - private boolean registerDefaultServlet = true; private MimeMappings mimeMappings = new MimeMappings(MimeMappings.DEFAULT); - private File documentRoot; - private List initializers = new ArrayList<>(); private Jsp jsp = new Jsp(); private Map localeCharsetMappings = new HashMap<>(); + private final SessionStoreDirectory sessionStoreDir = new SessionStoreDirectory(); + + private final DocumentRoot documentRoot = new DocumentRoot(this.logger); + + private final StaticResourceJars staticResourceJars = new StaticResourceJars(); + /** * Create a new {@link AbstractServletWebServerFactory} instance. */ @@ -184,12 +175,12 @@ public abstract class AbstractServletWebServerFactory } public File getSessionStoreDir() { - return this.sessionStoreDir; + return this.sessionStoreDir.getDirectory(); } @Override public void setSessionStoreDir(File sessionStoreDir) { - this.sessionStoreDir = sessionStoreDir; + this.sessionStoreDir.setDirectory(sessionStoreDir); } /** @@ -224,12 +215,12 @@ public abstract class AbstractServletWebServerFactory * @return the document root */ public File getDocumentRoot() { - return this.documentRoot; + return this.documentRoot.getDirectory(); } @Override public void setDocumentRoot(File documentRoot) { - this.documentRoot = documentRoot; + this.documentRoot.setDirectory(documentRoot); } @Override @@ -298,171 +289,19 @@ public abstract class AbstractServletWebServerFactory * @return the valid document root */ protected final File getValidDocumentRoot() { - File file = getDocumentRoot(); - // If document root not explicitly set see if we are running from a war archive - file = file != null ? file : getWarFileDocumentRoot(); - // If not a war archive maybe it is an exploded war - file = file != null ? file : getExplodedWarFileDocumentRoot(); - // Or maybe there is a document root in a well-known location - file = file != null ? file : getCommonDocumentRoot(); - if (file == null && this.logger.isDebugEnabled()) { - this.logger - .debug("None of the document roots " + Arrays.asList(COMMON_DOC_ROOTS) - + " point to a directory and will be ignored."); - } - else if (this.logger.isDebugEnabled()) { - this.logger.debug("Document root: " + file); - } - return file; - } - - private File getExplodedWarFileDocumentRoot() { - return getExplodedWarFileDocumentRoot(getCodeSourceArchive()); - } - - protected List getUrlsOfJarsWithMetaInfResources() { - ClassLoader classLoader = getClass().getClassLoader(); - List staticResourceUrls = new ArrayList<>(); - if (classLoader instanceof URLClassLoader) { - for (URL url : ((URLClassLoader) classLoader).getURLs()) { - try { - if ("file".equals(url.getProtocol())) { - File file = new File(url.getFile()); - if (file.isDirectory() - && new File(file, "META-INF/resources").isDirectory()) { - staticResourceUrls.add(url); - } - else if (isResourcesJar(file)) { - staticResourceUrls.add(url); - } - } - else { - URLConnection connection = url.openConnection(); - if (connection instanceof JarURLConnection) { - if (isResourcesJar((JarURLConnection) connection)) { - staticResourceUrls.add(url); - } - } - } - } - catch (IOException ex) { - throw new IllegalStateException(ex); - } - } - } - return staticResourceUrls; + return this.documentRoot.getValidDirectory(); } - private boolean isResourcesJar(JarURLConnection connection) { - try { - return isResourcesJar(connection.getJarFile()); - } - catch (IOException ex) { - return false; - } - } - - private boolean isResourcesJar(File file) { - try { - return isResourcesJar(new JarFile(file)); - } - catch (IOException ex) { - return false; - } - } - - private boolean isResourcesJar(JarFile jar) throws IOException { - try { - return jar.getName().endsWith(".jar") - && (jar.getJarEntry("META-INF/resources") != null); - } - finally { - jar.close(); - } - } - - protected final File getExplodedWarFileDocumentRoot(File codeSourceFile) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("Code archive: " + codeSourceFile); - } - if (codeSourceFile != null && codeSourceFile.exists()) { - String path = codeSourceFile.getAbsolutePath(); - int webInfPathIndex = path - .indexOf(File.separatorChar + "WEB-INF" + File.separatorChar); - if (webInfPathIndex >= 0) { - path = path.substring(0, webInfPathIndex); - return new File(path); - } - } - return null; - } - - private File getWarFileDocumentRoot() { - return getArchiveFileDocumentRoot(".war"); - } - - private File getArchiveFileDocumentRoot(String extension) { - File file = getCodeSourceArchive(); - if (this.logger.isDebugEnabled()) { - this.logger.debug("Code archive: " + file); - } - if (file != null && file.exists() && !file.isDirectory() - && file.getName().toLowerCase().endsWith(extension)) { - return file.getAbsoluteFile(); - } - return null; - } - - private File getCommonDocumentRoot() { - for (String commonDocRoot : COMMON_DOC_ROOTS) { - File root = new File(commonDocRoot); - if (root.exists() && root.isDirectory()) { - return root.getAbsoluteFile(); - } - } - return null; - } - - private File getCodeSourceArchive() { - try { - CodeSource codeSource = getClass().getProtectionDomain().getCodeSource(); - URL location = (codeSource == null ? null : codeSource.getLocation()); - if (location == null) { - return null; - } - String path = location.getPath(); - URLConnection connection = location.openConnection(); - if (connection instanceof JarURLConnection) { - path = ((JarURLConnection) connection).getJarFile().getName(); - } - if (path.indexOf("!/") != -1) { - path = path.substring(0, path.indexOf("!/")); - } - return new File(path); - } - catch (IOException ex) { - return null; - } + protected final List getUrlsOfJarsWithMetaInfResources() { + return this.staticResourceJars.getUrls(); } protected final File getValidSessionStoreDir() { - return getValidSessionStoreDir(true); + return this.sessionStoreDir.getValidDirectory(true); } protected final File getValidSessionStoreDir(boolean mkdirs) { - File dir = getSessionStoreDir(); - if (dir == null) { - return new ApplicationTemp().getDir("servlet-sessions"); - } - if (!dir.isAbsolute()) { - dir = new File(new ApplicationHome().getDir(), dir.getPath()); - } - if (!dir.exists() && mkdirs) { - dir.mkdirs(); - } - Assert.state(!mkdirs || dir.exists(), "Session dir " + dir + " does not exist"); - Assert.state(!dir.isFile(), "Session dir " + dir + " points to a file"); - return dir; + return this.sessionStoreDir.getValidDirectory(mkdirs); } } diff --git a/spring-boot/src/main/java/org/springframework/boot/web/servlet/server/DocumentRoot.java b/spring-boot/src/main/java/org/springframework/boot/web/servlet/server/DocumentRoot.java new file mode 100644 index 0000000000..484cf19804 --- /dev/null +++ b/spring-boot/src/main/java/org/springframework/boot/web/servlet/server/DocumentRoot.java @@ -0,0 +1,148 @@ +/* + * Copyright 2012-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.servlet.server; + +import java.io.File; +import java.io.IOException; +import java.net.JarURLConnection; +import java.net.URL; +import java.net.URLConnection; +import java.security.CodeSource; +import java.util.Arrays; + +import org.apache.commons.logging.Log; + +/** + * Manages a {@link ServletWebServerFactory} document root. + * + * @author Phillip Webb + * @see AbstractServletWebServerFactory + */ +class DocumentRoot { + + private static final String[] COMMON_DOC_ROOTS = { "src/main/webapp", "public", + "static" }; + + private final Log logger; + + private File directory; + + DocumentRoot(Log logger) { + this.logger = logger; + } + + public File getDirectory() { + return this.directory; + } + + public void setDirectory(File directory) { + this.directory = directory; + } + + /** + * Returns the absolute document root when it points to a valid directory, logging a + * warning and returning {@code null} otherwise. + * @return the valid document root + */ + public final File getValidDirectory() { + File file = this.directory; + file = (file != null ? file : getWarFileDocumentRoot()); + file = (file != null ? file : getExplodedWarFileDocumentRoot()); + file = (file != null ? file : getCommonDocumentRoot()); + if (file == null && this.logger.isDebugEnabled()) { + logNoDocumentRoots(); + } + else if (this.logger.isDebugEnabled()) { + this.logger.debug("Document root: " + file); + } + return file; + } + + private File getWarFileDocumentRoot() { + return getArchiveFileDocumentRoot(".war"); + } + + private File getArchiveFileDocumentRoot(String extension) { + File file = getCodeSourceArchive(); + if (this.logger.isDebugEnabled()) { + this.logger.debug("Code archive: " + file); + } + if (file != null && file.exists() && !file.isDirectory() + && file.getName().toLowerCase().endsWith(extension)) { + return file.getAbsoluteFile(); + } + return null; + } + + private File getExplodedWarFileDocumentRoot() { + return getExplodedWarFileDocumentRoot(getCodeSourceArchive()); + } + + private File getCodeSourceArchive() { + try { + CodeSource codeSource = getClass().getProtectionDomain().getCodeSource(); + URL location = (codeSource == null ? null : codeSource.getLocation()); + if (location == null) { + return null; + } + String path = location.getPath(); + URLConnection connection = location.openConnection(); + if (connection instanceof JarURLConnection) { + path = ((JarURLConnection) connection).getJarFile().getName(); + } + if (path.indexOf("!/") != -1) { + path = path.substring(0, path.indexOf("!/")); + } + return new File(path); + } + catch (IOException ex) { + return null; + } + } + + public final File getExplodedWarFileDocumentRoot(File codeSourceFile) { + if (this.logger.isDebugEnabled()) { + this.logger.debug("Code archive: " + codeSourceFile); + } + if (codeSourceFile != null && codeSourceFile.exists()) { + String path = codeSourceFile.getAbsolutePath(); + int webInfPathIndex = path + .indexOf(File.separatorChar + "WEB-INF" + File.separatorChar); + if (webInfPathIndex >= 0) { + path = path.substring(0, webInfPathIndex); + return new File(path); + } + } + return null; + } + + private File getCommonDocumentRoot() { + for (String commonDocRoot : COMMON_DOC_ROOTS) { + File root = new File(commonDocRoot); + if (root.exists() && root.isDirectory()) { + return root.getAbsoluteFile(); + } + } + return null; + } + + private void logNoDocumentRoots() { + this.logger.debug("None of the document roots " + Arrays.asList(COMMON_DOC_ROOTS) + + " point to a directory and will be ignored."); + } + +} diff --git a/spring-boot/src/main/java/org/springframework/boot/web/servlet/server/SessionStoreDirectory.java b/spring-boot/src/main/java/org/springframework/boot/web/servlet/server/SessionStoreDirectory.java new file mode 100644 index 0000000000..4af80c19b6 --- /dev/null +++ b/spring-boot/src/main/java/org/springframework/boot/web/servlet/server/SessionStoreDirectory.java @@ -0,0 +1,59 @@ +/* + * Copyright 2012-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.servlet.server; + +import java.io.File; + +import org.springframework.boot.ApplicationHome; +import org.springframework.boot.ApplicationTemp; +import org.springframework.util.Assert; + +/** + * Manages a session store directory. + * + * @author Phillip Webb + * @see AbstractServletWebServerFactory + */ +class SessionStoreDirectory { + + private File directory; + + public File getDirectory() { + return this.directory; + } + + public void setDirectory(File directory) { + this.directory = directory; + } + + public File getValidDirectory(boolean mkdirs) { + File dir = getDirectory(); + if (dir == null) { + return new ApplicationTemp().getDir("servlet-sessions"); + } + if (!dir.isAbsolute()) { + dir = new File(new ApplicationHome().getDir(), dir.getPath()); + } + if (!dir.exists() && mkdirs) { + dir.mkdirs(); + } + Assert.state(!mkdirs || dir.exists(), "Session dir " + dir + " does not exist"); + Assert.state(!dir.isFile(), "Session dir " + dir + " points to a file"); + return dir; + } + +} diff --git a/spring-boot/src/main/java/org/springframework/boot/web/servlet/server/StaticResourceJars.java b/spring-boot/src/main/java/org/springframework/boot/web/servlet/server/StaticResourceJars.java new file mode 100644 index 0000000000..618649051e --- /dev/null +++ b/spring-boot/src/main/java/org/springframework/boot/web/servlet/server/StaticResourceJars.java @@ -0,0 +1,108 @@ +/* + * Copyright 2012-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.servlet.server; + +import java.io.File; +import java.io.IOException; +import java.net.JarURLConnection; +import java.net.URL; +import java.net.URLClassLoader; +import java.net.URLConnection; +import java.util.ArrayList; +import java.util.List; +import java.util.jar.JarFile; + +/** + * Logic to extract URLs of static resource jars (those containing + * {@code "META-INF/resources"} directories). + * + * @author Andy Wilkinson + * @author Phillip Webb + */ +class StaticResourceJars { + + public final List getUrls() { + ClassLoader classLoader = getClass().getClassLoader(); + List urls = new ArrayList<>(); + if (classLoader instanceof URLClassLoader) { + for (URL url : ((URLClassLoader) classLoader).getURLs()) { + addUrl(urls, url); + } + } + return urls; + } + + private void addUrl(List urls, URL url) { + try { + if ("file".equals(url.getProtocol())) { + addUrlFile(urls, url, new File(url.getFile())); + } + else { + addUrlConnection(urls, url, url.openConnection()); + } + } + catch (IOException ex) { + throw new IllegalStateException(ex); + } + } + + private void addUrlFile(List urls, URL url, File file) { + if (file.isDirectory() && new File(file, "META-INF/resources").isDirectory()) { + urls.add(url); + } + else if (isResourcesJar(file)) { + urls.add(url); + } + } + + private void addUrlConnection(List urls, URL url, URLConnection connection) { + if (connection instanceof JarURLConnection) { + if (isResourcesJar((JarURLConnection) connection)) { + urls.add(url); + } + } + } + + private boolean isResourcesJar(JarURLConnection connection) { + try { + return isResourcesJar(connection.getJarFile()); + } + catch (IOException ex) { + return false; + } + } + + private boolean isResourcesJar(File file) { + try { + return isResourcesJar(new JarFile(file)); + } + catch (IOException ex) { + return false; + } + } + + private boolean isResourcesJar(JarFile jar) throws IOException { + try { + return jar.getName().endsWith(".jar") + && (jar.getJarEntry("META-INF/resources") != null); + } + finally { + jar.close(); + } + } + +} diff --git a/spring-boot/src/test/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactoryTests.java b/spring-boot/src/test/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactoryTests.java index 2a3af46077..bac2cca8db 100644 --- a/spring-boot/src/test/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactoryTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactoryTests.java @@ -161,8 +161,7 @@ public abstract class AbstractServletWebServerFactoryTests { @Test public void startServlet() throws Exception { AbstractServletWebServerFactory factory = getFactory(); - this.webServer = factory - .getWebServer(exampleServletRegistration()); + this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); assertThat(getResponse(getLocalUrl("/hello"))).isEqualTo("Hello World"); } @@ -170,8 +169,7 @@ public abstract class AbstractServletWebServerFactoryTests { @Test public void startCalledTwice() throws Exception { AbstractServletWebServerFactory factory = getFactory(); - this.webServer = factory - .getWebServer(exampleServletRegistration()); + this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); int port = this.webServer.getPort(); this.webServer.start(); @@ -183,8 +181,7 @@ public abstract class AbstractServletWebServerFactoryTests { @Test public void stopCalledTwice() throws Exception { AbstractServletWebServerFactory factory = getFactory(); - this.webServer = factory - .getWebServer(exampleServletRegistration()); + this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); this.webServer.stop(); this.webServer.stop(); @@ -194,8 +191,7 @@ public abstract class AbstractServletWebServerFactoryTests { public void emptyServerWhenPortIsMinusOne() throws Exception { AbstractServletWebServerFactory factory = getFactory(); factory.setPort(-1); - this.webServer = factory - .getWebServer(exampleServletRegistration()); + this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); assertThat(this.webServer.getPort()).isLessThan(0); // Jetty is -2 } @@ -203,8 +199,7 @@ public abstract class AbstractServletWebServerFactoryTests { @Test public void stopServlet() throws Exception { AbstractServletWebServerFactory factory = getFactory(); - this.webServer = factory - .getWebServer(exampleServletRegistration()); + this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); int port = this.webServer.getPort(); this.webServer.stop(); @@ -217,8 +212,7 @@ public abstract class AbstractServletWebServerFactoryTests { @Test public void restartWithKeepAlive() throws Exception { AbstractServletWebServerFactory factory = getFactory(); - this.webServer = factory - .getWebServer(exampleServletRegistration()); + this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); HttpComponentsAsyncClientHttpRequestFactory clientHttpRequestFactory = new HttpComponentsAsyncClientHttpRequestFactory(); ListenableFuture response1 = clientHttpRequestFactory @@ -227,8 +221,7 @@ public abstract class AbstractServletWebServerFactoryTests { assertThat(response1.get(10, TimeUnit.SECONDS).getRawStatusCode()).isEqualTo(200); this.webServer.stop(); - this.webServer = factory - .getWebServer(exampleServletRegistration()); + this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); ListenableFuture response2 = clientHttpRequestFactory @@ -251,20 +244,18 @@ public abstract class AbstractServletWebServerFactoryTests { public void startBlocksUntilReadyToServe() throws Exception { AbstractServletWebServerFactory factory = getFactory(); final Date[] date = new Date[1]; - this.webServer = factory - .getWebServer(new ServletContextInitializer() { - @Override - public void onStartup(ServletContext servletContext) - throws ServletException { - try { - Thread.sleep(500); - date[0] = new Date(); - } - catch (InterruptedException ex) { - throw new ServletException(ex); - } - } - }); + this.webServer = factory.getWebServer(new ServletContextInitializer() { + @Override + public void onStartup(ServletContext servletContext) throws ServletException { + try { + Thread.sleep(500); + date[0] = new Date(); + } + catch (InterruptedException ex) { + throw new ServletException(ex); + } + } + }); this.webServer.start(); assertThat(date[0]).isNotNull(); } @@ -273,14 +264,12 @@ public abstract class AbstractServletWebServerFactoryTests { public void loadOnStartAfterContextIsInitialized() throws Exception { AbstractServletWebServerFactory factory = getFactory(); final InitCountingServlet servlet = new InitCountingServlet(); - this.webServer = factory - .getWebServer(new ServletContextInitializer() { - @Override - public void onStartup(ServletContext servletContext) - throws ServletException { - servletContext.addServlet("test", servlet).setLoadOnStartup(1); - } - }); + this.webServer = factory.getWebServer(new ServletContextInitializer() { + @Override + public void onStartup(ServletContext servletContext) throws ServletException { + servletContext.addServlet("test", servlet).setLoadOnStartup(1); + } + }); assertThat(servlet.getInitCount()).isEqualTo(0); this.webServer.start(); assertThat(servlet.getInitCount()).isEqualTo(1); @@ -291,8 +280,7 @@ public abstract class AbstractServletWebServerFactoryTests { AbstractServletWebServerFactory factory = getFactory(); int specificPort = SocketUtils.findAvailableTcpPort(41000); factory.setPort(specificPort); - this.webServer = factory - .getWebServer(exampleServletRegistration()); + this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); assertThat(getResponse("http://localhost:" + specificPort + "/hello")) .isEqualTo("Hello World"); @@ -303,8 +291,7 @@ public abstract class AbstractServletWebServerFactoryTests { public void specificContextRoot() throws Exception { AbstractServletWebServerFactory factory = getFactory(); factory.setContextPath("/say"); - this.webServer = factory - .getWebServer(exampleServletRegistration()); + this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); assertThat(getResponse(getLocalUrl("/say/hello"))).isEqualTo("Hello World"); } @@ -340,8 +327,7 @@ public abstract class AbstractServletWebServerFactoryTests { } factory.setInitializers(Arrays.asList(initializers[2], initializers[3])); factory.addInitializers(initializers[4], initializers[5]); - this.webServer = factory.getWebServer(initializers[0], - initializers[1]); + this.webServer = factory.getWebServer(initializers[0], initializers[1]); this.webServer.start(); InOrder ordered = inOrder((Object[]) initializers); for (ServletContextInitializer initializer : initializers) { @@ -450,13 +436,15 @@ public abstract class AbstractServletWebServerFactoryTests { @Test public void sslKeyAlias() throws Exception { AbstractServletWebServerFactory factory = getFactory(); - factory.setSsl(getSsl(null, "password", "test-alias", "src/test/resources/test.jks")); + factory.setSsl( + getSsl(null, "password", "test-alias", "src/test/resources/test.jks")); this.webServer = factory.getWebServer( new ServletRegistrationBean<>(new ExampleServlet(true, false), "/hello")); this.webServer.start(); SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( - new SSLContextBuilder() - .loadTrustMaterial(null, new SerialNumberValidatingTrustSelfSignedStrategy("77e7c302")).build()); + new SSLContextBuilder().loadTrustMaterial(null, + new SerialNumberValidatingTrustSelfSignedStrategy("77e7c302")) + .build()); HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory) .build(); HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory( @@ -666,8 +654,7 @@ public abstract class AbstractServletWebServerFactoryTests { @Test public void cannotReadClassPathFiles() throws Exception { AbstractServletWebServerFactory factory = getFactory(); - this.webServer = factory - .getWebServer(exampleServletRegistration()); + this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); ClientHttpResponse response = getClientResponse( getLocalUrl("/org/springframework/boot/SpringApplication.class")); @@ -678,17 +665,20 @@ public abstract class AbstractServletWebServerFactoryTests { return getSsl(clientAuth, keyPassword, keyStore, null, null, null); } - private Ssl getSsl(ClientAuth clientAuth, String keyPassword, String keyAlias, String keyStore) { + private Ssl getSsl(ClientAuth clientAuth, String keyPassword, String keyAlias, + String keyStore) { return getSsl(clientAuth, keyPassword, keyAlias, keyStore, null, null, null); } private Ssl getSsl(ClientAuth clientAuth, String keyPassword, String keyStore, String trustStore, String[] supportedProtocols, String[] ciphers) { - return getSsl(clientAuth, keyPassword, null, keyStore, trustStore, supportedProtocols, ciphers); + return getSsl(clientAuth, keyPassword, null, keyStore, trustStore, + supportedProtocols, ciphers); } - private Ssl getSsl(ClientAuth clientAuth, String keyPassword, String keyAlias, String keyStore, - String trustStore, String[] supportedProtocols, String[] ciphers) { + private Ssl getSsl(ClientAuth clientAuth, String keyPassword, String keyAlias, + String keyStore, String trustStore, String[] supportedProtocols, + String[] ciphers) { Ssl ssl = new Ssl(); ssl.setClientAuth(clientAuth); if (keyPassword != null) { @@ -751,14 +741,12 @@ public abstract class AbstractServletWebServerFactoryTests { public void persistSession() throws Exception { AbstractServletWebServerFactory factory = getFactory(); factory.setPersistSession(true); - this.webServer = factory - .getWebServer(sessionServletRegistration()); + this.webServer = factory.getWebServer(sessionServletRegistration()); this.webServer.start(); String s1 = getResponse(getLocalUrl("/session")); String s2 = getResponse(getLocalUrl("/session")); this.webServer.stop(); - this.webServer = factory - .getWebServer(sessionServletRegistration()); + this.webServer = factory.getWebServer(sessionServletRegistration()); this.webServer.start(); String s3 = getResponse(getLocalUrl("/session")); String message = "Session error s1=" + s1 + " s2=" + s2 + " s3=" + s3; @@ -772,8 +760,7 @@ public abstract class AbstractServletWebServerFactoryTests { File sessionStoreDir = this.temporaryFolder.newFolder(); factory.setPersistSession(true); factory.setSessionStoreDir(sessionStoreDir); - this.webServer = factory - .getWebServer(sessionServletRegistration()); + this.webServer = factory.getWebServer(sessionServletRegistration()); this.webServer.start(); getResponse(getLocalUrl("/session")); this.webServer.stop(); @@ -876,19 +863,17 @@ public abstract class AbstractServletWebServerFactoryTests { public void rootServletContextResource() throws Exception { AbstractServletWebServerFactory factory = getFactory(); final AtomicReference rootResource = new AtomicReference<>(); - this.webServer = factory - .getWebServer(new ServletContextInitializer() { - @Override - public void onStartup(ServletContext servletContext) - throws ServletException { - try { - rootResource.set(servletContext.getResource("/")); - } - catch (MalformedURLException ex) { - throw new ServletException(ex); - } - } - }); + this.webServer = factory.getWebServer(new ServletContextInitializer() { + @Override + public void onStartup(ServletContext servletContext) throws ServletException { + try { + rootResource.set(servletContext.getResource("/")); + } + catch (MalformedURLException ex) { + throw new ServletException(ex); + } + } + }); this.webServer.start(); assertThat(rootResource.get()).isNotNull(); } @@ -897,8 +882,7 @@ public abstract class AbstractServletWebServerFactoryTests { public void customServerHeader() throws Exception { AbstractServletWebServerFactory factory = getFactory(); factory.setServerHeader("MyServer"); - this.webServer = factory - .getWebServer(exampleServletRegistration()); + this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); ClientHttpResponse response = getClientResponse(getLocalUrl("/hello")); assertThat(response.getHeaders().getFirst("server")).isEqualTo("MyServer"); @@ -907,8 +891,7 @@ public abstract class AbstractServletWebServerFactoryTests { @Test public void serverHeaderIsDisabledByDefault() throws Exception { AbstractServletWebServerFactory factory = getFactory(); - this.webServer = factory - .getWebServer(exampleServletRegistration()); + this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); ClientHttpResponse response = getClientResponse(getLocalUrl("/hello")); assertThat(response.getHeaders().getFirst("server")).isNull(); @@ -995,24 +978,6 @@ public abstract class AbstractServletWebServerFactoryTests { assertThat(options.getDevelopment()).isEqualTo(false); } - @Test - public void explodedWarFileDocumentRootWhenRunningFromExplodedWar() throws Exception { - AbstractServletWebServerFactory factory = getFactory(); - File webInfClasses = this.temporaryFolder.newFolder("test.war", "WEB-INF", "lib", - "spring-boot.jar"); - File documentRoot = factory.getExplodedWarFileDocumentRoot(webInfClasses); - assertThat(documentRoot) - .isEqualTo(webInfClasses.getParentFile().getParentFile().getParentFile()); - } - - @Test - public void explodedWarFileDocumentRootWhenRunningFromPackagedWar() throws Exception { - AbstractServletWebServerFactory factory = getFactory(); - File codeSourceFile = this.temporaryFolder.newFile("test.war"); - File documentRoot = factory.getExplodedWarFileDocumentRoot(codeSourceFile); - assertThat(documentRoot).isNull(); - } - protected abstract void addConnector(int port, AbstractServletWebServerFactory factory); @@ -1287,10 +1252,10 @@ public abstract class AbstractServletWebServerFactoryTests { } /** - * {@link TrustSelfSignedStrategy} that also validates certificate serial - * number. + * {@link TrustSelfSignedStrategy} that also validates certificate serial number. */ - private static final class SerialNumberValidatingTrustSelfSignedStrategy extends TrustSelfSignedStrategy { + private static final class SerialNumberValidatingTrustSelfSignedStrategy + extends TrustSelfSignedStrategy { private final String serialNumber; @@ -1299,7 +1264,8 @@ public abstract class AbstractServletWebServerFactoryTests { } @Override - public boolean isTrusted(X509Certificate[] chain, String authType) throws CertificateException { + public boolean isTrusted(X509Certificate[] chain, String authType) + throws CertificateException { String hexSerialNumber = chain[0].getSerialNumber().toString(16); boolean isMatch = hexSerialNumber.equals(this.serialNumber); return super.isTrusted(chain, authType) && isMatch; diff --git a/spring-boot/src/test/java/org/springframework/boot/web/servlet/server/DocumentRootTests.java b/spring-boot/src/test/java/org/springframework/boot/web/servlet/server/DocumentRootTests.java new file mode 100644 index 0000000000..42b061c701 --- /dev/null +++ b/spring-boot/src/test/java/org/springframework/boot/web/servlet/server/DocumentRootTests.java @@ -0,0 +1,56 @@ +/* + * Copyright 2012-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.servlet.server; + +import java.io.File; + +import org.apache.commons.logging.LogFactory; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link DocumentRoot}. + * + * @author Phillip Webb + */ +public class DocumentRootTests { + + @Rule + public TemporaryFolder temporaryFolder = new TemporaryFolder(); + + private DocumentRoot documentRoot = new DocumentRoot(LogFactory.getLog(getClass())); + + @Test + public void explodedWarFileDocumentRootWhenRunningFromExplodedWar() throws Exception { + File webInfClasses = this.temporaryFolder.newFolder("test.war", "WEB-INF", "lib", + "spring-boot.jar"); + File directory = this.documentRoot.getExplodedWarFileDocumentRoot(webInfClasses); + assertThat(directory) + .isEqualTo(webInfClasses.getParentFile().getParentFile().getParentFile()); + } + + @Test + public void explodedWarFileDocumentRootWhenRunningFromPackagedWar() throws Exception { + File codeSourceFile = this.temporaryFolder.newFile("test.war"); + File directory = this.documentRoot.getExplodedWarFileDocumentRoot(codeSourceFile); + assertThat(directory).isNull(); + } + +}