Polish "Use StackWalker to deduce main application class"

See gh-31701
pull/31828/head
Andy Wilkinson 2 years ago
parent ea3fe95881
commit 38fedcff34

@ -16,6 +16,7 @@
package org.springframework.boot;
import java.lang.StackWalker.StackFrame;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
@ -26,9 +27,11 @@ import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Properties;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@ -275,12 +278,15 @@ public class SpringApplication {
}
private Class<?> deduceMainApplicationClass() {
return StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE)
.walk((s) -> s.filter(e -> Objects.equals(e.getMethodName(), "main")).findFirst()
.map(StackWalker.StackFrame::getDeclaringClass))
return StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).walk(this::findMainClass)
.orElse(null);
}
private Optional<Class<?>> findMainClass(Stream<StackFrame> stack) {
return stack.filter((frame) -> Objects.equals(frame.getMethodName(), "main")).findFirst()
.map(StackWalker.StackFrame::getDeclaringClass);
}
/**
* Run the Spring application, creating and refreshing a new
* {@link ApplicationContext}.

@ -23,7 +23,6 @@ import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
@ -1315,35 +1314,6 @@ class SpringApplicationTests {
.accepts(hints);
}
@Test
void deduceMainApplicationClass() {
assertThat(
Objects.equals(deduceMainApplicationClassByStackWalker(), deduceMainApplicationClassByThrowException()))
.isTrue();
}
private Class<?> deduceMainApplicationClassByThrowException() {
try {
StackTraceElement[] stackTrace = new RuntimeException().getStackTrace();
for (StackTraceElement stackTraceElement : stackTrace) {
if ("main".equals(stackTraceElement.getMethodName())) {
return Class.forName(stackTraceElement.getClassName());
}
}
}
catch (ClassNotFoundException ex) {
// Swallow and continue
}
return null;
}
private Class<?> deduceMainApplicationClassByStackWalker() {
return StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE)
.walk((s) -> s.filter(e -> Objects.equals(e.getMethodName(), "main")).findFirst()
.map(StackWalker.StackFrame::getDeclaringClass))
.orElse(null);
}
private <S extends AvailabilityState> ArgumentMatcher<ApplicationEvent> isAvailabilityChangeEventWithState(
S state) {
return (argument) -> (argument instanceof AvailabilityChangeEvent<?>)

Loading…
Cancel
Save