diff --git a/spring-boot/src/main/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContext.java b/spring-boot/src/main/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContext.java index 268b2dee11..a4b56afa22 100644 --- a/spring-boot/src/main/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContext.java +++ b/spring-boot/src/main/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContext.java @@ -16,20 +16,10 @@ package org.springframework.boot.context.embedded; -import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; -import java.util.Comparator; import java.util.EventListener; -import java.util.LinkedHashMap; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Set; import javax.servlet.Filter; -import javax.servlet.MultipartConfigElement; import javax.servlet.Servlet; import javax.servlet.ServletConfig; import javax.servlet.ServletContext; @@ -41,7 +31,6 @@ import org.springframework.beans.BeansException; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextException; -import org.springframework.core.annotation.AnnotationAwareOrderComparator; import org.springframework.core.io.Resource; import org.springframework.util.StringUtils; import org.springframework.web.context.ContextLoader; @@ -94,7 +83,7 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext * default. To change the default behaviour you can use a * {@link ServletRegistrationBean} or a different bean name. */ - public static final String DISPATCHER_SERVLET_NAME = "dispatcherServlet"; + public static final String DISPATCHER_SERVLET_NAME = ServletContextInitializerBeans.DISPATCHER_SERVLET_NAME; private EmbeddedServletContainer embeddedServletContainer; @@ -220,108 +209,11 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext /** * Returns {@link ServletContextInitializer}s that should be used with the embedded * Servlet context. By default this method will first attempt to find - * {@link ServletContextInitializer} beans, if none are found it will instead search - * for {@link Servlet} and {@link Filter} beans. + * {@link ServletContextInitializer}, {@link Servlet}, {@link Filter} and certain + * {@link EventListener} beans. */ protected Collection getServletContextInitializerBeans() { - - List filters = new ArrayList(); - List servlets = new ArrayList(); - List listeners = new ArrayList(); - List other = new ArrayList(); - Set servletRegistrations = new LinkedHashSet(); - Set filterRegistrations = new LinkedHashSet(); - Set listenerRegistrations = new LinkedHashSet(); - - for (Entry initializerBean : getOrderedBeansOfType(ServletContextInitializer.class)) { - ServletContextInitializer initializer = initializerBean.getValue(); - if (initializer instanceof ServletRegistrationBean) { - servlets.add(initializer); - ServletRegistrationBean servlet = (ServletRegistrationBean) initializer; - servletRegistrations.add(servlet.getServlet()); - } - else if (initializer instanceof FilterRegistrationBean) { - filters.add(initializer); - FilterRegistrationBean filter = (FilterRegistrationBean) initializer; - filterRegistrations.add(filter.getFilter()); - } - else if (initializer instanceof ServletListenerRegistrationBean) { - listeners.add(initializer); - listenerRegistrations - .add(((ServletListenerRegistrationBean) initializer) - .getListener()); - } - else { - other.add(initializer); - } - } - - List> servletBeans = getOrderedBeansOfType(Servlet.class); - for (Entry servletBean : servletBeans) { - final String name = servletBean.getKey(); - Servlet servlet = servletBean.getValue(); - if (!servletRegistrations.contains(servlet)) { - String url = (servletBeans.size() == 1 ? "/" : "/" + name + "/"); - if (name.equals(DISPATCHER_SERVLET_NAME)) { - url = "/"; // always map the main dispatcherServlet to "/" - } - ServletRegistrationBean registration = new ServletRegistrationBean( - servlet, url); - registration.setName(name); - registration.setMultipartConfig(getMultipartConfig()); - registration.setOrder(CustomOrderAwareComparator.INSTANCE - .getOrder(servlet)); - servlets.add(registration); - } - } - - for (Entry filterBean : getOrderedBeansOfType(Filter.class)) { - String name = filterBean.getKey(); - Filter filter = filterBean.getValue(); - if (!filterRegistrations.contains(filter)) { - FilterRegistrationBean registration = new FilterRegistrationBean(filter); - registration.setName(name); - registration.setOrder(CustomOrderAwareComparator.INSTANCE - .getOrder(filter)); - filters.add(registration); - } - } - - Set> listenerTypes = ServletListenerRegistrationBean.getSupportedTypes(); - for (Class type : listenerTypes) { - for (Entry listenerBean : getOrderedBeansOfType(type)) { - String name = listenerBean.getKey(); - EventListener listener = (EventListener) listenerBean.getValue(); - if (ServletListenerRegistrationBean.isSupportedType(listener) - && !filterRegistrations.contains(listener)) { - ServletListenerRegistrationBean registration = new ServletListenerRegistrationBean( - listener); - registration.setName(name); - registration.setOrder(CustomOrderAwareComparator.INSTANCE - .getOrder(listener)); - listeners.add(registration); - } - } - } - AnnotationAwareOrderComparator.sort(filters); - AnnotationAwareOrderComparator.sort(servlets); - AnnotationAwareOrderComparator.sort(listeners); - AnnotationAwareOrderComparator.sort(other); - - List list = new ArrayList( - filters); - list.addAll(servlets); - list.addAll(listeners); - list.addAll(other); - return list; - } - - private MultipartConfigElement getMultipartConfig() { - List> beans = getOrderedBeansOfType(MultipartConfigElement.class); - if (beans.isEmpty()) { - return null; - } - return beans.get(0).getValue(); + return new ServletContextInitializerBeans(getBeanFactory()); } /** @@ -375,25 +267,6 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext } } - private List> getOrderedBeansOfType(Class type) { - List> beans = new ArrayList>(); - Comparator> comparator = new Comparator>() { - @Override - public int compare(Entry o1, Entry o2) { - return AnnotationAwareOrderComparator.INSTANCE.compare(o1.getValue(), - o2.getValue()); - } - }; - String[] names = getBeanFactory().getBeanNamesForType(type, true, false); - Map map = new LinkedHashMap(); - for (String name : names) { - map.put(name, getBeanFactory().getBean(name, type)); - } - beans.addAll(map.entrySet()); - Collections.sort(beans, comparator); - return beans; - } - private void startEmbeddedServletContainer() { if (this.embeddedServletContainer != null) { this.embeddedServletContainer.start(); @@ -448,15 +321,4 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext return this.embeddedServletContainer; } - private static class CustomOrderAwareComparator extends - AnnotationAwareOrderComparator { - - public static CustomOrderAwareComparator INSTANCE = new CustomOrderAwareComparator(); - - @Override - protected int getOrder(Object obj) { - return super.getOrder(obj); - } - } - } diff --git a/spring-boot/src/main/java/org/springframework/boot/context/embedded/ServletContextInitializerBeans.java b/spring-boot/src/main/java/org/springframework/boot/context/embedded/ServletContextInitializerBeans.java new file mode 100644 index 0000000000..24a98504cb --- /dev/null +++ b/spring-boot/src/main/java/org/springframework/boot/context/embedded/ServletContextInitializerBeans.java @@ -0,0 +1,258 @@ +/* + * Copyright 2012-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.context.embedded; + +import java.util.AbstractCollection; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.EventListener; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; + +import javax.servlet.Filter; +import javax.servlet.MultipartConfigElement; +import javax.servlet.Servlet; + +import org.springframework.beans.factory.ListableBeanFactory; +import org.springframework.core.annotation.AnnotationAwareOrderComparator; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +/** + * A collection {@link ServletContextInitializer}s obtained from a + * {@link ListableBeanFactory}. Includes all {@link ServletContextInitializer} beans and + * also adapts {@link Servlet}, {@link Filter} and certain {@link EventListener} beans. + *

+ * Items are sorted so that adapted beans are top ({@link Servlet}, {@link Filter} then + * {@link EventListener}) and direct {@link ServletContextInitializer} beans are at the + * end. Further sorting is applied within these groups using the + * {@link AnnotationAwareOrderComparator}. + * + * + * @author Dave Syer + * @author Phillip Webb + */ +class ServletContextInitializerBeans extends + AbstractCollection { + + static final String DISPATCHER_SERVLET_NAME = "dispatcherServlet"; + + private final Set seen = new HashSet(); + + private final MultiValueMap, ServletContextInitializer> initializers; + + private List sortedList; + + public ServletContextInitializerBeans(ListableBeanFactory beanFactory) { + this.initializers = new LinkedMultiValueMap, ServletContextInitializer>(); + addServletContextInitializerBeans(beanFactory); + addAdaptableBeans(beanFactory); + List sortedInitializers = new ArrayList(); + for (Map.Entry> entry : this.initializers + .entrySet()) { + AnnotationAwareOrderComparator.sort(entry.getValue()); + sortedInitializers.addAll(entry.getValue()); + } + this.sortedList = Collections.unmodifiableList(sortedInitializers); + } + + private void addServletContextInitializerBeans(ListableBeanFactory beanFactory) { + for (Entry initializerBean : getOrderedBeansOfType( + beanFactory, ServletContextInitializer.class)) { + addServletContextInitializerBean(initializerBean.getValue()); + } + } + + private void addServletContextInitializerBean(ServletContextInitializer initializer) { + if (initializer instanceof ServletRegistrationBean) { + addServletContextInitializerBean(Servlet.class, initializer, + ((ServletRegistrationBean) initializer).getServlet()); + } + else if (initializer instanceof FilterRegistrationBean) { + addServletContextInitializerBean(Filter.class, initializer, + ((FilterRegistrationBean) initializer).getFilter()); + } + else if (initializer instanceof ServletListenerRegistrationBean) { + addServletContextInitializerBean(EventListener.class, initializer, + ((ServletListenerRegistrationBean) initializer).getListener()); + } + else { + addServletContextInitializerBean(ServletContextInitializer.class, + initializer, null); + } + } + + private void addServletContextInitializerBean(Class type, + ServletContextInitializer initializer, Object source) { + this.initializers.add(type, initializer); + if (source != null) { + // Mark the underlying source as seen in case it wraps an existing bean + this.seen.add(source); + } + } + + @SuppressWarnings("unchecked") + private void addAdaptableBeans(ListableBeanFactory beanFactory) { + MultipartConfigElement multipartConfig = getMultipartConfig(beanFactory); + addAsRegistrationBean(beanFactory, Servlet.class, + new ServletRegistrationBeanAdapter(multipartConfig)); + addAsRegistrationBean(beanFactory, Filter.class, + new FilterRegistrationBeanAdapter()); + for (Class listenerType : ServletListenerRegistrationBean.getSupportedTypes()) { + addAsRegistrationBean(beanFactory, EventListener.class, + (Class) listenerType, + new ServletListenerRegistrationBeanAdapter()); + } + } + + private MultipartConfigElement getMultipartConfig(ListableBeanFactory beanFactory) { + List> beans = getOrderedBeansOfType( + beanFactory, MultipartConfigElement.class); + return (beans.isEmpty() ? null : beans.get(0).getValue()); + } + + private void addAsRegistrationBean(ListableBeanFactory beanFactory, + Class type, RegistrationBeanAdapter adapter) { + addAsRegistrationBean(beanFactory, type, type, adapter); + } + + private void addAsRegistrationBean(ListableBeanFactory beanFactory, + Class type, Class beanType, RegistrationBeanAdapter adapter) { + List> beans = getOrderedBeansOfType(beanFactory, beanType); + for (Entry bean : beans) { + if (this.seen.add(bean.getValue())) { + // One that we haven't already seen + RegistrationBean registration = adapter.createRegistrationBean( + bean.getKey(), bean.getValue(), beans.size()); + registration.setName(bean.getKey()); + registration.setOrder(getOrder(bean.getValue())); + this.initializers.add(type, registration); + } + } + } + + private int getOrder(Object value) { + return new AnnotationAwareOrderComparator() { + @Override + public int getOrder(Object obj) { + return super.getOrder(obj); + } + }.getOrder(value); + } + + private List> getOrderedBeansOfType( + ListableBeanFactory beanFactory, Class type) { + List> beans = new ArrayList>(); + Comparator> comparator = new Comparator>() { + @Override + public int compare(Entry o1, Entry o2) { + return AnnotationAwareOrderComparator.INSTANCE.compare(o1.getValue(), + o2.getValue()); + } + }; + String[] names = beanFactory.getBeanNamesForType(type, true, false); + Map map = new LinkedHashMap(); + for (String name : names) { + map.put(name, beanFactory.getBean(name, type)); + } + beans.addAll(map.entrySet()); + Collections.sort(beans, comparator); + return beans; + } + + @Override + public Iterator iterator() { + return this.sortedList.iterator(); + } + + @Override + public int size() { + return this.sortedList.size(); + } + + /** + * Adapter to convert a given Bean type into a {@link RegistrationBean} (and hence a + * {@link ServletContextInitializer}. + */ + private static interface RegistrationBeanAdapter { + + RegistrationBean createRegistrationBean(String name, T source, + int totalNumberOfSourceBeans); + + } + + /** + * {@link RegistrationBeanAdapter} for {@link Servlet} beans. + */ + private static class ServletRegistrationBeanAdapter implements + RegistrationBeanAdapter { + + private final MultipartConfigElement multipartConfig; + + public ServletRegistrationBeanAdapter(MultipartConfigElement multipartConfig) { + this.multipartConfig = multipartConfig; + } + + @Override + public RegistrationBean createRegistrationBean(String name, Servlet source, + int totalNumberOfSourceBeans) { + String url = (totalNumberOfSourceBeans == 1 ? "/" : "/" + name + "/"); + if (name.equals(DISPATCHER_SERVLET_NAME)) { + url = "/"; // always map the main dispatcherServlet to "/" + } + ServletRegistrationBean bean = new ServletRegistrationBean(source, url); + bean.setMultipartConfig(this.multipartConfig); + return bean; + } + + } + + /** + * {@link RegistrationBeanAdapter} for {@link Filter} beans. + */ + private static class FilterRegistrationBeanAdapter implements + RegistrationBeanAdapter { + + @Override + public RegistrationBean createRegistrationBean(String name, Filter source, + int totalNumberOfSourceBeans) { + return new FilterRegistrationBean(source); + } + + } + + /** + * {@link RegistrationBeanAdapter} for certain {@link EventListener} beans. + */ + private static class ServletListenerRegistrationBeanAdapter implements + RegistrationBeanAdapter { + + @Override + public RegistrationBean createRegistrationBean(String name, EventListener source, + int totalNumberOfSourceBeans) { + return new ServletListenerRegistrationBean(source); + } + + } + +} diff --git a/spring-boot/src/main/java/org/springframework/boot/logging/LoggingApplicationListener.java b/spring-boot/src/main/java/org/springframework/boot/logging/LoggingApplicationListener.java index c47bef4c57..ba35a9b385 100644 --- a/spring-boot/src/main/java/org/springframework/boot/logging/LoggingApplicationListener.java +++ b/spring-boot/src/main/java/org/springframework/boot/logging/LoggingApplicationListener.java @@ -141,7 +141,19 @@ public class LoggingApplicationListener implements SmartApplicationListener { * {@link Environment} and the classpath. */ protected void initialize(ConfigurableEnvironment environment, ClassLoader classLoader) { + initializeEarlyLoggingLevel(environment); + cleanLogTempProperty(); + LoggingSystem system = LoggingSystem.get(classLoader); + boolean systemEnvironmentChanged = mapSystemPropertiesFromSpring(environment); + if (systemEnvironmentChanged) { + // Re-initialize the defaults in case the system Environment changed + system.beforeInitialize(); + } + initializeSystem(environment, system); + initializeFinalLoggingLevels(environment, system); + } + private void initializeEarlyLoggingLevel(ConfigurableEnvironment environment) { if (this.parseArgs && this.springBootLogging == null) { if (environment.containsProperty("debug")) { this.springBootLogging = LogLevel.DEBUG; @@ -150,7 +162,9 @@ public class LoggingApplicationListener implements SmartApplicationListener { this.springBootLogging = LogLevel.TRACE; } } + } + private void cleanLogTempProperty() { // Logback won't read backslashes so add a clean path for it to use if (!StringUtils.hasLength(System.getProperty("LOG_TEMP"))) { String path = System.getProperty("java.io.tmpdir"); @@ -160,24 +174,24 @@ public class LoggingApplicationListener implements SmartApplicationListener { } System.setProperty("LOG_TEMP", path); } + } - boolean environmentChanged = false; + private boolean mapSystemPropertiesFromSpring(Environment environment) { + boolean changed = false; for (Map.Entry mapping : ENVIRONMENT_SYSTEM_PROPERTY_MAPPING .entrySet()) { - if (environment.containsProperty(mapping.getKey())) { - System.setProperty(mapping.getValue(), - environment.getProperty(mapping.getKey())); - environmentChanged = true; + String springName = mapping.getKey(); + String systemName = mapping.getValue(); + if (environment.containsProperty(springName)) { + System.setProperty(systemName, environment.getProperty(springName)); + changed = true; } } + return changed; + } - LoggingSystem system = LoggingSystem.get(classLoader); - - if (environmentChanged) { - // Re-initialize the defaults in case the Environment changed - system.beforeInitialize(); - } - // User specified configuration + private void initializeSystem(ConfigurableEnvironment environment, + LoggingSystem system) { if (environment.containsProperty("logging.config")) { String value = environment.getProperty("logging.config"); try { @@ -185,22 +199,23 @@ public class LoggingApplicationListener implements SmartApplicationListener { system.initialize(value); } catch (Exception ex) { - this.logger - .warn("Logging environment value '" - + value - + "' cannot be opened and will be ignored (using default location instead)"); + this.logger.warn("Logging environment value '" + value + + "' cannot be opened and will be ignored " + + "(using default location instead)"); system.initialize(); } } else { system.initialize(); } + } + private void initializeFinalLoggingLevels(ConfigurableEnvironment environment, + LoggingSystem system) { if (this.springBootLogging != null) { initializeLogLevel(system, this.springBootLogging); } setLogLevels(system, environment); - } public void setLogLevels(LoggingSystem system, Environment environment) { diff --git a/spring-boot/src/test/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContextTests.java b/spring-boot/src/test/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContextTests.java index 29d54dbbb4..37cbb57924 100644 --- a/spring-boot/src/test/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContextTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContextTests.java @@ -457,4 +457,5 @@ public class EmbeddedWebApplicationContextTests { } } + }