@ -16,6 +16,10 @@
package org.springframework.boot.security.servlet ;
package org.springframework.boot.security.servlet ;
import java.lang.Thread.UncaughtExceptionHandler ;
import java.util.ArrayList ;
import java.util.List ;
import java.util.concurrent.atomic.AtomicBoolean ;
import java.util.function.Supplier ;
import java.util.function.Supplier ;
import javax.servlet.http.HttpServletRequest ;
import javax.servlet.http.HttpServletRequest ;
@ -26,6 +30,7 @@ import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.context.ApplicationContext ;
import org.springframework.context.ApplicationContext ;
import org.springframework.mock.web.MockHttpServletRequest ;
import org.springframework.mock.web.MockHttpServletRequest ;
import org.springframework.mock.web.MockServletContext ;
import org.springframework.mock.web.MockServletContext ;
import org.springframework.util.ReflectionUtils ;
import org.springframework.web.context.WebApplicationContext ;
import org.springframework.web.context.WebApplicationContext ;
import org.springframework.web.context.support.StaticWebApplicationContext ;
import org.springframework.web.context.support.StaticWebApplicationContext ;
@ -105,6 +110,31 @@ public class ApplicationContextRequestMatcherTests {
assertThat ( matcher . matches ( request ) ) . isFalse ( ) ;
assertThat ( matcher . matches ( request ) ) . isFalse ( ) ;
}
}
@Test // gh-18211
public void matchesWhenConcurrentlyCalledWaitsForInitialize ( ) {
ConcurrentApplicationContextRequestMatcher matcher = new ConcurrentApplicationContextRequestMatcher ( ) ;
StaticWebApplicationContext context = createWebApplicationContext ( ) ;
Runnable target = ( ) - > matcher . matches ( new MockHttpServletRequest ( context . getServletContext ( ) ) ) ;
List < Thread > threads = new ArrayList < > ( ) ;
AssertingUncaughtExceptionHandler exceptionHandler = new AssertingUncaughtExceptionHandler ( ) ;
for ( int i = 0 ; i < 2 ; i + + ) {
Thread thread = new Thread ( target ) ;
thread . setUncaughtExceptionHandler ( exceptionHandler ) ;
threads . add ( thread ) ;
}
threads . forEach ( Thread : : start ) ;
threads . forEach ( this : : join ) ;
exceptionHandler . assertNoExceptions ( ) ;
}
private void join ( Thread thread ) {
try {
thread . join ( 1000 ) ;
}
catch ( InterruptedException ex ) {
}
}
private StaticWebApplicationContext createWebApplicationContext ( ) {
private StaticWebApplicationContext createWebApplicationContext ( ) {
StaticWebApplicationContext context = new StaticWebApplicationContext ( ) ;
StaticWebApplicationContext context = new StaticWebApplicationContext ( ) ;
MockServletContext servletContext = new MockServletContext ( ) ;
MockServletContext servletContext = new MockServletContext ( ) ;
@ -160,4 +190,47 @@ public class ApplicationContextRequestMatcherTests {
}
}
static class ConcurrentApplicationContextRequestMatcher extends ApplicationContextRequestMatcher < Object > {
ConcurrentApplicationContextRequestMatcher ( ) {
super ( Object . class ) ;
}
private AtomicBoolean initialized = new AtomicBoolean ( ) ;
@Override
protected void initialized ( Supplier < Object > context ) {
try {
Thread . sleep ( 200 ) ;
}
catch ( InterruptedException ex ) {
}
this . initialized . set ( true ) ;
}
@Override
protected boolean matches ( HttpServletRequest request , Supplier < Object > context ) {
assertThat ( this . initialized . get ( ) ) . isTrue ( ) ;
return true ;
}
}
private static class AssertingUncaughtExceptionHandler implements UncaughtExceptionHandler {
private volatile Throwable ex ;
@Override
public void uncaughtException ( Thread thead , Throwable ex ) {
this . ex = ex ;
}
public void assertNoExceptions ( ) {
if ( this . ex ! = null ) {
ReflectionUtils . rethrowRuntimeException ( this . ex ) ;
}
}
}
}
}