diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/BootstrapContext.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/BootstrapContext.java index 4155ec535a..197620f1b3 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/BootstrapContext.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/BootstrapContext.java @@ -16,6 +16,8 @@ package org.springframework.boot; +import java.util.function.Supplier; + import org.springframework.context.ApplicationContext; import org.springframework.core.env.Environment; @@ -32,8 +34,8 @@ import org.springframework.core.env.Environment; public interface BootstrapContext { /** - * Return an instance from the context, creating it if it hasn't been accessed - * previously. + * Return an instance from the context if the type has been registered. The instance + * will be created it if it hasn't been accessed previously. * @param the instance type * @param type the instance type * @return the instance managed by the context @@ -41,4 +43,45 @@ public interface BootstrapContext { */ T get(Class type) throws IllegalStateException; + /** + * Return an instance from the context if the type has been registered. The instance + * will be created it if it hasn't been accessed previously. + * @param the instance type + * @param type the instance type + * @param other the instance to use if the type has not been registered + * @return the instance + */ + T getOrElse(Class type, T other); + + /** + * Return an instance from the context if the type has been registered. The instance + * will be created it if it hasn't been accessed previously. + * @param the instance type + * @param type the instance type + * @param other a supplier for the instance to use if the type has not been registered + * @return the instance + */ + T getOrElseSupply(Class type, Supplier other); + + /** + * Return an instance from the context if the type has been registered. The instance + * will be created it if it hasn't been accessed previously. + * @param the instance type + * @param the exception to throw if the type is not registered + * @param type the instance type + * @param exceptionSupplier the supplier which will return the exception to be thrown + * @return the instance managed by the context + * @throws X if the type has not been registered + * @throws IllegalStateException if the type has not been registered + */ + T getOrElseThrow(Class type, Supplier exceptionSupplier) throws X; + + /** + * Return if a registration exists for the given type. + * @param the instance type + * @param type the instance type + * @return {@code true} if the type has already been registered + */ + boolean isRegistered(Class type); + } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/DefaultBootstrapContext.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/DefaultBootstrapContext.java index 79495c0c95..cc73b8ec36 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/DefaultBootstrapContext.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/DefaultBootstrapContext.java @@ -18,6 +18,7 @@ package org.springframework.boot; import java.util.HashMap; import java.util.Map; +import java.util.function.Supplier; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationListener; @@ -83,18 +84,42 @@ public class DefaultBootstrapContext implements ConfigurableBootstrapContext { } @Override - @SuppressWarnings("unchecked") public T get(Class type) throws IllegalStateException { + return getOrElseThrow(type, () -> new IllegalStateException(type.getName() + " has not been registered")); + } + + @Override + public T getOrElse(Class type, T other) { + return getOrElseSupply(type, () -> other); + } + + @Override + public T getOrElseSupply(Class type, Supplier other) { synchronized (this.instanceSuppliers) { InstanceSupplier instanceSupplier = this.instanceSuppliers.get(type); - Assert.state(instanceSupplier != null, () -> type.getName() + " has not been registered"); - T instance = (T) this.instances.get(type); - if (instance == null) { - instance = (T) instanceSupplier.get(this); - this.instances.put(type, instance); + return (instanceSupplier != null) ? getInstance(type, instanceSupplier) : other.get(); + } + } + + @Override + public T getOrElseThrow(Class type, Supplier exceptionSupplier) throws X { + synchronized (this.instanceSuppliers) { + InstanceSupplier instanceSupplier = this.instanceSuppliers.get(type); + if (instanceSupplier == null) { + throw exceptionSupplier.get(); } - return instance; + return getInstance(type, instanceSupplier); + } + } + + @SuppressWarnings("unchecked") + private T getInstance(Class type, InstanceSupplier instanceSupplier) { + T instance = (T) this.instances.get(type); + if (instance == null) { + instance = (T) instanceSupplier.get(this); + this.instances.put(type, instance); } + return instance; } /** diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/DefaultBootstrapContextTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/DefaultBootstrapContextTests.java index e6003dc704..c727f554b6 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/DefaultBootstrapContextTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/DefaultBootstrapContextTests.java @@ -16,6 +16,7 @@ package org.springframework.boot; +import java.io.IOException; import java.util.concurrent.atomic.AtomicInteger; import org.assertj.core.api.AbstractAssert; @@ -29,6 +30,7 @@ import org.springframework.context.ConfigurableApplicationContext; import org.springframework.context.support.StaticApplicationContext; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIOException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalStateException; @@ -150,6 +152,63 @@ class DefaultBootstrapContextTests { assertThat(this.context.get(Integer.class)).isEqualTo(0); } + @Test + void getOrElseWhenNoRegistrationReturnsOther() { + this.context.register(Number.class, InstanceSupplier.of(1)); + assertThat(this.context.getOrElse(Long.class, -1L)).isEqualTo(-1); + } + + @Test + void getOrElseWhenRegisteredAsNullReturnsNull() { + this.context.register(Number.class, InstanceSupplier.of(null)); + assertThat(this.context.getOrElse(Number.class, -1)).isNull(); + } + + @Test + void getOrElseCreatesReturnsOnlyOneInstance() { + this.context.register(Integer.class, InstanceSupplier.from(this.counter::getAndIncrement)); + assertThat(this.context.getOrElse(Integer.class, -1)).isEqualTo(0); + assertThat(this.context.getOrElse(Integer.class, -1)).isEqualTo(0); + } + + @Test + void getOrElseSupplyWhenNoRegistrationReturnsSupplied() { + this.context.register(Number.class, InstanceSupplier.of(1)); + assertThat(this.context.getOrElseSupply(Long.class, () -> -1L)).isEqualTo(-1); + } + + @Test + void getOrElseSupplyWhenRegisteredAsNullReturnsNull() { + this.context.register(Number.class, InstanceSupplier.of(null)); + assertThat(this.context.getOrElseSupply(Number.class, () -> -1L)).isNull(); + } + + @Test + void getOrElseSupplyCreatesOnlyOneInstance() { + this.context.register(Integer.class, InstanceSupplier.from(this.counter::getAndIncrement)); + assertThat(this.context.getOrElseSupply(Integer.class, () -> -1)).isEqualTo(0); + assertThat(this.context.getOrElseSupply(Integer.class, () -> -1)).isEqualTo(0); + } + + @Test + void getOrElseThrowWhenNoRegistrationThrowsSuppliedException() { + this.context.register(Number.class, InstanceSupplier.of(1)); + assertThatIOException().isThrownBy(() -> this.context.getOrElseThrow(Long.class, IOException::new)); + } + + @Test + void getOrElseThrowWhenRegisteredAsNullReturnsNull() { + this.context.register(Number.class, InstanceSupplier.of(null)); + assertThat(this.context.getOrElseThrow(Number.class, RuntimeException::new)).isNull(); + } + + @Test + void getOrElseThrowCreatesOnlyOneInstance() { + this.context.register(Integer.class, InstanceSupplier.from(this.counter::getAndIncrement)); + assertThat(this.context.getOrElseThrow(Integer.class, RuntimeException::new)).isEqualTo(0); + assertThat(this.context.getOrElseThrow(Integer.class, RuntimeException::new)).isEqualTo(0); + } + @Test void closeMulticastsEventToListeners() { TestCloseListener listener = new TestCloseListener();