Make sure ErrorPageFilter is only applied once per request

Fixes gh-1257
pull/1286/head
Dave Syer 10 years ago
parent 0c52817c88
commit 4a33ab5577

@ -38,6 +38,7 @@ import org.springframework.boot.context.embedded.ErrorPage;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;
/**
* A special {@link AbstractConfigurableEmbeddedServletContainer} for non-embedded
@ -76,21 +77,28 @@ class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContainer imple
private final Map<Class<?>, String> exceptions = new HashMap<Class<?>, String>();
private final Map<Class<?>, Class<?>> subtypes = new HashMap<Class<?>, Class<?>>();
private final OncePerRequestFilter delegate = new OncePerRequestFilter(
) {
@Override
protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response, FilterChain chain)
throws ServletException, IOException {
ErrorPageFilter.this.doFilter(request, response, chain);
}
};
@Override
public void init(FilterConfig filterConfig) throws ServletException {
delegate.init(filterConfig);
}
@Override
public void doFilter(ServletRequest request, ServletResponse response,
FilterChain chain) throws IOException, ServletException {
if (request instanceof HttpServletRequest
&& response instanceof HttpServletResponse) {
doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
}
else {
chain.doFilter(request, response);
}
delegate.doFilter(request, response, chain);
}
private void doFilter(HttpServletRequest request, HttpServletResponse response,

@ -16,6 +16,11 @@
package org.springframework.boot.context.web;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import java.io.IOException;
import javax.servlet.RequestDispatcher;
@ -29,13 +34,10 @@ import org.junit.Test;
import org.springframework.boot.context.embedded.ErrorPage;
import org.springframework.http.HttpStatus;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockFilterConfig;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
/**
* Tests for {@link ErrorPageFilter}.
*
@ -97,6 +99,21 @@ public class ErrorPageFilterTests {
equalTo(400));
}
@Test
public void oncePerRequest() throws Exception {
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
((HttpServletResponse) response).sendError(400, "BAD");
assertNotNull(request.getAttribute("FILTER.FILTERED"));
super.doFilter(request, response);
}
};
filter.init(new MockFilterConfig("FILTER"));
this.filter.doFilter(this.request, this.response, this.chain);
}
@Test
public void globalError() throws Exception {
this.filter.addErrorPages(new ErrorPage("/error"));

Loading…
Cancel
Save