diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/sql/init/dependency/AbstractBeansOfTypeDatabaseInitializerDetector.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/sql/init/dependency/AbstractBeansOfTypeDatabaseInitializerDetector.java index 1f634cc75c..93ed501208 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/sql/init/dependency/AbstractBeansOfTypeDatabaseInitializerDetector.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/sql/init/dependency/AbstractBeansOfTypeDatabaseInitializerDetector.java @@ -30,12 +30,6 @@ import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; */ public abstract class AbstractBeansOfTypeDatabaseInitializerDetector implements DatabaseInitializerDetector { - /** - * Returns the bean types that should be detected as being database initializers. - * @return the database initializer bean types - */ - protected abstract Set> getDatabaseInitializerBeanTypes(); - @Override public Set detect(ConfigurableListableBeanFactory beanFactory) { try { @@ -47,4 +41,10 @@ public abstract class AbstractBeansOfTypeDatabaseInitializerDetector implements } } + /** + * Returns the bean types that should be detected as being database initializers. + * @return the database initializer bean types + */ + protected abstract Set> getDatabaseInitializerBeanTypes(); + } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/sql/init/dependency/AbstractBeansOfTypeDependsOnDatabaseInitializationDetector.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/sql/init/dependency/AbstractBeansOfTypeDependsOnDatabaseInitializationDetector.java index dc5722f186..2814036c63 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/sql/init/dependency/AbstractBeansOfTypeDependsOnDatabaseInitializationDetector.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/sql/init/dependency/AbstractBeansOfTypeDependsOnDatabaseInitializationDetector.java @@ -32,13 +32,6 @@ import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; public abstract class AbstractBeansOfTypeDependsOnDatabaseInitializationDetector implements DependsOnDatabaseInitializationDetector { - /** - * Returns the bean types that should be detected as depending on database - * initialization. - * @return the database initialization dependent bean types - */ - protected abstract Set> getDependsOnDatabaseInitializationBeanTypes(); - @Override public Set detect(ConfigurableListableBeanFactory beanFactory) { try { @@ -50,4 +43,11 @@ public abstract class AbstractBeansOfTypeDependsOnDatabaseInitializationDetector } } + /** + * Returns the bean types that should be detected as depending on database + * initialization. + * @return the database initialization dependent bean types + */ + protected abstract Set> getDependsOnDatabaseInitializationBeanTypes(); + } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/sql/init/dependency/DatabaseInitializationDependencyConfigurer.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/sql/init/dependency/DatabaseInitializationDependencyConfigurer.java index aa4de6cb59..41196a28a0 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/sql/init/dependency/DatabaseInitializationDependencyConfigurer.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/sql/init/dependency/DatabaseInitializationDependencyConfigurer.java @@ -16,9 +16,11 @@ package org.springframework.boot.sql.init.dependency; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.List; import java.util.Set; @@ -65,16 +67,23 @@ public class DatabaseInitializationDependencyConfigurer implements ImportBeanDef @Override public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) { - if (registry.containsBeanDefinition(DependsOnDatabaseInitializationPostProcessor.class.getName())) { - return; + String name = DependsOnDatabaseInitializationPostProcessor.class.getName(); + if (!registry.containsBeanDefinition(name)) { + BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition( + DependsOnDatabaseInitializationPostProcessor.class, + this::createDependsOnDatabaseInitializationPostProcessor); + registry.registerBeanDefinition(name, builder.getBeanDefinition()); } - registry.registerBeanDefinition(DependsOnDatabaseInitializationPostProcessor.class.getName(), - BeanDefinitionBuilder - .genericBeanDefinition(DependsOnDatabaseInitializationPostProcessor.class, - () -> new DependsOnDatabaseInitializationPostProcessor(this.environment)) - .getBeanDefinition()); } + private DependsOnDatabaseInitializationPostProcessor createDependsOnDatabaseInitializationPostProcessor() { + return new DependsOnDatabaseInitializationPostProcessor(this.environment); + } + + /** + * {@link BeanFactoryPostProcessor} used to configure database initialization + * dependency relationships. + */ static class DependsOnDatabaseInitializationPostProcessor implements BeanFactoryPostProcessor { private final Environment environment; @@ -85,58 +94,55 @@ public class DatabaseInitializationDependencyConfigurer implements ImportBeanDef @Override public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) { - Set detectedDatabaseInitializers = detectDatabaseInitializers(beanFactory); - if (detectedDatabaseInitializers.isEmpty()) { + Set initializerBeanNames = detectInitializerBeanNames(beanFactory); + if (initializerBeanNames.isEmpty()) { return; } - for (String dependentDefinitionName : detectDependsOnDatabaseInitialization(beanFactory, - this.environment)) { - BeanDefinition definition = getBeanDefinition(dependentDefinitionName, beanFactory); - String[] dependencies = definition.getDependsOn(); - for (String dependencyName : detectedDatabaseInitializers) { - dependencies = StringUtils.addStringToArray(dependencies, dependencyName); - } - definition.setDependsOn(dependencies); + for (String dependsOnInitializationBeanNames : detectDependsOnInitializationBeanNames(beanFactory)) { + BeanDefinition definition = getBeanDefinition(dependsOnInitializationBeanNames, beanFactory); + definition.setDependsOn(merge(definition.getDependsOn(), initializerBeanNames)); } } - private Set detectDatabaseInitializers(ConfigurableListableBeanFactory beanFactory) { - List detectors = instantiateDetectors(beanFactory, this.environment, - DatabaseInitializerDetector.class); - Set detected = new HashSet<>(); + private String[] merge(String[] source, Set additional) { + Set result = new LinkedHashSet<>((source != null) ? Arrays.asList(source) : Collections.emptySet()); + result.addAll(additional); + return StringUtils.toStringArray(result); + } + + private Set detectInitializerBeanNames(ConfigurableListableBeanFactory beanFactory) { + List detectors = getDetectors(beanFactory, DatabaseInitializerDetector.class); + Set beanNames = new HashSet<>(); for (DatabaseInitializerDetector detector : detectors) { - for (String initializerName : detector.detect(beanFactory)) { - detected.add(initializerName); - beanFactory.getBeanDefinition(initializerName) - .setAttribute(DatabaseInitializerDetector.class.getName(), detector.getClass().getName()); + for (String beanName : detector.detect(beanFactory)) { + BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName); + beanDefinition.setAttribute(DatabaseInitializerDetector.class.getName(), + detector.getClass().getName()); + beanNames.add(beanName); } } - detected = Collections.unmodifiableSet(detected); + beanNames = Collections.unmodifiableSet(beanNames); for (DatabaseInitializerDetector detector : detectors) { - detector.detectionComplete(beanFactory, detected); + detector.detectionComplete(beanFactory, beanNames); } - return detected; + return beanNames; } - private Collection detectDependsOnDatabaseInitialization(ConfigurableListableBeanFactory beanFactory, - Environment environment) { - List detectors = instantiateDetectors(beanFactory, environment, + private Collection detectDependsOnInitializationBeanNames(ConfigurableListableBeanFactory beanFactory) { + List detectors = getDetectors(beanFactory, DependsOnDatabaseInitializationDetector.class); - Set dependentUponDatabaseInitialization = new HashSet<>(); + Set beanNames = new HashSet<>(); for (DependsOnDatabaseInitializationDetector detector : detectors) { - dependentUponDatabaseInitialization.addAll(detector.detect(beanFactory)); + beanNames.addAll(detector.detect(beanFactory)); } - return dependentUponDatabaseInitialization; + return beanNames; } - private List instantiateDetectors(ConfigurableListableBeanFactory beanFactory, Environment environment, - Class detectorType) { - List detectorNames = SpringFactoriesLoader.loadFactoryNames(detectorType, - beanFactory.getBeanClassLoader()); - Instantiator instantiator = new Instantiator<>(detectorType, - (availableParameters) -> availableParameters.add(Environment.class, environment)); - List detectors = instantiator.instantiate(detectorNames); - return detectors; + private List getDetectors(ConfigurableListableBeanFactory beanFactory, Class type) { + List names = SpringFactoriesLoader.loadFactoryNames(type, beanFactory.getBeanClassLoader()); + Instantiator instantiator = new Instantiator<>(type, + (availableParameters) -> availableParameters.add(Environment.class, this.environment)); + return instantiator.instantiate(names); } private static BeanDefinition getBeanDefinition(String beanName, ConfigurableListableBeanFactory beanFactory) { diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/sql/init/dependency/DatabaseInitializationDependencyConfigurerTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/sql/init/dependency/DatabaseInitializationDependencyConfigurerTests.java index 7f5c2a5609..7d40abd552 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/sql/init/dependency/DatabaseInitializationDependencyConfigurerTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/sql/init/dependency/DatabaseInitializationDependencyConfigurerTests.java @@ -33,7 +33,6 @@ import java.util.stream.Collectors; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; -import org.mockito.Mockito; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; @@ -47,6 +46,7 @@ import org.springframework.mock.env.MockEnvironment; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; @@ -59,16 +59,12 @@ class DatabaseInitializationDependencyConfigurerTests { private final ConfigurableEnvironment environment = new MockEnvironment(); - DatabaseInitializerDetector databaseInitializerDetector = MockedDatabaseInitializerDetector.mock; - - DependsOnDatabaseInitializationDetector dependsOnDatabaseInitializationDetector = MockedDependsOnDatabaseInitializationDetector.mock; - @TempDir File temp; @BeforeEach void resetMocks() { - reset(MockedDatabaseInitializerDetector.mock, MockedDependsOnDatabaseInitializationDetector.mock); + reset(MockDatabaseInitializerDetector.instance, MockedDependsOnDatabaseInitializationDetector.instance); } @Test @@ -89,19 +85,19 @@ class DatabaseInitializationDependencyConfigurerTests { void whenDependenciesAreConfiguredThenBeansThatDependUponDatabaseInitializationDependUponDetectedDatabaseInitializers() { BeanDefinition alpha = BeanDefinitionBuilder.genericBeanDefinition(String.class).getBeanDefinition(); BeanDefinition bravo = BeanDefinitionBuilder.genericBeanDefinition(String.class).getBeanDefinition(); - performDetection(Arrays.asList(MockedDatabaseInitializerDetector.class, + performDetection(Arrays.asList(MockDatabaseInitializerDetector.class, MockedDependsOnDatabaseInitializationDetector.class), (context) -> { context.registerBeanDefinition("alpha", alpha); context.registerBeanDefinition("bravo", bravo); - given(this.databaseInitializerDetector.detect(context.getBeanFactory())) + given(MockDatabaseInitializerDetector.instance.detect(context.getBeanFactory())) .willReturn(Collections.singleton("alpha")); - given(this.dependsOnDatabaseInitializationDetector.detect(context.getBeanFactory())) + given(MockedDependsOnDatabaseInitializationDetector.instance.detect(context.getBeanFactory())) .willReturn(Collections.singleton("bravo")); context.refresh(); assertThat(alpha.getAttribute(DatabaseInitializerDetector.class.getName())) - .isEqualTo(MockedDatabaseInitializerDetector.class.getName()); + .isEqualTo(MockDatabaseInitializerDetector.class.getName()); assertThat(bravo.getAttribute(DatabaseInitializerDetector.class.getName())).isNull(); - verify(this.databaseInitializerDetector).detectionComplete(context.getBeanFactory(), + verify(MockDatabaseInitializerDetector.instance).detectionComplete(context.getBeanFactory(), Collections.singleton("alpha")); assertThat(bravo.getDependsOn()).containsExactly("alpha"); }); @@ -156,31 +152,31 @@ class DatabaseInitializationDependencyConfigurerTests { } - static class MockedDatabaseInitializerDetector implements DatabaseInitializerDetector { + static class MockDatabaseInitializerDetector implements DatabaseInitializerDetector { - private static DatabaseInitializerDetector mock = Mockito.mock(DatabaseInitializerDetector.class); + private static DatabaseInitializerDetector instance = mock(DatabaseInitializerDetector.class); @Override public Set detect(ConfigurableListableBeanFactory beanFactory) { - return MockedDatabaseInitializerDetector.mock.detect(beanFactory); + return MockDatabaseInitializerDetector.instance.detect(beanFactory); } @Override public void detectionComplete(ConfigurableListableBeanFactory beanFactory, Set databaseInitializerNames) { - mock.detectionComplete(beanFactory, databaseInitializerNames); + instance.detectionComplete(beanFactory, databaseInitializerNames); } } static class MockedDependsOnDatabaseInitializationDetector implements DependsOnDatabaseInitializationDetector { - private static DependsOnDatabaseInitializationDetector mock = Mockito - .mock(DependsOnDatabaseInitializationDetector.class); + private static DependsOnDatabaseInitializationDetector instance = mock( + DependsOnDatabaseInitializationDetector.class); @Override public Set detect(ConfigurableListableBeanFactory beanFactory) { - return MockedDependsOnDatabaseInitializationDetector.mock.detect(beanFactory); + return instance.detect(beanFactory); } }