From 50430a20c65939d50722f525c6723c17101da52a Mon Sep 17 00:00:00 2001 From: Phillip Webb Date: Wed, 30 Sep 2015 16:58:55 -0700 Subject: [PATCH] Add Tomcat X-Forwarded-For header tests Update Abstract & Tomcat EmbeddedServletContainerFactoryTests to check that X-Forwarded-For headers work as expected. See gh-4018 --- ...tEmbeddedServletContainerFactoryTests.java | 34 +++++++++++++------ .../boot/context/embedded/ExampleServlet.java | 3 +- ...tEmbeddedServletContainerFactoryTests.java | 8 +++++ 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/spring-boot/src/test/java/org/springframework/boot/context/embedded/AbstractEmbeddedServletContainerFactoryTests.java b/spring-boot/src/test/java/org/springframework/boot/context/embedded/AbstractEmbeddedServletContainerFactoryTests.java index 3df99931e5..842c953669 100644 --- a/spring-boot/src/test/java/org/springframework/boot/context/embedded/AbstractEmbeddedServletContainerFactoryTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/context/embedded/AbstractEmbeddedServletContainerFactoryTests.java @@ -638,8 +638,9 @@ public abstract class AbstractEmbeddedServletContainerFactoryTests { return "http://localhost:" + port + resourcePath; } - protected String getResponse(String url) throws IOException, URISyntaxException { - ClientHttpResponse response = getClientResponse(url); + protected String getResponse(String url, String... headers) throws IOException, + URISyntaxException { + ClientHttpResponse response = getClientResponse(url, headers); try { return StreamUtils.copyToString(response.getBody(), Charset.forName("UTF-8")); } @@ -649,9 +650,9 @@ public abstract class AbstractEmbeddedServletContainerFactoryTests { } protected String getResponse(String url, - HttpComponentsClientHttpRequestFactory requestFactory) throws IOException, - URISyntaxException { - ClientHttpResponse response = getClientResponse(url, requestFactory); + HttpComponentsClientHttpRequestFactory requestFactory, String... headers) + throws IOException, URISyntaxException { + ClientHttpResponse response = getClientResponse(url, requestFactory, headers); try { return StreamUtils.copyToString(response.getBody(), Charset.forName("UTF-8")); } @@ -660,8 +661,8 @@ public abstract class AbstractEmbeddedServletContainerFactoryTests { } } - protected ClientHttpResponse getClientResponse(String url) throws IOException, - URISyntaxException { + protected ClientHttpResponse getClientResponse(String url, String... headers) + throws IOException, URISyntaxException { return getClientResponse(url, new HttpComponentsClientHttpRequestFactory() { @Override @@ -669,19 +670,32 @@ public abstract class AbstractEmbeddedServletContainerFactoryTests { return AbstractEmbeddedServletContainerFactoryTests.this.httpClientContext; } - }); + }, headers); } protected ClientHttpResponse getClientResponse(String url, - HttpComponentsClientHttpRequestFactory requestFactory) throws IOException, - URISyntaxException { + HttpComponentsClientHttpRequestFactory requestFactory, String... headers) + throws IOException, URISyntaxException { ClientHttpRequest request = requestFactory.createRequest(new URI(url), HttpMethod.GET); request.getHeaders().add("Cookie", "JSESSIONID=" + "123"); + for (String header : headers) { + String[] parts = header.split(":"); + request.getHeaders().add(parts[0], parts[1]); + } ClientHttpResponse response = request.execute(); return response; } + protected void assertForwardHeaderIsUsed(EmbeddedServletContainerFactory factory) + throws IOException, URISyntaxException { + this.container = factory.getEmbeddedServletContainer(new ServletRegistrationBean( + new ExampleServlet(true), "/hello")); + this.container.start(); + assertThat(getResponse(getLocalUrl("/hello"), "X-Forwarded-For:140.211.11.130"), + containsString("remoteaddr=140.211.11.130")); + } + protected abstract AbstractEmbeddedServletContainerFactory getFactory(); protected abstract Object getJspServlet(); diff --git a/spring-boot/src/test/java/org/springframework/boot/context/embedded/ExampleServlet.java b/spring-boot/src/test/java/org/springframework/boot/context/embedded/ExampleServlet.java index 6d625aa2a2..a33aa0b976 100644 --- a/spring-boot/src/test/java/org/springframework/boot/context/embedded/ExampleServlet.java +++ b/spring-boot/src/test/java/org/springframework/boot/context/embedded/ExampleServlet.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2013 the original author or authors. + * Copyright 2012-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,6 +47,7 @@ public class ExampleServlet extends GenericServlet { String content = "Hello World"; if (this.echoRequestInfo) { content += " scheme=" + request.getScheme(); + content += " remoteaddr=" + request.getRemoteAddr(); } response.getWriter().write(content); } diff --git a/spring-boot/src/test/java/org/springframework/boot/context/embedded/tomcat/TomcatEmbeddedServletContainerFactoryTests.java b/spring-boot/src/test/java/org/springframework/boot/context/embedded/tomcat/TomcatEmbeddedServletContainerFactoryTests.java index ed59f39241..49de3bd665 100644 --- a/spring-boot/src/test/java/org/springframework/boot/context/embedded/tomcat/TomcatEmbeddedServletContainerFactoryTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/context/embedded/tomcat/TomcatEmbeddedServletContainerFactoryTests.java @@ -35,6 +35,7 @@ import org.apache.catalina.Valve; import org.apache.catalina.Wrapper; import org.apache.catalina.connector.Connector; import org.apache.catalina.startup.Tomcat; +import org.apache.catalina.valves.RemoteIpValve; import org.apache.coyote.http11.AbstractHttp11JsseProtocol; import org.junit.Test; import org.mockito.InOrder; @@ -333,6 +334,13 @@ public class TomcatEmbeddedServletContainerFactoryTests extends assertThat(jspServlet.findInitParameter("a"), is(equalTo("alpha"))); } + @Test + public void useForwardHeaders() throws Exception { + TomcatEmbeddedServletContainerFactory factory = getFactory(); + factory.addContextValves(new RemoteIpValve()); + assertForwardHeaderIsUsed(factory); + } + @Override protected Wrapper getJspServlet() { Container context = ((TomcatEmbeddedServletContainer) this.container).getTomcat()