Polish "Use StackWalker to deduce main application class"

See gh-31701
pull/31828/head
Andy Wilkinson 2 years ago
parent ea3fe95881
commit 38fedcff34

@ -16,6 +16,7 @@
package org.springframework.boot; package org.springframework.boot;
import java.lang.StackWalker.StackFrame;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -26,9 +27,11 @@ import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Optional;
import java.util.Properties; import java.util.Properties;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
@ -275,12 +278,15 @@ public class SpringApplication {
} }
private Class<?> deduceMainApplicationClass() { private Class<?> deduceMainApplicationClass() {
return StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE) return StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).walk(this::findMainClass)
.walk((s) -> s.filter(e -> Objects.equals(e.getMethodName(), "main")).findFirst()
.map(StackWalker.StackFrame::getDeclaringClass))
.orElse(null); .orElse(null);
} }
private Optional<Class<?>> findMainClass(Stream<StackFrame> stack) {
return stack.filter((frame) -> Objects.equals(frame.getMethodName(), "main")).findFirst()
.map(StackWalker.StackFrame::getDeclaringClass);
}
/** /**
* Run the Spring application, creating and refreshing a new * Run the Spring application, creating and refreshing a new
* {@link ApplicationContext}. * {@link ApplicationContext}.

@ -23,7 +23,6 @@ import java.util.Iterator;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
@ -1315,35 +1314,6 @@ class SpringApplicationTests {
.accepts(hints); .accepts(hints);
} }
@Test
void deduceMainApplicationClass() {
assertThat(
Objects.equals(deduceMainApplicationClassByStackWalker(), deduceMainApplicationClassByThrowException()))
.isTrue();
}
private Class<?> deduceMainApplicationClassByThrowException() {
try {
StackTraceElement[] stackTrace = new RuntimeException().getStackTrace();
for (StackTraceElement stackTraceElement : stackTrace) {
if ("main".equals(stackTraceElement.getMethodName())) {
return Class.forName(stackTraceElement.getClassName());
}
}
}
catch (ClassNotFoundException ex) {
// Swallow and continue
}
return null;
}
private Class<?> deduceMainApplicationClassByStackWalker() {
return StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE)
.walk((s) -> s.filter(e -> Objects.equals(e.getMethodName(), "main")).findFirst()
.map(StackWalker.StackFrame::getDeclaringClass))
.orElse(null);
}
private <S extends AvailabilityState> ArgumentMatcher<ApplicationEvent> isAvailabilityChangeEventWithState( private <S extends AvailabilityState> ArgumentMatcher<ApplicationEvent> isAvailabilityChangeEventWithState(
S state) { S state) {
return (argument) -> (argument instanceof AvailabilityChangeEvent<?>) return (argument) -> (argument instanceof AvailabilityChangeEvent<?>)

Loading…
Cancel
Save