diff --git a/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java b/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java index 0deda7d67b..efa21c1d9a 100644 --- a/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java +++ b/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java @@ -56,7 +56,7 @@ import org.springframework.web.filter.OncePerRequestFilter; @Component @Order(Ordered.HIGHEST_PRECEDENCE) class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContainer implements - Filter, NonEmbeddedServletContainerFactory { +Filter, NonEmbeddedServletContainerFactory { private static Log logger = LogFactory.getLog(ErrorPageFilter.class); @@ -109,18 +109,21 @@ class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContainer imple int status = wrapped.getStatus(); if (status >= 400) { handleErrorStatus(request, response, status, wrapped.getMessage()); + response.flushBuffer(); + } + else if (!request.isAsyncStarted()) { + response.flushBuffer(); } } catch (Throwable ex) { handleException(request, response, wrapped, ex); + response.flushBuffer(); } - response.flushBuffer(); - } private void handleErrorStatus(HttpServletRequest request, HttpServletResponse response, int status, String message) - throws ServletException, IOException { + throws ServletException, IOException { String errorPath = getErrorPath(this.statuses, status); if (errorPath == null) { response.sendError(status, message); @@ -132,7 +135,7 @@ class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContainer imple private void handleException(HttpServletRequest request, HttpServletResponse response, ErrorWrapperResponse wrapped, Throwable ex) - throws IOException, ServletException { + throws IOException, ServletException { Class type = ex.getClass(); String errorPath = getErrorPath(type); if (errorPath == null) { diff --git a/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java b/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java index 4660a90f87..d29127ffff 100644 --- a/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java @@ -34,6 +34,7 @@ import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; @@ -59,6 +60,7 @@ public class ErrorPageFilterTests { assertThat(this.chain.getRequest(), equalTo((ServletRequest) this.request)); assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse(), equalTo((ServletResponse) this.response)); + assertTrue(this.response.isCommitted()); } @Test @@ -79,6 +81,7 @@ public class ErrorPageFilterTests { equalTo((ServletResponse) this.response)); assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus(), equalTo(400)); + assertTrue(this.response.isCommitted()); } @Test @@ -97,6 +100,7 @@ public class ErrorPageFilterTests { equalTo((ServletResponse) this.response)); assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus(), equalTo(400)); + assertTrue(this.response.isCommitted()); } @Test @@ -199,6 +203,62 @@ public class ErrorPageFilterTests { equalTo((Object) "BAD")); assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE), equalTo((Object) IllegalStateException.class.getName())); + assertTrue(this.response.isCommitted()); + } + + @Test + public void responseIsNotCommitedWhenRequestIsAsync() throws Exception { + this.request.setAsyncStarted(true); + + this.filter.doFilter(this.request, this.response, this.chain); + + assertThat(this.chain.getRequest(), equalTo((ServletRequest) this.request)); + assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse(), + equalTo((ServletResponse) this.response)); + assertFalse(this.response.isCommitted()); + } + + @Test + public void responseIsCommitedWhenRequestIsAsyncAndExceptionIsThrown() + throws Exception { + this.filter.addErrorPages(new ErrorPage("/error")); + this.request.setAsyncStarted(true); + this.chain = new MockFilterChain() { + @Override + public void doFilter(ServletRequest request, ServletResponse response) + throws IOException, ServletException { + super.doFilter(request, response); + throw new RuntimeException("BAD"); + } + }; + + this.filter.doFilter(this.request, this.response, this.chain); + + assertThat(this.chain.getRequest(), equalTo((ServletRequest) this.request)); + assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse(), + equalTo((ServletResponse) this.response)); + assertTrue(this.response.isCommitted()); + } + + @Test + public void responseIsCommitedWhenRequestIsAsyncAndStatusIs400Plus() throws Exception { + this.filter.addErrorPages(new ErrorPage("/error")); + this.request.setAsyncStarted(true); + this.chain = new MockFilterChain() { + @Override + public void doFilter(ServletRequest request, ServletResponse response) + throws IOException, ServletException { + super.doFilter(request, response); + ((HttpServletResponse) response).sendError(400, "BAD"); + } + }; + + this.filter.doFilter(this.request, this.response, this.chain); + + assertThat(this.chain.getRequest(), equalTo((ServletRequest) this.request)); + assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse(), + equalTo((ServletResponse) this.response)); + assertTrue(this.response.isCommitted()); } }