From 447edd2c4eafde6d478690394a78cc936e9ef37c Mon Sep 17 00:00:00 2001 From: Phillip Webb Date: Mon, 14 Dec 2015 11:38:30 +0000 Subject: [PATCH] Allow gzip compression without `Content-Length` Ensure that gzip compression is applied when the `Content-Length` header is not specified. Prior to this commit Tomcat and Jetty would compress a response that didn't contain the header, but Undertow would not. Fixes gh-4769 --- .../UndertowEmbeddedServletContainer.java | 24 ++++++++++++++++++- ...tEmbeddedServletContainerFactoryTests.java | 24 ++++++++++++++++--- .../boot/context/embedded/ExampleServlet.java | 20 +++++++++++++--- 3 files changed, 61 insertions(+), 7 deletions(-) diff --git a/spring-boot/src/main/java/org/springframework/boot/context/embedded/undertow/UndertowEmbeddedServletContainer.java b/spring-boot/src/main/java/org/springframework/boot/context/embedded/undertow/UndertowEmbeddedServletContainer.java index 22be949363..fe77ee0679 100644 --- a/spring-boot/src/main/java/org/springframework/boot/context/embedded/undertow/UndertowEmbeddedServletContainer.java +++ b/spring-boot/src/main/java/org/springframework/boot/context/embedded/undertow/UndertowEmbeddedServletContainer.java @@ -36,6 +36,7 @@ import io.undertow.server.handlers.encoding.ContentEncodingRepository; import io.undertow.server.handlers.encoding.EncodingHandler; import io.undertow.server.handlers.encoding.GzipEncodingProvider; import io.undertow.servlet.api.DeploymentManager; +import io.undertow.util.Headers; import io.undertow.util.HttpString; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -161,7 +162,7 @@ public class UndertowEmbeddedServletContainer implements EmbeddedServletContaine private Predicate[] getCompressionPredicates(Compression compression) { List predicates = new ArrayList(); - predicates.add(Predicates.maxContentSize(compression.getMinResponseSize())); + predicates.add(new MaxSizePredicate(compression.getMinResponseSize())); predicates.add(new CompressibleMimeTypePredicate(compression.getMimeTypes())); if (compression.getExcludedUserAgents() != null) { for (String agent : compression.getExcludedUserAgents()) { @@ -294,4 +295,25 @@ public class UndertowEmbeddedServletContainer implements EmbeddedServletContaine } + /** + * Predicate that returns true if the Content-Size of a request is above a given value + * or is missing. + */ + private static class MaxSizePredicate implements Predicate { + + private final Predicate maxContentSize; + + public MaxSizePredicate(int size) { + this.maxContentSize = Predicates.maxContentSize(size); + } + + @Override + public boolean resolve(HttpServerExchange value) { + if (value.getResponseHeaders().contains(Headers.CONTENT_LENGTH)) { + return this.maxContentSize.resolve(value); + } + return true; + } + + } } diff --git a/spring-boot/src/test/java/org/springframework/boot/context/embedded/AbstractEmbeddedServletContainerFactoryTests.java b/spring-boot/src/test/java/org/springframework/boot/context/embedded/AbstractEmbeddedServletContainerFactoryTests.java index a8106dcae2..ac997d3604 100644 --- a/spring-boot/src/test/java/org/springframework/boot/context/embedded/AbstractEmbeddedServletContainerFactoryTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/context/embedded/AbstractEmbeddedServletContainerFactoryTests.java @@ -360,7 +360,7 @@ public abstract class AbstractEmbeddedServletContainerFactoryTests { ssl.setEnabled(false); factory.setSsl(ssl); this.container = factory.getEmbeddedServletContainer( - new ServletRegistrationBean(new ExampleServlet(true), "/hello")); + new ServletRegistrationBean(new ExampleServlet(true, false), "/hello")); this.container.start(); SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( new SSLContextBuilder() @@ -378,7 +378,7 @@ public abstract class AbstractEmbeddedServletContainerFactoryTests { AbstractEmbeddedServletContainerFactory factory = getFactory(); factory.setSsl(getSsl(null, "password", "src/test/resources/test.jks")); this.container = factory.getEmbeddedServletContainer( - new ServletRegistrationBean(new ExampleServlet(true), "/hello")); + new ServletRegistrationBean(new ExampleServlet(true, false), "/hello")); this.container.start(); SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( new SSLContextBuilder() @@ -658,6 +658,24 @@ public abstract class AbstractEmbeddedServletContainerFactoryTests { assertFalse(doTestCompression(10000, null, new String[] { "testUserAgent" })); } + @Test + public void compressionWithoutContentSizeHeader() throws Exception { + AbstractEmbeddedServletContainerFactory factory = getFactory(); + Compression compression = new Compression(); + compression.setEnabled(true); + factory.setCompression(compression); + this.container = factory.getEmbeddedServletContainer( + new ServletRegistrationBean(new ExampleServlet(false, true), "/hello")); + this.container.start(); + TestGzipInputStreamFactory inputStreamFactory = new TestGzipInputStreamFactory(); + Map contentDecoderMap = Collections + .singletonMap("gzip", (InputStreamFactory) inputStreamFactory); + getResponse(getLocalUrl("/hello"), + new HttpComponentsClientHttpRequestFactory(HttpClientBuilder.create() + .setContentDecoderRegistry(contentDecoderMap).build())); + assertThat(inputStreamFactory.wasCompressionUsed(), equalTo(true)); + } + @Test public void mimeMappingsAreCorrectlyConfigured() throws Exception { AbstractEmbeddedServletContainerFactory factory = getFactory(); @@ -824,7 +842,7 @@ public abstract class AbstractEmbeddedServletContainerFactoryTests { protected void assertForwardHeaderIsUsed(EmbeddedServletContainerFactory factory) throws IOException, URISyntaxException { this.container = factory.getEmbeddedServletContainer( - new ServletRegistrationBean(new ExampleServlet(true), "/hello")); + new ServletRegistrationBean(new ExampleServlet(true, false), "/hello")); this.container.start(); assertThat(getResponse(getLocalUrl("/hello"), "X-Forwarded-For:140.211.11.130"), containsString("remoteaddr=140.211.11.130")); diff --git a/spring-boot/src/test/java/org/springframework/boot/context/embedded/ExampleServlet.java b/spring-boot/src/test/java/org/springframework/boot/context/embedded/ExampleServlet.java index a33aa0b976..de8c065886 100644 --- a/spring-boot/src/test/java/org/springframework/boot/context/embedded/ExampleServlet.java +++ b/spring-boot/src/test/java/org/springframework/boot/context/embedded/ExampleServlet.java @@ -20,9 +20,12 @@ import java.io.IOException; import javax.servlet.GenericServlet; import javax.servlet.ServletException; +import javax.servlet.ServletOutputStream; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; +import org.springframework.util.StreamUtils; + /** * Simple example Servlet used for testing. * @@ -33,12 +36,15 @@ public class ExampleServlet extends GenericServlet { private final boolean echoRequestInfo; + private final boolean writeWithoutContentLength; + public ExampleServlet() { - this(false); + this(false, false); } - public ExampleServlet(boolean echoRequestInfo) { + public ExampleServlet(boolean echoRequestInfo, boolean writeWithoutContentLength) { this.echoRequestInfo = echoRequestInfo; + this.writeWithoutContentLength = writeWithoutContentLength; } @Override @@ -49,7 +55,15 @@ public class ExampleServlet extends GenericServlet { content += " scheme=" + request.getScheme(); content += " remoteaddr=" + request.getRemoteAddr(); } - response.getWriter().write(content); + if (this.writeWithoutContentLength) { + response.setContentType("text/plain"); + ServletOutputStream outputStream = response.getOutputStream(); + StreamUtils.copy(content.getBytes(), outputStream); + outputStream.flush(); + } + else { + response.getWriter().write(content); + } } }