Add missing registration convenience methods

Update `BootstrapContext` with convenience methods that help if the
type has not been registered.

Closes gh-23438
pull/23456/head
Phillip Webb 4 years ago
parent 0df37302af
commit fde2e440bb

@ -16,6 +16,8 @@
package org.springframework.boot; package org.springframework.boot;
import java.util.function.Supplier;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
import org.springframework.core.env.Environment; import org.springframework.core.env.Environment;
@ -32,8 +34,8 @@ import org.springframework.core.env.Environment;
public interface BootstrapContext { public interface BootstrapContext {
/** /**
* Return an instance from the context, creating it if it hasn't been accessed * Return an instance from the context if the type has been registered. The instance
* previously. * will be created it if it hasn't been accessed previously.
* @param <T> the instance type * @param <T> the instance type
* @param type the instance type * @param type the instance type
* @return the instance managed by the context * @return the instance managed by the context
@ -41,4 +43,45 @@ public interface BootstrapContext {
*/ */
<T> T get(Class<T> type) throws IllegalStateException; <T> T get(Class<T> type) throws IllegalStateException;
/**
* Return an instance from the context if the type has been registered. The instance
* will be created it if it hasn't been accessed previously.
* @param <T> the instance type
* @param type the instance type
* @param other the instance to use if the type has not been registered
* @return the instance
*/
<T> T getOrElse(Class<T> type, T other);
/**
* Return an instance from the context if the type has been registered. The instance
* will be created it if it hasn't been accessed previously.
* @param <T> the instance type
* @param type the instance type
* @param other a supplier for the instance to use if the type has not been registered
* @return the instance
*/
<T> T getOrElseSupply(Class<T> type, Supplier<T> other);
/**
* Return an instance from the context if the type has been registered. The instance
* will be created it if it hasn't been accessed previously.
* @param <T> the instance type
* @param <X> the exception to throw if the type is not registered
* @param type the instance type
* @param exceptionSupplier the supplier which will return the exception to be thrown
* @return the instance managed by the context
* @throws X if the type has not been registered
* @throws IllegalStateException if the type has not been registered
*/
<T, X extends Throwable> T getOrElseThrow(Class<T> type, Supplier<? extends X> exceptionSupplier) throws X;
/**
* Return if a registration exists for the given type.
* @param <T> the instance type
* @param type the instance type
* @return {@code true} if the type has already been registered
*/
<T> boolean isRegistered(Class<T> type);
} }

@ -18,6 +18,7 @@ package org.springframework.boot;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.function.Supplier;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationListener; import org.springframework.context.ApplicationListener;
@ -83,18 +84,42 @@ public class DefaultBootstrapContext implements ConfigurableBootstrapContext {
} }
@Override @Override
@SuppressWarnings("unchecked")
public <T> T get(Class<T> type) throws IllegalStateException { public <T> T get(Class<T> type) throws IllegalStateException {
return getOrElseThrow(type, () -> new IllegalStateException(type.getName() + " has not been registered"));
}
@Override
public <T> T getOrElse(Class<T> type, T other) {
return getOrElseSupply(type, () -> other);
}
@Override
public <T> T getOrElseSupply(Class<T> type, Supplier<T> other) {
synchronized (this.instanceSuppliers) { synchronized (this.instanceSuppliers) {
InstanceSupplier<?> instanceSupplier = this.instanceSuppliers.get(type); InstanceSupplier<?> instanceSupplier = this.instanceSuppliers.get(type);
Assert.state(instanceSupplier != null, () -> type.getName() + " has not been registered"); return (instanceSupplier != null) ? getInstance(type, instanceSupplier) : other.get();
T instance = (T) this.instances.get(type); }
if (instance == null) { }
instance = (T) instanceSupplier.get(this);
this.instances.put(type, instance); @Override
public <T, X extends Throwable> T getOrElseThrow(Class<T> type, Supplier<? extends X> exceptionSupplier) throws X {
synchronized (this.instanceSuppliers) {
InstanceSupplier<?> instanceSupplier = this.instanceSuppliers.get(type);
if (instanceSupplier == null) {
throw exceptionSupplier.get();
} }
return instance; return getInstance(type, instanceSupplier);
}
}
@SuppressWarnings("unchecked")
private <T> T getInstance(Class<T> type, InstanceSupplier<?> instanceSupplier) {
T instance = (T) this.instances.get(type);
if (instance == null) {
instance = (T) instanceSupplier.get(this);
this.instances.put(type, instance);
} }
return instance;
} }
/** /**

@ -16,6 +16,7 @@
package org.springframework.boot; package org.springframework.boot;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import org.assertj.core.api.AbstractAssert; import org.assertj.core.api.AbstractAssert;
@ -29,6 +30,7 @@ import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.support.StaticApplicationContext; import org.springframework.context.support.StaticApplicationContext;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIOException;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
@ -150,6 +152,63 @@ class DefaultBootstrapContextTests {
assertThat(this.context.get(Integer.class)).isEqualTo(0); assertThat(this.context.get(Integer.class)).isEqualTo(0);
} }
@Test
void getOrElseWhenNoRegistrationReturnsOther() {
this.context.register(Number.class, InstanceSupplier.of(1));
assertThat(this.context.getOrElse(Long.class, -1L)).isEqualTo(-1);
}
@Test
void getOrElseWhenRegisteredAsNullReturnsNull() {
this.context.register(Number.class, InstanceSupplier.of(null));
assertThat(this.context.getOrElse(Number.class, -1)).isNull();
}
@Test
void getOrElseCreatesReturnsOnlyOneInstance() {
this.context.register(Integer.class, InstanceSupplier.from(this.counter::getAndIncrement));
assertThat(this.context.getOrElse(Integer.class, -1)).isEqualTo(0);
assertThat(this.context.getOrElse(Integer.class, -1)).isEqualTo(0);
}
@Test
void getOrElseSupplyWhenNoRegistrationReturnsSupplied() {
this.context.register(Number.class, InstanceSupplier.of(1));
assertThat(this.context.getOrElseSupply(Long.class, () -> -1L)).isEqualTo(-1);
}
@Test
void getOrElseSupplyWhenRegisteredAsNullReturnsNull() {
this.context.register(Number.class, InstanceSupplier.of(null));
assertThat(this.context.getOrElseSupply(Number.class, () -> -1L)).isNull();
}
@Test
void getOrElseSupplyCreatesOnlyOneInstance() {
this.context.register(Integer.class, InstanceSupplier.from(this.counter::getAndIncrement));
assertThat(this.context.getOrElseSupply(Integer.class, () -> -1)).isEqualTo(0);
assertThat(this.context.getOrElseSupply(Integer.class, () -> -1)).isEqualTo(0);
}
@Test
void getOrElseThrowWhenNoRegistrationThrowsSuppliedException() {
this.context.register(Number.class, InstanceSupplier.of(1));
assertThatIOException().isThrownBy(() -> this.context.getOrElseThrow(Long.class, IOException::new));
}
@Test
void getOrElseThrowWhenRegisteredAsNullReturnsNull() {
this.context.register(Number.class, InstanceSupplier.of(null));
assertThat(this.context.getOrElseThrow(Number.class, RuntimeException::new)).isNull();
}
@Test
void getOrElseThrowCreatesOnlyOneInstance() {
this.context.register(Integer.class, InstanceSupplier.from(this.counter::getAndIncrement));
assertThat(this.context.getOrElseThrow(Integer.class, RuntimeException::new)).isEqualTo(0);
assertThat(this.context.getOrElseThrow(Integer.class, RuntimeException::new)).isEqualTo(0);
}
@Test @Test
void closeMulticastsEventToListeners() { void closeMulticastsEventToListeners() {
TestCloseListener listener = new TestCloseListener(); TestCloseListener listener = new TestCloseListener();

Loading…
Cancel
Save