diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/sql/init/dependency/DatabaseInitializationDependencyConfigurer.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/sql/init/dependency/DatabaseInitializationDependencyConfigurer.java index 11cd99ec23..5b58f96416 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/sql/init/dependency/DatabaseInitializationDependencyConfigurer.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/sql/init/dependency/DatabaseInitializationDependencyConfigurer.java @@ -20,8 +20,10 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; import java.util.Set; import org.springframework.beans.factory.BeanFactory; @@ -38,6 +40,7 @@ import org.springframework.core.Ordered; import org.springframework.core.env.Environment; import org.springframework.core.io.support.SpringFactoriesLoader; import org.springframework.core.type.AnnotationMetadata; +import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; /** @@ -100,48 +103,49 @@ public class DatabaseInitializationDependencyConfigurer implements ImportBeanDef @Override public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) { - Set initializerBeanNames = detectInitializerBeanNames(beanFactory); + InitializerBeanNames initializerBeanNames = detectInitializerBeanNames(beanFactory); if (initializerBeanNames.isEmpty()) { return; } - String previousInitializerBeanName = null; - for (String initializerBeanName : initializerBeanNames) { - BeanDefinition beanDefinition = getBeanDefinition(initializerBeanName, beanFactory); - beanDefinition.setDependsOn(merge(beanDefinition.getDependsOn(), previousInitializerBeanName)); - previousInitializerBeanName = initializerBeanName; + Set previousInitializerBeanNamesBatch = null; + for (Set initializerBeanNamesBatch : initializerBeanNames.batchedBeanNames()) { + for (String initializerBeanName : initializerBeanNamesBatch) { + BeanDefinition beanDefinition = getBeanDefinition(initializerBeanName, beanFactory); + beanDefinition + .setDependsOn(merge(beanDefinition.getDependsOn(), previousInitializerBeanNamesBatch)); + } + previousInitializerBeanNamesBatch = initializerBeanNamesBatch; } for (String dependsOnInitializationBeanNames : detectDependsOnInitializationBeanNames(beanFactory)) { BeanDefinition beanDefinition = getBeanDefinition(dependsOnInitializationBeanNames, beanFactory); - beanDefinition.setDependsOn(merge(beanDefinition.getDependsOn(), initializerBeanNames)); + beanDefinition.setDependsOn(merge(beanDefinition.getDependsOn(), initializerBeanNames.beanNames())); } } - private String[] merge(String[] source, String additional) { - return merge(source, (additional != null) ? Collections.singleton(additional) : Collections.emptySet()); - } - private String[] merge(String[] source, Set additional) { + if (CollectionUtils.isEmpty(additional)) { + return source; + } Set result = new LinkedHashSet<>((source != null) ? Arrays.asList(source) : Collections.emptySet()); result.addAll(additional); return StringUtils.toStringArray(result); } - private Set detectInitializerBeanNames(ConfigurableListableBeanFactory beanFactory) { + private InitializerBeanNames detectInitializerBeanNames(ConfigurableListableBeanFactory beanFactory) { List detectors = getDetectors(beanFactory, DatabaseInitializerDetector.class); - Set beanNames = new LinkedHashSet<>(); + InitializerBeanNames initializerBeanNames = new InitializerBeanNames(); for (DatabaseInitializerDetector detector : detectors) { for (String beanName : detector.detect(beanFactory)) { BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName); beanDefinition.setAttribute(DatabaseInitializerDetector.class.getName(), detector.getClass().getName()); - beanNames.add(beanName); + initializerBeanNames.detected(detector, beanName); } } - beanNames = Collections.unmodifiableSet(beanNames); for (DatabaseInitializerDetector detector : detectors) { - detector.detectionComplete(beanFactory, beanNames); + detector.detectionComplete(beanFactory, initializerBeanNames.beanNames()); } - return beanNames; + return initializerBeanNames; } private Collection detectDependsOnInitializationBeanNames(ConfigurableListableBeanFactory beanFactory) { @@ -174,6 +178,31 @@ public class DatabaseInitializationDependencyConfigurer implements ImportBeanDef } } + static class InitializerBeanNames { + + private final Map> byDetectorBeanNames = new LinkedHashMap<>(); + + private final Set beanNames = new LinkedHashSet<>(); + + private void detected(DatabaseInitializerDetector detector, String beanName) { + this.byDetectorBeanNames.computeIfAbsent(detector, (key) -> new LinkedHashSet<>()).add(beanName); + this.beanNames.add(beanName); + } + + private boolean isEmpty() { + return this.beanNames.isEmpty(); + } + + private Iterable> batchedBeanNames() { + return this.byDetectorBeanNames.values(); + } + + private Set beanNames() { + return Collections.unmodifiableSet(this.beanNames); + } + + } + } } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/sql/init/dependency/DatabaseInitializationDependencyConfigurerTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/sql/init/dependency/DatabaseInitializationDependencyConfigurerTests.java index 4a8fc5859f..55427a21db 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/sql/init/dependency/DatabaseInitializationDependencyConfigurerTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/sql/init/dependency/DatabaseInitializationDependencyConfigurerTests.java @@ -26,6 +26,7 @@ import java.util.Collections; import java.util.Enumeration; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Properties; @@ -71,7 +72,8 @@ class DatabaseInitializationDependencyConfigurerTests { @BeforeEach void resetMocks() { - reset(MockDatabaseInitializerDetector.instance, OrderedMockDatabaseInitializerDetector.instance, + reset(MockDatabaseInitializerDetector.instance, OrderedNearLowestMockDatabaseInitializerDetector.instance, + OrderedLowestMockDatabaseInitializerDetector.instance, MockedDependsOnDatabaseInitializationDetector.instance); } @@ -94,8 +96,7 @@ class DatabaseInitializationDependencyConfigurerTests { context.refresh(); assertThat(DependsOnCaptor.dependsOn).hasEntrySatisfying("bravo", (dependencies) -> assertThat(dependencies).containsExactly("alpha")); - assertThat(DependsOnCaptor.dependsOn).hasEntrySatisfying("alpha", - (dependencies) -> assertThat(dependencies).isEmpty()); + assertThat(DependsOnCaptor.dependsOn).doesNotContainKey("alpha"); }); } @@ -140,24 +141,34 @@ class DatabaseInitializationDependencyConfigurerTests { @Test void whenDependenciesAreConfiguredDetectedDatabaseInitializersAreInitializedInCorrectOrder() { BeanDefinition alpha = BeanDefinitionBuilder.genericBeanDefinition(String.class).getBeanDefinition(); - BeanDefinition bravo = BeanDefinitionBuilder.genericBeanDefinition(String.class).getBeanDefinition(); + BeanDefinition bravo1 = BeanDefinitionBuilder.genericBeanDefinition(String.class).getBeanDefinition(); + BeanDefinition bravo2 = BeanDefinitionBuilder.genericBeanDefinition(String.class).getBeanDefinition(); BeanDefinition charlie = BeanDefinitionBuilder.genericBeanDefinition(String.class).getBeanDefinition(); - performDetection(Arrays.asList(MockDatabaseInitializerDetector.class, - OrderedMockDatabaseInitializerDetector.class, MockedDependsOnDatabaseInitializationDetector.class), + BeanDefinition delta = BeanDefinitionBuilder.genericBeanDefinition(String.class).getBeanDefinition(); + performDetection( + Arrays.asList(MockDatabaseInitializerDetector.class, OrderedLowestMockDatabaseInitializerDetector.class, + OrderedNearLowestMockDatabaseInitializerDetector.class, + MockedDependsOnDatabaseInitializationDetector.class), (context) -> { given(MockDatabaseInitializerDetector.instance.detect(context.getBeanFactory())) .willReturn(Collections.singleton("alpha")); - given(OrderedMockDatabaseInitializerDetector.instance.detect(context.getBeanFactory())) - .willReturn(Collections.singleton("bravo")); + given(OrderedNearLowestMockDatabaseInitializerDetector.instance.detect(context.getBeanFactory())) + .willReturn(new LinkedHashSet<>(Arrays.asList("bravo1", "bravo2"))); + given(OrderedLowestMockDatabaseInitializerDetector.instance.detect(context.getBeanFactory())) + .willReturn(new LinkedHashSet<>(Arrays.asList("charlie"))); given(MockedDependsOnDatabaseInitializationDetector.instance.detect(context.getBeanFactory())) - .willReturn(Collections.singleton("charlie")); + .willReturn(Collections.singleton("delta")); context.registerBeanDefinition("alpha", alpha); - context.registerBeanDefinition("bravo", bravo); + context.registerBeanDefinition("bravo1", bravo1); + context.registerBeanDefinition("bravo2", bravo2); context.registerBeanDefinition("charlie", charlie); + context.registerBeanDefinition("delta", delta); context.register(DependencyConfigurerConfiguration.class); context.refresh(); - assertThat(charlie.getDependsOn()).containsExactly("alpha", "bravo"); - assertThat(bravo.getDependsOn()).containsExactly("alpha"); + assertThat(delta.getDependsOn()).containsExactlyInAnyOrder("alpha", "bravo1", "bravo2", "charlie"); + assertThat(charlie.getDependsOn()).containsExactly("bravo1", "bravo2"); + assertThat(bravo1.getDependsOn()).containsExactly("alpha"); + assertThat(bravo2.getDependsOn()).containsExactly("alpha"); assertThat(alpha.getDependsOn()).isNullOrEmpty(); }); } @@ -227,7 +238,7 @@ class DatabaseInitializationDependencyConfigurerTests { } - static class OrderedMockDatabaseInitializerDetector implements DatabaseInitializerDetector { + static class OrderedLowestMockDatabaseInitializerDetector implements DatabaseInitializerDetector { private static DatabaseInitializerDetector instance = mock(DatabaseInitializerDetector.class); @@ -243,6 +254,22 @@ class DatabaseInitializationDependencyConfigurerTests { } + static class OrderedNearLowestMockDatabaseInitializerDetector implements DatabaseInitializerDetector { + + private static DatabaseInitializerDetector instance = mock(DatabaseInitializerDetector.class); + + @Override + public Set detect(ConfigurableListableBeanFactory beanFactory) { + return instance.detect(beanFactory); + } + + @Override + public int getOrder() { + return Ordered.LOWEST_PRECEDENCE - 100; + } + + } + static class MockedDependsOnDatabaseInitializationDetector implements DependsOnDatabaseInitializationDetector { private static DependsOnDatabaseInitializationDetector instance = mock(