diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java index ff131a6767..c0a264d3c1 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2017 the original author or authors. + * Copyright 2012-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; +import java.util.function.Supplier; import org.springframework.beans.BeanUtils; import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; @@ -58,6 +59,7 @@ import org.springframework.web.util.UriTemplateHandler; * @author Stephane Nicoll * @author Phillip Webb * @author Andy Wilkinson + * @author Brian Clozel * @since 1.4.0 */ public class RestTemplateBuilder { @@ -81,7 +83,7 @@ public class RestTemplateBuilder { private final Set> messageConverters; - private final ClientHttpRequestFactory requestFactory; + private final Supplier requestFactorySupplier; private final UriTemplateHandler uriTemplateHandler; @@ -105,7 +107,7 @@ public class RestTemplateBuilder { this.detectRequestFactory = true; this.rootUri = null; this.messageConverters = null; - this.requestFactory = null; + this.requestFactorySupplier = null; this.uriTemplateHandler = null; this.errorHandler = null; this.basicAuthorization = null; @@ -117,7 +119,7 @@ public class RestTemplateBuilder { private RestTemplateBuilder(boolean detectRequestFactory, String rootUri, Set> messageConverters, - ClientHttpRequestFactory requestFactory, + Supplier requestFactorySupplier, UriTemplateHandler uriTemplateHandler, ResponseErrorHandler errorHandler, BasicAuthorizationInterceptor basicAuthorization, Set restTemplateCustomizers, @@ -126,7 +128,7 @@ public class RestTemplateBuilder { this.detectRequestFactory = detectRequestFactory; this.rootUri = rootUri; this.messageConverters = messageConverters; - this.requestFactory = requestFactory; + this.requestFactorySupplier = requestFactorySupplier; this.uriTemplateHandler = uriTemplateHandler; this.errorHandler = errorHandler; this.basicAuthorization = basicAuthorization; @@ -144,7 +146,7 @@ public class RestTemplateBuilder { */ public RestTemplateBuilder detectRequestFactory(boolean detectRequestFactory) { return new RestTemplateBuilder(detectRequestFactory, this.rootUri, - this.messageConverters, this.requestFactory, this.uriTemplateHandler, + this.messageConverters, this.requestFactorySupplier, this.uriTemplateHandler, this.errorHandler, this.basicAuthorization, this.restTemplateCustomizers, this.requestFactoryCustomizers, this.interceptors); } @@ -157,7 +159,7 @@ public class RestTemplateBuilder { */ public RestTemplateBuilder rootUri(String rootUri) { return new RestTemplateBuilder(this.detectRequestFactory, rootUri, - this.messageConverters, this.requestFactory, this.uriTemplateHandler, + this.messageConverters, this.requestFactorySupplier, this.uriTemplateHandler, this.errorHandler, this.basicAuthorization, this.restTemplateCustomizers, this.requestFactoryCustomizers, this.interceptors); } @@ -190,7 +192,7 @@ public class RestTemplateBuilder { return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, Collections.unmodifiableSet( new LinkedHashSet>(messageConverters)), - this.requestFactory, this.uriTemplateHandler, this.errorHandler, + this.requestFactorySupplier, this.uriTemplateHandler, this.errorHandler, this.basicAuthorization, this.restTemplateCustomizers, this.requestFactoryCustomizers, this.interceptors); } @@ -219,7 +221,7 @@ public class RestTemplateBuilder { Collection> messageConverters) { Assert.notNull(messageConverters, "MessageConverters must not be null"); return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, - append(this.messageConverters, messageConverters), this.requestFactory, + append(this.messageConverters, messageConverters), this.requestFactorySupplier, this.uriTemplateHandler, this.errorHandler, this.basicAuthorization, this.restTemplateCustomizers, this.requestFactoryCustomizers, this.interceptors); @@ -236,7 +238,7 @@ public class RestTemplateBuilder { return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, Collections.unmodifiableSet( new LinkedHashSet<>(new RestTemplate().getMessageConverters())), - this.requestFactory, this.uriTemplateHandler, this.errorHandler, + this.requestFactorySupplier, this.uriTemplateHandler, this.errorHandler, this.basicAuthorization, this.restTemplateCustomizers, this.requestFactoryCustomizers, this.interceptors); } @@ -269,7 +271,7 @@ public class RestTemplateBuilder { Collection interceptors) { Assert.notNull(interceptors, "interceptors must not be null"); return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, - this.messageConverters, this.requestFactory, this.uriTemplateHandler, + this.messageConverters, this.requestFactorySupplier, this.uriTemplateHandler, this.errorHandler, this.basicAuthorization, this.restTemplateCustomizers, this.requestFactoryCustomizers, Collections.unmodifiableSet(new LinkedHashSet<>(interceptors))); @@ -301,7 +303,7 @@ public class RestTemplateBuilder { Collection interceptors) { Assert.notNull(interceptors, "interceptors must not be null"); return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, - this.messageConverters, this.requestFactory, this.uriTemplateHandler, + this.messageConverters, this.requestFactorySupplier, this.uriTemplateHandler, this.errorHandler, this.basicAuthorization, this.restTemplateCustomizers, this.requestFactoryCustomizers, append(this.interceptors, interceptors)); } @@ -315,7 +317,7 @@ public class RestTemplateBuilder { public RestTemplateBuilder requestFactory( Class requestFactory) { Assert.notNull(requestFactory, "RequestFactory must not be null"); - return requestFactory(createRequestFactory(requestFactory)); + return requestFactory(() -> createRequestFactory(requestFactory)); } private ClientHttpRequestFactory createRequestFactory( @@ -331,15 +333,17 @@ public class RestTemplateBuilder { } /** - * Set the {@link ClientHttpRequestFactory} that should be used with the - * {@link RestTemplate}. - * @param requestFactory the request factory to use + * Set the {@code Supplier} of {@link ClientHttpRequestFactory} + * that should be called each time we {@link #build()} a new + * {@link RestTemplate} instance. + * @param requestFactorySupplier the supplier for the request factory * @return a new builder instance + * @since 2.0.0 */ - public RestTemplateBuilder requestFactory(ClientHttpRequestFactory requestFactory) { - Assert.notNull(requestFactory, "RequestFactory must not be null"); + public RestTemplateBuilder requestFactory(Supplier requestFactorySupplier) { + Assert.notNull(requestFactorySupplier, "RequestFactory Supplier must not be null"); return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, - this.messageConverters, requestFactory, this.uriTemplateHandler, + this.messageConverters, requestFactorySupplier, this.uriTemplateHandler, this.errorHandler, this.basicAuthorization, this.restTemplateCustomizers, this.requestFactoryCustomizers, this.interceptors); } @@ -353,7 +357,7 @@ public class RestTemplateBuilder { public RestTemplateBuilder uriTemplateHandler(UriTemplateHandler uriTemplateHandler) { Assert.notNull(uriTemplateHandler, "UriTemplateHandler must not be null"); return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, - this.messageConverters, this.requestFactory, uriTemplateHandler, + this.messageConverters, this.requestFactorySupplier, uriTemplateHandler, this.errorHandler, this.basicAuthorization, this.restTemplateCustomizers, this.requestFactoryCustomizers, this.interceptors); } @@ -367,7 +371,7 @@ public class RestTemplateBuilder { public RestTemplateBuilder errorHandler(ResponseErrorHandler errorHandler) { Assert.notNull(errorHandler, "ErrorHandler must not be null"); return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, - this.messageConverters, this.requestFactory, this.uriTemplateHandler, + this.messageConverters, this.requestFactorySupplier, this.uriTemplateHandler, errorHandler, this.basicAuthorization, this.restTemplateCustomizers, this.requestFactoryCustomizers, this.interceptors); } @@ -381,7 +385,7 @@ public class RestTemplateBuilder { */ public RestTemplateBuilder basicAuthorization(String username, String password) { return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, - this.messageConverters, this.requestFactory, this.uriTemplateHandler, + this.messageConverters, this.requestFactorySupplier, this.uriTemplateHandler, this.errorHandler, new BasicAuthorizationInterceptor(username, password), this.restTemplateCustomizers, this.requestFactoryCustomizers, this.interceptors); @@ -417,7 +421,7 @@ public class RestTemplateBuilder { Assert.notNull(restTemplateCustomizers, "RestTemplateCustomizers must not be null"); return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, - this.messageConverters, this.requestFactory, this.uriTemplateHandler, + this.messageConverters, this.requestFactorySupplier, this.uriTemplateHandler, this.errorHandler, this.basicAuthorization, Collections.unmodifiableSet(new LinkedHashSet( restTemplateCustomizers)), @@ -451,7 +455,7 @@ public class RestTemplateBuilder { Collection customizers) { Assert.notNull(customizers, "RestTemplateCustomizers must not be null"); return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, - this.messageConverters, this.requestFactory, this.uriTemplateHandler, + this.messageConverters, this.requestFactorySupplier, this.uriTemplateHandler, this.errorHandler, this.basicAuthorization, append(this.restTemplateCustomizers, customizers), this.requestFactoryCustomizers, this.interceptors); @@ -465,7 +469,7 @@ public class RestTemplateBuilder { */ public RestTemplateBuilder setConnectTimeout(int connectTimeout) { return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, - this.messageConverters, this.requestFactory, this.uriTemplateHandler, + this.messageConverters, this.requestFactorySupplier, this.uriTemplateHandler, this.errorHandler, this.basicAuthorization, this.restTemplateCustomizers, append(this.requestFactoryCustomizers, new ConnectTimeoutRequestFactoryCustomizer(connectTimeout)), @@ -480,7 +484,7 @@ public class RestTemplateBuilder { */ public RestTemplateBuilder setReadTimeout(int readTimeout) { return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, - this.messageConverters, this.requestFactory, this.uriTemplateHandler, + this.messageConverters, this.requestFactorySupplier, this.uriTemplateHandler, this.errorHandler, this.basicAuthorization, this.restTemplateCustomizers, append(this.requestFactoryCustomizers, new ReadTimeoutRequestFactoryCustomizer(readTimeout)), @@ -547,8 +551,8 @@ public class RestTemplateBuilder { private void configureRequestFactory(RestTemplate restTemplate) { ClientHttpRequestFactory requestFactory = null; - if (this.requestFactory != null) { - requestFactory = this.requestFactory; + if (this.requestFactorySupplier != null) { + requestFactory = this.requestFactorySupplier.get(); } else if (this.detectRequestFactory) { requestFactory = detectRequestFactory(); diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java index ef3faf4dfe..191ee22bd6 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2017 the original author or authors. + * Copyright 2012-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package org.springframework.boot.web.client; import java.util.Collections; import java.util.Set; +import java.util.function.Supplier; import org.apache.http.client.config.RequestConfig; import org.junit.Before; @@ -273,16 +274,16 @@ public class RestTemplateBuilderTests { } @Test - public void requestFactoryWhenFactoryIsNullShouldThrowException() { + public void requestFactoryWhenSupplierIsNullShouldThrowException() { this.thrown.expect(IllegalArgumentException.class); - this.thrown.expectMessage("RequestFactory must not be null"); - this.builder.requestFactory((ClientHttpRequestFactory) null); + this.thrown.expectMessage("RequestFactory Supplier must not be null"); + this.builder.requestFactory((Supplier) null); } @Test public void requestFactoryShouldApply() { ClientHttpRequestFactory requestFactory = mock(ClientHttpRequestFactory.class); - RestTemplate template = this.builder.requestFactory(requestFactory).build(); + RestTemplate template = this.builder.requestFactory(() -> requestFactory).build(); assertThat(template.getRequestFactory()).isSameAs(requestFactory); } @@ -466,7 +467,7 @@ public class RestTemplateBuilderTests { @Test public void connectTimeoutCanBeConfiguredOnAWrappedRequestFactory() { SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); - this.builder.requestFactory(new BufferingClientHttpRequestFactory(requestFactory)) + this.builder.requestFactory(() -> new BufferingClientHttpRequestFactory(requestFactory)) .setConnectTimeout(1234).build(); assertThat(ReflectionTestUtils.getField(requestFactory, "connectTimeout")) .isEqualTo(1234); @@ -475,7 +476,7 @@ public class RestTemplateBuilderTests { @Test public void readTimeoutCanBeConfiguredOnAWrappedRequestFactory() { SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); - this.builder.requestFactory(new BufferingClientHttpRequestFactory(requestFactory)) + this.builder.requestFactory(() -> new BufferingClientHttpRequestFactory(requestFactory)) .setReadTimeout(1234).build(); assertThat(ReflectionTestUtils.getField(requestFactory, "readTimeout")) .isEqualTo(1234); @@ -485,7 +486,7 @@ public class RestTemplateBuilderTests { public void unwrappingDoesNotAffectRequestFactoryThatIsSetOnTheBuiltTemplate() { SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); RestTemplate template = this.builder - .requestFactory(new BufferingClientHttpRequestFactory(requestFactory)) + .requestFactory(() -> new BufferingClientHttpRequestFactory(requestFactory)) .build(); assertThat(template.getRequestFactory()) .isInstanceOf(BufferingClientHttpRequestFactory.class);