pull/26813/head
Phillip Webb 3 years ago
parent 87d35250a5
commit be23a29651

@ -30,12 +30,6 @@ import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
*/ */
public abstract class AbstractBeansOfTypeDatabaseInitializerDetector implements DatabaseInitializerDetector { public abstract class AbstractBeansOfTypeDatabaseInitializerDetector implements DatabaseInitializerDetector {
/**
* Returns the bean types that should be detected as being database initializers.
* @return the database initializer bean types
*/
protected abstract Set<Class<?>> getDatabaseInitializerBeanTypes();
@Override @Override
public Set<String> detect(ConfigurableListableBeanFactory beanFactory) { public Set<String> detect(ConfigurableListableBeanFactory beanFactory) {
try { try {
@ -47,4 +41,10 @@ public abstract class AbstractBeansOfTypeDatabaseInitializerDetector implements
} }
} }
/**
* Returns the bean types that should be detected as being database initializers.
* @return the database initializer bean types
*/
protected abstract Set<Class<?>> getDatabaseInitializerBeanTypes();
} }

@ -32,13 +32,6 @@ import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
public abstract class AbstractBeansOfTypeDependsOnDatabaseInitializationDetector public abstract class AbstractBeansOfTypeDependsOnDatabaseInitializationDetector
implements DependsOnDatabaseInitializationDetector { implements DependsOnDatabaseInitializationDetector {
/**
* Returns the bean types that should be detected as depending on database
* initialization.
* @return the database initialization dependent bean types
*/
protected abstract Set<Class<?>> getDependsOnDatabaseInitializationBeanTypes();
@Override @Override
public Set<String> detect(ConfigurableListableBeanFactory beanFactory) { public Set<String> detect(ConfigurableListableBeanFactory beanFactory) {
try { try {
@ -50,4 +43,11 @@ public abstract class AbstractBeansOfTypeDependsOnDatabaseInitializationDetector
} }
} }
/**
* Returns the bean types that should be detected as depending on database
* initialization.
* @return the database initialization dependent bean types
*/
protected abstract Set<Class<?>> getDependsOnDatabaseInitializationBeanTypes();
} }

@ -16,9 +16,11 @@
package org.springframework.boot.sql.init.dependency; package org.springframework.boot.sql.init.dependency;
import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
@ -65,16 +67,23 @@ public class DatabaseInitializationDependencyConfigurer implements ImportBeanDef
@Override @Override
public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) { public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) {
if (registry.containsBeanDefinition(DependsOnDatabaseInitializationPostProcessor.class.getName())) { String name = DependsOnDatabaseInitializationPostProcessor.class.getName();
return; if (!registry.containsBeanDefinition(name)) {
BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(
DependsOnDatabaseInitializationPostProcessor.class,
this::createDependsOnDatabaseInitializationPostProcessor);
registry.registerBeanDefinition(name, builder.getBeanDefinition());
}
} }
registry.registerBeanDefinition(DependsOnDatabaseInitializationPostProcessor.class.getName(),
BeanDefinitionBuilder private DependsOnDatabaseInitializationPostProcessor createDependsOnDatabaseInitializationPostProcessor() {
.genericBeanDefinition(DependsOnDatabaseInitializationPostProcessor.class, return new DependsOnDatabaseInitializationPostProcessor(this.environment);
() -> new DependsOnDatabaseInitializationPostProcessor(this.environment))
.getBeanDefinition());
} }
/**
* {@link BeanFactoryPostProcessor} used to configure database initialization
* dependency relationships.
*/
static class DependsOnDatabaseInitializationPostProcessor implements BeanFactoryPostProcessor { static class DependsOnDatabaseInitializationPostProcessor implements BeanFactoryPostProcessor {
private final Environment environment; private final Environment environment;
@ -85,58 +94,55 @@ public class DatabaseInitializationDependencyConfigurer implements ImportBeanDef
@Override @Override
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) { public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) {
Set<String> detectedDatabaseInitializers = detectDatabaseInitializers(beanFactory); Set<String> initializerBeanNames = detectInitializerBeanNames(beanFactory);
if (detectedDatabaseInitializers.isEmpty()) { if (initializerBeanNames.isEmpty()) {
return; return;
} }
for (String dependentDefinitionName : detectDependsOnDatabaseInitialization(beanFactory, for (String dependsOnInitializationBeanNames : detectDependsOnInitializationBeanNames(beanFactory)) {
this.environment)) { BeanDefinition definition = getBeanDefinition(dependsOnInitializationBeanNames, beanFactory);
BeanDefinition definition = getBeanDefinition(dependentDefinitionName, beanFactory); definition.setDependsOn(merge(definition.getDependsOn(), initializerBeanNames));
String[] dependencies = definition.getDependsOn();
for (String dependencyName : detectedDatabaseInitializers) {
dependencies = StringUtils.addStringToArray(dependencies, dependencyName);
} }
definition.setDependsOn(dependencies);
} }
private String[] merge(String[] source, Set<String> additional) {
Set<String> result = new LinkedHashSet<>((source != null) ? Arrays.asList(source) : Collections.emptySet());
result.addAll(additional);
return StringUtils.toStringArray(result);
} }
private Set<String> detectDatabaseInitializers(ConfigurableListableBeanFactory beanFactory) { private Set<String> detectInitializerBeanNames(ConfigurableListableBeanFactory beanFactory) {
List<DatabaseInitializerDetector> detectors = instantiateDetectors(beanFactory, this.environment, List<DatabaseInitializerDetector> detectors = getDetectors(beanFactory, DatabaseInitializerDetector.class);
DatabaseInitializerDetector.class); Set<String> beanNames = new HashSet<>();
Set<String> detected = new HashSet<>();
for (DatabaseInitializerDetector detector : detectors) { for (DatabaseInitializerDetector detector : detectors) {
for (String initializerName : detector.detect(beanFactory)) { for (String beanName : detector.detect(beanFactory)) {
detected.add(initializerName); BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName);
beanFactory.getBeanDefinition(initializerName) beanDefinition.setAttribute(DatabaseInitializerDetector.class.getName(),
.setAttribute(DatabaseInitializerDetector.class.getName(), detector.getClass().getName()); detector.getClass().getName());
beanNames.add(beanName);
} }
} }
detected = Collections.unmodifiableSet(detected); beanNames = Collections.unmodifiableSet(beanNames);
for (DatabaseInitializerDetector detector : detectors) { for (DatabaseInitializerDetector detector : detectors) {
detector.detectionComplete(beanFactory, detected); detector.detectionComplete(beanFactory, beanNames);
} }
return detected; return beanNames;
} }
private Collection<String> detectDependsOnDatabaseInitialization(ConfigurableListableBeanFactory beanFactory, private Collection<String> detectDependsOnInitializationBeanNames(ConfigurableListableBeanFactory beanFactory) {
Environment environment) { List<DependsOnDatabaseInitializationDetector> detectors = getDetectors(beanFactory,
List<DependsOnDatabaseInitializationDetector> detectors = instantiateDetectors(beanFactory, environment,
DependsOnDatabaseInitializationDetector.class); DependsOnDatabaseInitializationDetector.class);
Set<String> dependentUponDatabaseInitialization = new HashSet<>(); Set<String> beanNames = new HashSet<>();
for (DependsOnDatabaseInitializationDetector detector : detectors) { for (DependsOnDatabaseInitializationDetector detector : detectors) {
dependentUponDatabaseInitialization.addAll(detector.detect(beanFactory)); beanNames.addAll(detector.detect(beanFactory));
} }
return dependentUponDatabaseInitialization; return beanNames;
} }
private <T> List<T> instantiateDetectors(ConfigurableListableBeanFactory beanFactory, Environment environment, private <T> List<T> getDetectors(ConfigurableListableBeanFactory beanFactory, Class<T> type) {
Class<T> detectorType) { List<String> names = SpringFactoriesLoader.loadFactoryNames(type, beanFactory.getBeanClassLoader());
List<String> detectorNames = SpringFactoriesLoader.loadFactoryNames(detectorType, Instantiator<T> instantiator = new Instantiator<>(type,
beanFactory.getBeanClassLoader()); (availableParameters) -> availableParameters.add(Environment.class, this.environment));
Instantiator<T> instantiator = new Instantiator<>(detectorType, return instantiator.instantiate(names);
(availableParameters) -> availableParameters.add(Environment.class, environment));
List<T> detectors = instantiator.instantiate(detectorNames);
return detectors;
} }
private static BeanDefinition getBeanDefinition(String beanName, ConfigurableListableBeanFactory beanFactory) { private static BeanDefinition getBeanDefinition(String beanName, ConfigurableListableBeanFactory beanFactory) {

@ -33,7 +33,6 @@ import java.util.stream.Collectors;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.mockito.Mockito;
import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
@ -47,6 +46,7 @@ import org.springframework.mock.env.MockEnvironment;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -59,16 +59,12 @@ class DatabaseInitializationDependencyConfigurerTests {
private final ConfigurableEnvironment environment = new MockEnvironment(); private final ConfigurableEnvironment environment = new MockEnvironment();
DatabaseInitializerDetector databaseInitializerDetector = MockedDatabaseInitializerDetector.mock;
DependsOnDatabaseInitializationDetector dependsOnDatabaseInitializationDetector = MockedDependsOnDatabaseInitializationDetector.mock;
@TempDir @TempDir
File temp; File temp;
@BeforeEach @BeforeEach
void resetMocks() { void resetMocks() {
reset(MockedDatabaseInitializerDetector.mock, MockedDependsOnDatabaseInitializationDetector.mock); reset(MockDatabaseInitializerDetector.instance, MockedDependsOnDatabaseInitializationDetector.instance);
} }
@Test @Test
@ -89,19 +85,19 @@ class DatabaseInitializationDependencyConfigurerTests {
void whenDependenciesAreConfiguredThenBeansThatDependUponDatabaseInitializationDependUponDetectedDatabaseInitializers() { void whenDependenciesAreConfiguredThenBeansThatDependUponDatabaseInitializationDependUponDetectedDatabaseInitializers() {
BeanDefinition alpha = BeanDefinitionBuilder.genericBeanDefinition(String.class).getBeanDefinition(); BeanDefinition alpha = BeanDefinitionBuilder.genericBeanDefinition(String.class).getBeanDefinition();
BeanDefinition bravo = BeanDefinitionBuilder.genericBeanDefinition(String.class).getBeanDefinition(); BeanDefinition bravo = BeanDefinitionBuilder.genericBeanDefinition(String.class).getBeanDefinition();
performDetection(Arrays.asList(MockedDatabaseInitializerDetector.class, performDetection(Arrays.asList(MockDatabaseInitializerDetector.class,
MockedDependsOnDatabaseInitializationDetector.class), (context) -> { MockedDependsOnDatabaseInitializationDetector.class), (context) -> {
context.registerBeanDefinition("alpha", alpha); context.registerBeanDefinition("alpha", alpha);
context.registerBeanDefinition("bravo", bravo); context.registerBeanDefinition("bravo", bravo);
given(this.databaseInitializerDetector.detect(context.getBeanFactory())) given(MockDatabaseInitializerDetector.instance.detect(context.getBeanFactory()))
.willReturn(Collections.singleton("alpha")); .willReturn(Collections.singleton("alpha"));
given(this.dependsOnDatabaseInitializationDetector.detect(context.getBeanFactory())) given(MockedDependsOnDatabaseInitializationDetector.instance.detect(context.getBeanFactory()))
.willReturn(Collections.singleton("bravo")); .willReturn(Collections.singleton("bravo"));
context.refresh(); context.refresh();
assertThat(alpha.getAttribute(DatabaseInitializerDetector.class.getName())) assertThat(alpha.getAttribute(DatabaseInitializerDetector.class.getName()))
.isEqualTo(MockedDatabaseInitializerDetector.class.getName()); .isEqualTo(MockDatabaseInitializerDetector.class.getName());
assertThat(bravo.getAttribute(DatabaseInitializerDetector.class.getName())).isNull(); assertThat(bravo.getAttribute(DatabaseInitializerDetector.class.getName())).isNull();
verify(this.databaseInitializerDetector).detectionComplete(context.getBeanFactory(), verify(MockDatabaseInitializerDetector.instance).detectionComplete(context.getBeanFactory(),
Collections.singleton("alpha")); Collections.singleton("alpha"));
assertThat(bravo.getDependsOn()).containsExactly("alpha"); assertThat(bravo.getDependsOn()).containsExactly("alpha");
}); });
@ -156,31 +152,31 @@ class DatabaseInitializationDependencyConfigurerTests {
} }
static class MockedDatabaseInitializerDetector implements DatabaseInitializerDetector { static class MockDatabaseInitializerDetector implements DatabaseInitializerDetector {
private static DatabaseInitializerDetector mock = Mockito.mock(DatabaseInitializerDetector.class); private static DatabaseInitializerDetector instance = mock(DatabaseInitializerDetector.class);
@Override @Override
public Set<String> detect(ConfigurableListableBeanFactory beanFactory) { public Set<String> detect(ConfigurableListableBeanFactory beanFactory) {
return MockedDatabaseInitializerDetector.mock.detect(beanFactory); return MockDatabaseInitializerDetector.instance.detect(beanFactory);
} }
@Override @Override
public void detectionComplete(ConfigurableListableBeanFactory beanFactory, public void detectionComplete(ConfigurableListableBeanFactory beanFactory,
Set<String> databaseInitializerNames) { Set<String> databaseInitializerNames) {
mock.detectionComplete(beanFactory, databaseInitializerNames); instance.detectionComplete(beanFactory, databaseInitializerNames);
} }
} }
static class MockedDependsOnDatabaseInitializationDetector implements DependsOnDatabaseInitializationDetector { static class MockedDependsOnDatabaseInitializationDetector implements DependsOnDatabaseInitializationDetector {
private static DependsOnDatabaseInitializationDetector mock = Mockito private static DependsOnDatabaseInitializationDetector instance = mock(
.mock(DependsOnDatabaseInitializationDetector.class); DependsOnDatabaseInitializationDetector.class);
@Override @Override
public Set<String> detect(ConfigurableListableBeanFactory beanFactory) { public Set<String> detect(ConfigurableListableBeanFactory beanFactory) {
return MockedDependsOnDatabaseInitializationDetector.mock.detect(beanFactory); return instance.detect(beanFactory);
} }
} }

Loading…
Cancel
Save