Use WebRequest rather than RequestAttributes in ErrorAttributes

This change aligns ErrorAttributes with ResponseEntityExceptionHandler
which takes a WebRequest as a parameter of its handleException method.
WebRequest extends RequestAttributes and provides access to much more
than just the request's attributes. For example request headers and
parameters are available from WebRequest.

Closes gh-7952
Closes gh-6555
pull/2836/merge
Andy Wilkinson 8 years ago
parent d6a3238e7d
commit 9192db692b

@ -24,7 +24,7 @@ import org.springframework.stereotype.Controller;
import org.springframework.util.Assert;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletWebRequest;
/**
* Special {@link MvcEndpoint} for handling "/error" path when the management servlet is
@ -45,9 +45,8 @@ public class ManagementErrorEndpoint {
@RequestMapping("${server.error.path:${error.path:/error}}")
@ResponseBody
public Map<String, Object> invoke() {
return this.errorAttributes.getErrorAttributes(
RequestContextHolder.currentRequestAttributes(), false);
public Map<String, Object> invoke(ServletWebRequest request) {
return this.errorAttributes.getErrorAttributes(request, false);
}
}

@ -42,7 +42,7 @@ import org.springframework.boot.actuate.trace.TraceProperties.Include;
import org.springframework.boot.autoconfigure.web.servlet.error.ErrorAttributes;
import org.springframework.core.Ordered;
import org.springframework.http.HttpStatus;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.context.request.ServletWebRequest;
import org.springframework.web.filter.OncePerRequestFilter;
/**
@ -149,7 +149,7 @@ public class WebRequestTraceFilter extends OncePerRequestFilter implements Order
if (isIncluded(Include.ERRORS) && exception != null
&& this.errorAttributes != null) {
trace.put("error", this.errorAttributes
.getErrorAttributes(new ServletRequestAttributes(request), true));
.getErrorAttributes(new ServletWebRequest(request), true));
}
return trace;
}

@ -27,8 +27,8 @@ import org.springframework.core.annotation.AnnotationAwareOrderComparator;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Controller;
import org.springframework.util.Assert;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.context.request.ServletWebRequest;
import org.springframework.web.context.request.WebRequest;
import org.springframework.web.servlet.ModelAndView;
/**
@ -68,9 +68,8 @@ public abstract class AbstractErrorController implements ErrorController {
protected Map<String, Object> getErrorAttributes(HttpServletRequest request,
boolean includeStackTrace) {
RequestAttributes requestAttributes = new ServletRequestAttributes(request);
return this.errorAttributes.getErrorAttributes(requestAttributes,
includeStackTrace);
WebRequest webRequest = new ServletWebRequest(request);
return this.errorAttributes.getErrorAttributes(webRequest, includeStackTrace);
}
protected boolean getTraceParameter(HttpServletRequest request) {

@ -34,6 +34,7 @@ import org.springframework.validation.BindingResult;
import org.springframework.validation.ObjectError;
import org.springframework.web.bind.MethodArgumentNotValidException;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.WebRequest;
import org.springframework.web.servlet.HandlerExceptionResolver;
import org.springframework.web.servlet.ModelAndView;
@ -100,13 +101,13 @@ public class DefaultErrorAttributes
}
@Override
public Map<String, Object> getErrorAttributes(RequestAttributes requestAttributes,
public Map<String, Object> getErrorAttributes(WebRequest webRequest,
boolean includeStackTrace) {
Map<String, Object> errorAttributes = new LinkedHashMap<>();
errorAttributes.put("timestamp", new Date());
addStatus(errorAttributes, requestAttributes);
addErrorDetails(errorAttributes, requestAttributes, includeStackTrace);
addPath(errorAttributes, requestAttributes);
addStatus(errorAttributes, webRequest);
addErrorDetails(errorAttributes, webRequest, includeStackTrace);
addPath(errorAttributes, webRequest);
return errorAttributes;
}
@ -130,8 +131,8 @@ public class DefaultErrorAttributes
}
private void addErrorDetails(Map<String, Object> errorAttributes,
RequestAttributes requestAttributes, boolean includeStackTrace) {
Throwable error = getError(requestAttributes);
WebRequest webRequest, boolean includeStackTrace) {
Throwable error = getError(webRequest);
if (error != null) {
while (error instanceof ServletException && error.getCause() != null) {
error = ((ServletException) error).getCause();
@ -144,7 +145,7 @@ public class DefaultErrorAttributes
addStackTrace(errorAttributes, error);
}
}
Object message = getAttribute(requestAttributes, "javax.servlet.error.message");
Object message = getAttribute(webRequest, "javax.servlet.error.message");
if ((!StringUtils.isEmpty(message) || errorAttributes.get("message") == null)
&& !(error instanceof BindingResult)) {
errorAttributes.put("message",
@ -195,10 +196,10 @@ public class DefaultErrorAttributes
}
@Override
public Throwable getError(RequestAttributes requestAttributes) {
Throwable exception = getAttribute(requestAttributes, ERROR_ATTRIBUTE);
public Throwable getError(WebRequest webRequest) {
Throwable exception = getAttribute(webRequest, ERROR_ATTRIBUTE);
if (exception == null) {
exception = getAttribute(requestAttributes, "javax.servlet.error.exception");
exception = getAttribute(webRequest, "javax.servlet.error.exception");
}
return exception;
}

@ -19,7 +19,7 @@ package org.springframework.boot.autoconfigure.web.servlet.error;
import java.util.Map;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.WebRequest;
import org.springframework.web.servlet.ModelAndView;
/**
@ -34,19 +34,19 @@ public interface ErrorAttributes {
/**
* Returns a {@link Map} of the error attributes. The map can be used as the model of
* an error page {@link ModelAndView}, or returned as a {@link ResponseBody}.
* @param requestAttributes the source request attributes
* @param webRequest the source request
* @param includeStackTrace if stack trace elements should be included
* @return a map of error attributes
*/
Map<String, Object> getErrorAttributes(RequestAttributes requestAttributes,
Map<String, Object> getErrorAttributes(WebRequest webRequest,
boolean includeStackTrace);
/**
* Return the underlying cause of the error or {@code null} if the error cannot be
* extracted.
* @param requestAttributes the source request attributes
* @param webRequest the source request
* @return the {@link Exception} that caused the error or {@code null}
*/
Throwable getError(RequestAttributes requestAttributes);
Throwable getError(WebRequest webRequest);
}

@ -31,8 +31,8 @@ import org.springframework.validation.BindingResult;
import org.springframework.validation.MapBindingResult;
import org.springframework.validation.ObjectError;
import org.springframework.web.bind.MethodArgumentNotValidException;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.context.request.ServletWebRequest;
import org.springframework.web.context.request.WebRequest;
import org.springframework.web.servlet.ModelAndView;
import static org.assertj.core.api.Assertions.assertThat;
@ -49,13 +49,12 @@ public class DefaultErrorAttributesTests {
private MockHttpServletRequest request = new MockHttpServletRequest();
private RequestAttributes requestAttributes = new ServletRequestAttributes(
this.request);
private WebRequest webRequest = new ServletWebRequest(this.request);
@Test
public void includeTimeStamp() throws Exception {
Map<String, Object> attributes = this.errorAttributes
.getErrorAttributes(this.requestAttributes, false);
.getErrorAttributes(this.webRequest, false);
assertThat(attributes.get("timestamp")).isInstanceOf(Date.class);
}
@ -63,7 +62,7 @@ public class DefaultErrorAttributesTests {
public void specificStatusCode() throws Exception {
this.request.setAttribute("javax.servlet.error.status_code", 404);
Map<String, Object> attributes = this.errorAttributes
.getErrorAttributes(this.requestAttributes, false);
.getErrorAttributes(this.webRequest, false);
assertThat(attributes.get("error"))
.isEqualTo(HttpStatus.NOT_FOUND.getReasonPhrase());
assertThat(attributes.get("status")).isEqualTo(404);
@ -72,7 +71,7 @@ public class DefaultErrorAttributesTests {
@Test
public void missingStatusCode() throws Exception {
Map<String, Object> attributes = this.errorAttributes
.getErrorAttributes(this.requestAttributes, false);
.getErrorAttributes(this.webRequest, false);
assertThat(attributes.get("error")).isEqualTo("None");
assertThat(attributes.get("status")).isEqualTo(999);
}
@ -85,8 +84,8 @@ public class DefaultErrorAttributesTests {
this.request.setAttribute("javax.servlet.error.exception",
new RuntimeException("Ignored"));
Map<String, Object> attributes = this.errorAttributes
.getErrorAttributes(this.requestAttributes, false);
assertThat(this.errorAttributes.getError(this.requestAttributes)).isSameAs(ex);
.getErrorAttributes(this.webRequest, false);
assertThat(this.errorAttributes.getError(this.webRequest)).isSameAs(ex);
assertThat(modelAndView).isNull();
assertThat(attributes.get("exception")).isNull();
assertThat(attributes.get("message")).isEqualTo("Test");
@ -97,8 +96,8 @@ public class DefaultErrorAttributesTests {
RuntimeException ex = new RuntimeException("Test");
this.request.setAttribute("javax.servlet.error.exception", ex);
Map<String, Object> attributes = this.errorAttributes
.getErrorAttributes(this.requestAttributes, false);
assertThat(this.errorAttributes.getError(this.requestAttributes)).isSameAs(ex);
.getErrorAttributes(this.webRequest, false);
assertThat(this.errorAttributes.getError(this.webRequest)).isSameAs(ex);
assertThat(attributes.get("exception")).isNull();
assertThat(attributes.get("message")).isEqualTo("Test");
}
@ -107,7 +106,7 @@ public class DefaultErrorAttributesTests {
public void servletMessage() throws Exception {
this.request.setAttribute("javax.servlet.error.message", "Test");
Map<String, Object> attributes = this.errorAttributes
.getErrorAttributes(this.requestAttributes, false);
.getErrorAttributes(this.webRequest, false);
assertThat(attributes.get("exception")).isNull();
assertThat(attributes.get("message")).isEqualTo("Test");
}
@ -118,7 +117,7 @@ public class DefaultErrorAttributesTests {
new RuntimeException());
this.request.setAttribute("javax.servlet.error.message", "Test");
Map<String, Object> attributes = this.errorAttributes
.getErrorAttributes(this.requestAttributes, false);
.getErrorAttributes(this.webRequest, false);
assertThat(attributes.get("exception")).isNull();
assertThat(attributes.get("message")).isEqualTo("Test");
}
@ -129,9 +128,8 @@ public class DefaultErrorAttributesTests {
ServletException wrapped = new ServletException(new ServletException(ex));
this.request.setAttribute("javax.servlet.error.exception", wrapped);
Map<String, Object> attributes = this.errorAttributes
.getErrorAttributes(this.requestAttributes, false);
assertThat(this.errorAttributes.getError(this.requestAttributes))
.isSameAs(wrapped);
.getErrorAttributes(this.webRequest, false);
assertThat(this.errorAttributes.getError(this.webRequest)).isSameAs(wrapped);
assertThat(attributes.get("exception")).isNull();
assertThat(attributes.get("message")).isEqualTo("Test");
}
@ -141,8 +139,8 @@ public class DefaultErrorAttributesTests {
Error error = new OutOfMemoryError("Test error");
this.request.setAttribute("javax.servlet.error.exception", error);
Map<String, Object> attributes = this.errorAttributes
.getErrorAttributes(this.requestAttributes, false);
assertThat(this.errorAttributes.getError(this.requestAttributes)).isSameAs(error);
.getErrorAttributes(this.webRequest, false);
assertThat(this.errorAttributes.getError(this.webRequest)).isSameAs(error);
assertThat(attributes.get("exception")).isNull();
assertThat(attributes.get("message")).isEqualTo("Test error");
}
@ -169,7 +167,7 @@ public class DefaultErrorAttributesTests {
private void testBindingResult(BindingResult bindingResult, Exception ex) {
this.request.setAttribute("javax.servlet.error.exception", ex);
Map<String, Object> attributes = this.errorAttributes
.getErrorAttributes(this.requestAttributes, false);
.getErrorAttributes(this.webRequest, false);
assertThat(attributes.get("message"))
.isEqualTo("Validation failed for object='objectName'. Error count: 1");
assertThat(attributes.get("errors")).isEqualTo(bindingResult.getAllErrors());
@ -181,7 +179,7 @@ public class DefaultErrorAttributesTests {
RuntimeException ex = new RuntimeException("Test");
this.request.setAttribute("javax.servlet.error.exception", ex);
Map<String, Object> attributes = errorAttributes
.getErrorAttributes(this.requestAttributes, false);
.getErrorAttributes(this.webRequest, false);
assertThat(attributes.get("exception"))
.isEqualTo(RuntimeException.class.getName());
assertThat(attributes.get("message")).isEqualTo("Test");
@ -192,7 +190,7 @@ public class DefaultErrorAttributesTests {
RuntimeException ex = new RuntimeException("Test");
this.request.setAttribute("javax.servlet.error.exception", ex);
Map<String, Object> attributes = this.errorAttributes
.getErrorAttributes(this.requestAttributes, true);
.getErrorAttributes(this.webRequest, true);
assertThat(attributes.get("trace").toString()).startsWith("java.lang");
}
@ -201,7 +199,7 @@ public class DefaultErrorAttributesTests {
RuntimeException ex = new RuntimeException("Test");
this.request.setAttribute("javax.servlet.error.exception", ex);
Map<String, Object> attributes = this.errorAttributes
.getErrorAttributes(this.requestAttributes, false);
.getErrorAttributes(this.webRequest, false);
assertThat(attributes.get("trace")).isNull();
}
@ -209,7 +207,7 @@ public class DefaultErrorAttributesTests {
public void path() throws Exception {
this.request.setAttribute("javax.servlet.error.request_uri", "path");
Map<String, Object> attributes = this.errorAttributes
.getErrorAttributes(this.requestAttributes, false);
.getErrorAttributes(this.webRequest, false);
assertThat(attributes.get("path")).isEqualTo("path");
}

Loading…
Cancel
Save