Ensure matches is not called before initialization

Update `ApplicationContextRequestMatcher` to ensure that the `matches`
method is never called before `initialized`. This fixes an issue
accidentally introduced in commit 5938ca78 where concurrent calls
to `matches` could trigger unexpected errors due to the fact that the
second call proceeded before the `initialized` method had returned.

Fixes gh-18211
pull/18464/head
Phillip Webb 5 years ago
parent 5427526bcc
commit 1ceb96f9f2

@ -16,7 +16,6 @@
package org.springframework.boot.security.servlet; package org.springframework.boot.security.servlet;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier; import java.util.function.Supplier;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
@ -44,7 +43,9 @@ public abstract class ApplicationContextRequestMatcher<C> implements RequestMatc
private final Class<? extends C> contextClass; private final Class<? extends C> contextClass;
private final AtomicBoolean initialized = new AtomicBoolean(false); private volatile boolean initialized;
private final Object initializeLock = new Object();
public ApplicationContextRequestMatcher(Class<? extends C> contextClass) { public ApplicationContextRequestMatcher(Class<? extends C> contextClass) {
Assert.notNull(contextClass, "Context class must not be null"); Assert.notNull(contextClass, "Context class must not be null");
@ -59,8 +60,13 @@ public abstract class ApplicationContextRequestMatcher<C> implements RequestMatc
return false; return false;
} }
Supplier<C> context = () -> getContext(webApplicationContext); Supplier<C> context = () -> getContext(webApplicationContext);
if (this.initialized.compareAndSet(false, true)) { if (!this.initialized) {
initialized(context); synchronized (this.initializeLock) {
if (!this.initialized) {
initialized(context);
this.initialized = true;
}
}
} }
return matches(request, context); return matches(request, context);
} }

@ -16,6 +16,10 @@
package org.springframework.boot.security.servlet; package org.springframework.boot.security.servlet;
import java.lang.Thread.UncaughtExceptionHandler;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier; import java.util.function.Supplier;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
@ -26,6 +30,7 @@ import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockServletContext; import org.springframework.mock.web.MockServletContext;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.StaticWebApplicationContext; import org.springframework.web.context.support.StaticWebApplicationContext;
@ -105,6 +110,31 @@ public class ApplicationContextRequestMatcherTests {
assertThat(matcher.matches(request)).isFalse(); assertThat(matcher.matches(request)).isFalse();
} }
@Test // gh-18211
public void matchesWhenConcurrentlyCalledWaitsForInitialize() {
ConcurrentApplicationContextRequestMatcher matcher = new ConcurrentApplicationContextRequestMatcher();
StaticWebApplicationContext context = createWebApplicationContext();
Runnable target = () -> matcher.matches(new MockHttpServletRequest(context.getServletContext()));
List<Thread> threads = new ArrayList<>();
AssertingUncaughtExceptionHandler exceptionHandler = new AssertingUncaughtExceptionHandler();
for (int i = 0; i < 2; i++) {
Thread thread = new Thread(target);
thread.setUncaughtExceptionHandler(exceptionHandler);
threads.add(thread);
}
threads.forEach(Thread::start);
threads.forEach(this::join);
exceptionHandler.assertNoExceptions();
}
private void join(Thread thread) {
try {
thread.join(1000);
}
catch (InterruptedException ex) {
}
}
private StaticWebApplicationContext createWebApplicationContext() { private StaticWebApplicationContext createWebApplicationContext() {
StaticWebApplicationContext context = new StaticWebApplicationContext(); StaticWebApplicationContext context = new StaticWebApplicationContext();
MockServletContext servletContext = new MockServletContext(); MockServletContext servletContext = new MockServletContext();
@ -160,4 +190,47 @@ public class ApplicationContextRequestMatcherTests {
} }
static class ConcurrentApplicationContextRequestMatcher extends ApplicationContextRequestMatcher<Object> {
ConcurrentApplicationContextRequestMatcher() {
super(Object.class);
}
private AtomicBoolean initialized = new AtomicBoolean();
@Override
protected void initialized(Supplier<Object> context) {
try {
Thread.sleep(200);
}
catch (InterruptedException ex) {
}
this.initialized.set(true);
}
@Override
protected boolean matches(HttpServletRequest request, Supplier<Object> context) {
assertThat(this.initialized.get()).isTrue();
return true;
}
}
private static class AssertingUncaughtExceptionHandler implements UncaughtExceptionHandler {
private volatile Throwable ex;
@Override
public void uncaughtException(Thread thead, Throwable ex) {
this.ex = ex;
}
public void assertNoExceptions() {
if (this.ex != null) {
ReflectionUtils.rethrowRuntimeException(this.ex);
}
}
}
} }

Loading…
Cancel
Save