Align servlet container error handling with executable jar/war behavior

Previously, when an exception was thrown by a Controller in an
application deployed to a servlet container the exception that was
handled would be Spring Framework’s NestedServletException rather than
the exception thrown by the application. Furthermore, when an exception
was thrown or the response was used to send an error, the
javax.servlet.error.request_uri request attribute would not be set. This
differed from the behaviour in an executable jar/war where the exception
would be the one thrown by the application, and the request_uri
attribute would be set.

This commit updates ErrorPageFilter, which is only involved in a servlet
container, to unwrap a NestedServletException so that it’s the
application’s exception that’s handled, and to set the request_uri
attribute in the event of an exception being thrown or an error being
sent.

Closes gh-3249
pull/3885/head
Andy Wilkinson 10 years ago
parent 5ed2a9632b
commit 6fd3042462

@ -39,6 +39,7 @@ import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order; import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.util.NestedServletException;
/** /**
* A special {@link AbstractConfigurableEmbeddedServletContainer} for non-embedded * A special {@link AbstractConfigurableEmbeddedServletContainer} for non-embedded
@ -69,6 +70,8 @@ public class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContaine
private static final String ERROR_MESSAGE = "javax.servlet.error.message"; private static final String ERROR_MESSAGE = "javax.servlet.error.message";
public static final String ERROR_REQUEST_URI = "javax.servlet.error.request_uri";
private static final String ERROR_STATUS_CODE = "javax.servlet.error.status_code"; private static final String ERROR_STATUS_CODE = "javax.servlet.error.status_code";
private String global; private String global;
@ -121,7 +124,11 @@ public class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContaine
} }
} }
catch (Throwable ex) { catch (Throwable ex) {
handleException(request, response, wrapped, ex); Throwable exceptionToHandle = ex;
if (ex instanceof NestedServletException) {
exceptionToHandle = ((NestedServletException) ex).getRootCause();
}
handleException(request, response, wrapped, exceptionToHandle);
response.flushBuffer(); response.flushBuffer();
} }
} }
@ -225,9 +232,10 @@ public class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContaine
return this.global; return this.global;
} }
private void setErrorAttributes(ServletRequest request, int status, String message) { private void setErrorAttributes(HttpServletRequest request, int status, String message) {
request.setAttribute(ERROR_STATUS_CODE, status); request.setAttribute(ERROR_STATUS_CODE, status);
request.setAttribute(ERROR_MESSAGE, message); request.setAttribute(ERROR_MESSAGE, message);
request.setAttribute(ERROR_REQUEST_URI, request.getRequestURI());
} }
private void rethrow(Throwable ex) throws IOException, ServletException { private void rethrow(Throwable ex) throws IOException, ServletException {

@ -38,6 +38,7 @@ import org.springframework.web.context.request.async.DeferredResult;
import org.springframework.web.context.request.async.StandardServletAsyncWebRequest; import org.springframework.web.context.request.async.StandardServletAsyncWebRequest;
import org.springframework.web.context.request.async.WebAsyncManager; import org.springframework.web.context.request.async.WebAsyncManager;
import org.springframework.web.context.request.async.WebAsyncUtils; import org.springframework.web.context.request.async.WebAsyncUtils;
import org.springframework.web.util.NestedServletException;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
@ -62,7 +63,8 @@ public class ErrorPageFilterTests {
private ErrorPageFilter filter = new ErrorPageFilter(); private ErrorPageFilter filter = new ErrorPageFilter();
private MockHttpServletRequest request = new MockHttpServletRequest(); private MockHttpServletRequest request = new MockHttpServletRequest("GET",
"/test/path");
private MockHttpServletResponse response = new MockHttpServletResponse(); private MockHttpServletResponse response = new MockHttpServletResponse();
@ -199,6 +201,9 @@ public class ErrorPageFilterTests {
equalTo((Object) 400)); equalTo((Object) 400));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE), assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE),
equalTo((Object) "BAD")); equalTo((Object) "BAD"));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI),
equalTo((Object) "/test/path"));
assertTrue(this.response.isCommitted()); assertTrue(this.response.isCommitted());
assertThat(this.response.getForwardedUrl(), equalTo("/error")); assertThat(this.response.getForwardedUrl(), equalTo("/error"));
} }
@ -221,6 +226,8 @@ public class ErrorPageFilterTests {
equalTo((Object) 400)); equalTo((Object) 400));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE), assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE),
equalTo((Object) "BAD")); equalTo((Object) "BAD"));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI),
equalTo((Object) "/test/path"));
assertTrue(this.response.isCommitted()); assertTrue(this.response.isCommitted());
assertThat(this.response.getForwardedUrl(), equalTo("/400")); assertThat(this.response.getForwardedUrl(), equalTo("/400"));
} }
@ -264,6 +271,8 @@ public class ErrorPageFilterTests {
equalTo((Object) "BAD")); equalTo((Object) "BAD"));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE), assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE),
equalTo((Object) RuntimeException.class.getName())); equalTo((Object) RuntimeException.class.getName()));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI),
equalTo((Object) "/test/path"));
assertTrue(this.response.isCommitted()); assertTrue(this.response.isCommitted());
assertThat(this.response.getForwardedUrl(), equalTo("/500")); assertThat(this.response.getForwardedUrl(), equalTo("/500"));
} }
@ -319,6 +328,8 @@ public class ErrorPageFilterTests {
equalTo((Object) "BAD")); equalTo((Object) "BAD"));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE), assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE),
equalTo((Object) IllegalStateException.class.getName())); equalTo((Object) IllegalStateException.class.getName()));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI),
equalTo((Object) "/test/path"));
assertTrue(this.response.isCommitted()); assertTrue(this.response.isCommitted());
} }
@ -465,6 +476,32 @@ public class ErrorPageFilterTests {
assertThat(this.output.toString(), containsString("request [/test/alpha]")); assertThat(this.output.toString(), containsString("request [/test/alpha]"));
} }
@Test
public void nestedServletExceptionIsUnwrapped() throws Exception {
this.filter.addErrorPages(new ErrorPage(RuntimeException.class, "/500"));
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
super.doFilter(request, response);
throw new NestedServletException("Wrapper", new RuntimeException("BAD"));
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus(),
equalTo(500));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_STATUS_CODE),
equalTo((Object) 500));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE),
equalTo((Object) "BAD"));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE),
equalTo((Object) RuntimeException.class.getName()));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI),
equalTo((Object) "/test/path"));
assertTrue(this.response.isCommitted());
assertThat(this.response.getForwardedUrl(), equalTo("/500"));
}
private void setUpAsyncDispatch() throws Exception { private void setUpAsyncDispatch() throws Exception {
this.request.setAsyncSupported(true); this.request.setAsyncSupported(true);
this.request.setAsyncStarted(true); this.request.setAsyncStarted(true);

Loading…
Cancel
Save