diff --git a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java index 60864e8a21..dde8c630d1 100644 --- a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java +++ b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java @@ -17,13 +17,11 @@ package org.springframework.boot.test.web.client; import java.io.IOException; -import java.lang.reflect.Field; import java.net.URI; import java.util.Arrays; import java.util.HashSet; import java.util.Map; import java.util.Set; -import java.util.function.Supplier; import org.apache.http.client.HttpClient; import org.apache.http.client.config.CookieSpecs; @@ -36,11 +34,6 @@ import org.apache.http.impl.client.HttpClients; import org.apache.http.protocol.HttpContext; import org.apache.http.ssl.SSLContextBuilder; -import org.springframework.beans.BeanInstantiationException; -import org.springframework.beans.BeanUtils; -import org.springframework.boot.web.client.BasicAuthentication; -import org.springframework.boot.web.client.BasicAuthenticationClientHttpRequestFactory; -import org.springframework.boot.web.client.ClientHttpRequestFactorySupplier; import org.springframework.boot.web.client.RestTemplateBuilder; import org.springframework.boot.web.client.RootUriTemplateHandler; import org.springframework.core.ParameterizedTypeReference; @@ -49,13 +42,10 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; -import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; -import org.springframework.http.client.InterceptingClientHttpRequestFactory; import org.springframework.util.Assert; -import org.springframework.util.ReflectionUtils; import org.springframework.web.client.DefaultResponseErrorHandler; import org.springframework.web.client.RequestCallback; import org.springframework.web.client.ResponseExtractor; @@ -89,10 +79,12 @@ import org.springframework.web.util.UriTemplateHandler; */ public class TestRestTemplate { - private final RestTemplate restTemplate; + private final RestTemplateBuilder builder; private final HttpClientOption[] httpClientOptions; + private final RestTemplate restTemplate; + /** * Create a new {@link TestRestTemplate} instance. * @param restTemplateBuilder builder used to configure underlying @@ -124,64 +116,30 @@ public class TestRestTemplate { /** * Create a new {@link TestRestTemplate} instance with the specified credentials. - * @param restTemplateBuilder builder used to configure underlying - * {@link RestTemplate} + * @param builder builder used to configure underlying {@link RestTemplate} * @param username the username to use (or {@code null}) * @param password the password (or {@code null}) * @param httpClientOptions client options to use if the Apache HTTP Client is used * @since 2.0.0 */ - public TestRestTemplate(RestTemplateBuilder restTemplateBuilder, String username, - String password, HttpClientOption... httpClientOptions) { - this((restTemplateBuilder != null) ? restTemplateBuilder.build() : null, username, - password, httpClientOptions); - } - - private TestRestTemplate(RestTemplate restTemplate, String username, String password, + public TestRestTemplate(RestTemplateBuilder builder, String username, String password, HttpClientOption... httpClientOptions) { - Assert.notNull(restTemplate, "RestTemplate must not be null"); + Assert.notNull(builder, "Builder must not be null"); + this.builder = builder; this.httpClientOptions = httpClientOptions; - if (getRequestFactoryClass(restTemplate) - .isAssignableFrom(HttpComponentsClientHttpRequestFactory.class)) { - restTemplate.setRequestFactory( - new CustomHttpComponentsClientHttpRequestFactory(httpClientOptions)); - } - addAuthentication(restTemplate, username, password); - restTemplate.setErrorHandler(new NoOpResponseErrorHandler()); - this.restTemplate = restTemplate; - } - - private Class getRequestFactoryClass( - RestTemplate restTemplate) { - return getRequestFactory(restTemplate).getClass(); - } - - private ClientHttpRequestFactory getRequestFactory(RestTemplate restTemplate) { - ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory(); - while (requestFactory instanceof InterceptingClientHttpRequestFactory - || requestFactory instanceof BasicAuthenticationClientHttpRequestFactory) { - requestFactory = unwrapRequestFactory( - ((AbstractClientHttpRequestFactoryWrapper) requestFactory)); + if (httpClientOptions != null) { + ClientHttpRequestFactory requestFactory = builder.buildRequestFactory(); + if (requestFactory instanceof HttpComponentsClientHttpRequestFactory) { + builder = builder.requestFactory( + () -> new CustomHttpComponentsClientHttpRequestFactory( + httpClientOptions)); + } } - return requestFactory; - } - - private ClientHttpRequestFactory unwrapRequestFactory( - AbstractClientHttpRequestFactoryWrapper requestFactory) { - Field field = ReflectionUtils.findField( - AbstractClientHttpRequestFactoryWrapper.class, "requestFactory"); - ReflectionUtils.makeAccessible(field); - return (ClientHttpRequestFactory) ReflectionUtils.getField(field, requestFactory); - } - - private void addAuthentication(RestTemplate restTemplate, String username, - String password) { - if (username == null || password == null) { - return; + if (username != null || password != null) { + builder = builder.basicAuthentication(username, password); } - ClientHttpRequestFactory requestFactory = getRequestFactory(restTemplate); - restTemplate.setRequestFactory(new BasicAuthenticationClientHttpRequestFactory( - new BasicAuthentication(username, password), requestFactory)); + this.restTemplate = builder.build(); + this.restTemplate.setErrorHandler(new NoOpResponseErrorHandler()); } /** @@ -1038,25 +996,10 @@ public class TestRestTemplate { * @since 1.4.1 */ public TestRestTemplate withBasicAuth(String username, String password) { - RestTemplate restTemplate = new RestTemplateBuilder() - .requestFactory(getRequestFactorySupplier()) - .messageConverters(getRestTemplate().getMessageConverters()) - .interceptors(getRestTemplate().getInterceptors()) - .uriTemplateHandler(getRestTemplate().getUriTemplateHandler()).build(); - return new TestRestTemplate(restTemplate, username, password, + TestRestTemplate template = new TestRestTemplate(this.builder, username, password, this.httpClientOptions); - } - - private Supplier getRequestFactorySupplier() { - return () -> { - try { - return BeanUtils - .instantiateClass(getRequestFactoryClass(getRestTemplate())); - } - catch (BeanInstantiationException ex) { - return new ClientHttpRequestFactorySupplier().get(); - } - }; + template.setUriTemplateHandler(getRestTemplate().getUriTemplateHandler()); + return template; } @SuppressWarnings({ "rawtypes", "unchecked" }) @@ -1078,7 +1021,7 @@ public class TestRestTemplate { } /** - * Options used to customize the Apache Http Client if it is used. + * Options used to customize the Apache HTTP Client. */ public enum HttpClientOption { diff --git a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java index e4993ed99f..e2feb5b009 100644 --- a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java +++ b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java @@ -20,13 +20,14 @@ import java.io.IOException; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.net.URI; +import java.util.List; +import java.util.stream.Collectors; import org.apache.http.client.config.RequestConfig; import org.junit.jupiter.api.Test; import org.springframework.boot.test.web.client.TestRestTemplate.CustomHttpComponentsClientHttpRequestFactory; import org.springframework.boot.test.web.client.TestRestTemplate.HttpClientOption; -import org.springframework.boot.web.client.BasicAuthenticationClientHttpRequestFactory; import org.springframework.boot.web.client.RestTemplateBuilder; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpEntity; @@ -99,25 +100,11 @@ public class TestRestTemplateTests { TestRestTemplate testRestTemplate = new TestRestTemplate(builder) .withBasicAuth("test", "test"); RestTemplate restTemplate = testRestTemplate.getRestTemplate(); + assertThat(restTemplate.getRequestFactory().getClass().getName()) + .contains("BasicAuth"); Object requestFactory = ReflectionTestUtils .getField(restTemplate.getRequestFactory(), "requestFactory"); - assertThat(requestFactory).isNotEqualTo(customFactory) - .hasSameClassAs(customFactory); - } - - @Test - public void withBasicAuthWhenRequestFactoryTypeCannotBeInstantiatedShouldFallback() { - TestClientHttpRequestFactory customFactory = new TestClientHttpRequestFactory( - "my-request-factory"); - RestTemplateBuilder builder = new RestTemplateBuilder() - .requestFactory(() -> customFactory); - TestRestTemplate testRestTemplate = new TestRestTemplate(builder) - .withBasicAuth("test", "test"); - RestTemplate restTemplate = testRestTemplate.getRestTemplate(); - Object requestFactory = ReflectionTestUtils - .getField(restTemplate.getRequestFactory(), "requestFactory"); - assertThat(requestFactory).isNotEqualTo(customFactory) - .isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class); + assertThat(requestFactory).isEqualTo(customFactory).hasSameClassAs(customFactory); } @Test @@ -145,9 +132,10 @@ public class TestRestTemplateTests { @Test public void authenticated() { - assertThat(new TestRestTemplate("user", "password").getRestTemplate() - .getRequestFactory()) - .isInstanceOf(BasicAuthenticationClientHttpRequestFactory.class); + RestTemplate restTemplate = new TestRestTemplate("user", "password") + .getRestTemplate(); + ClientHttpRequestFactory factory = restTemplate.getRequestFactory(); + assertThat(factory.getClass().getName()).contains("BasicAuthentication"); } @Test @@ -225,22 +213,17 @@ public class TestRestTemplateTests { @Test public void withBasicAuthAddsBasicAuthClientFactoryWhenNotAlreadyPresent() { - TestRestTemplate originalTemplate = new TestRestTemplate(); - TestRestTemplate basicAuthTemplate = originalTemplate.withBasicAuth("user", - "password"); - assertThat(basicAuthTemplate.getRestTemplate().getMessageConverters()) - .containsExactlyElementsOf( - originalTemplate.getRestTemplate().getMessageConverters()); - assertThat(basicAuthTemplate.getRestTemplate().getRequestFactory()) - .isInstanceOf(BasicAuthenticationClientHttpRequestFactory.class); + TestRestTemplate original = new TestRestTemplate(); + TestRestTemplate basicAuth = original.withBasicAuth("user", "password"); + assertThat(getConverterClasses(original)) + .containsExactlyElementsOf(getConverterClasses(basicAuth)); + assertThat(basicAuth.getRestTemplate().getRequestFactory().getClass().getName()) + .contains("BasicAuth"); assertThat(ReflectionTestUtils.getField( - basicAuthTemplate.getRestTemplate().getRequestFactory(), - "requestFactory")) + basicAuth.getRestTemplate().getRequestFactory(), "requestFactory")) .isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class); - assertThat(basicAuthTemplate.getRestTemplate().getUriTemplateHandler()) - .isSameAs(originalTemplate.getRestTemplate().getUriTemplateHandler()); - assertThat(basicAuthTemplate.getRestTemplate().getInterceptors()).isEmpty(); - assertBasicAuthorizationCredentials(basicAuthTemplate, "user", "password"); + assertThat(basicAuth.getRestTemplate().getInterceptors()).isEmpty(); + assertBasicAuthorizationCredentials(basicAuth, "user", "password"); } @Test @@ -248,20 +231,22 @@ public class TestRestTemplateTests { TestRestTemplate original = new TestRestTemplate("foo", "bar") .withBasicAuth("replace", "replace"); TestRestTemplate basicAuth = original.withBasicAuth("user", "password"); - assertThat(basicAuth.getRestTemplate().getMessageConverters()) - .containsExactlyElementsOf( - original.getRestTemplate().getMessageConverters()); - assertThat(basicAuth.getRestTemplate().getRequestFactory()) - .isInstanceOf(BasicAuthenticationClientHttpRequestFactory.class); + assertThat(getConverterClasses(basicAuth)) + .containsExactlyElementsOf(getConverterClasses(original)); + assertThat(basicAuth.getRestTemplate().getRequestFactory().getClass().getName()) + .contains("BasicAuth"); assertThat(ReflectionTestUtils.getField( basicAuth.getRestTemplate().getRequestFactory(), "requestFactory")) .isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class); - assertThat(basicAuth.getRestTemplate().getUriTemplateHandler()) - .isSameAs(original.getRestTemplate().getUriTemplateHandler()); assertThat(basicAuth.getRestTemplate().getInterceptors()).isEmpty(); assertBasicAuthorizationCredentials(basicAuth, "user", "password"); } + private List> getConverterClasses(TestRestTemplate testRestTemplate) { + return testRestTemplate.getRestTemplate().getMessageConverters().stream() + .map(Object::getClass).collect(Collectors.toList()); + } + @Test public void withBasicAuthShouldUseNoOpErrorHandler() throws Exception { TestRestTemplate originalTemplate = new TestRestTemplate("foo", "bar"); diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthentication.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthentication.java index c9b1f618f8..059499ccef 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthentication.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthentication.java @@ -18,6 +18,8 @@ package org.springframework.boot.web.client; import java.nio.charset.Charset; +import org.springframework.http.HttpHeaders; +import org.springframework.http.client.ClientHttpRequest; import org.springframework.util.Assert; /** @@ -25,10 +27,9 @@ import org.springframework.util.Assert; * {@link BasicAuthenticationClientHttpRequestFactory}. * * @author Dmytro Nosan - * @since 2.2.0 * @see BasicAuthenticationClientHttpRequestFactory */ -public class BasicAuthentication { +class BasicAuthentication { private final String username; @@ -36,22 +37,7 @@ public class BasicAuthentication { private final Charset charset; - /** - * Create a new {@link BasicAuthentication}. - * @param username the username to use - * @param password the password to use - */ - public BasicAuthentication(String username, String password) { - this(username, password, null); - } - - /** - * Create a new {@link BasicAuthentication}. - * @param username the username to use - * @param password the password to use - * @param charset the charset to use - */ - public BasicAuthentication(String username, String password, Charset charset) { + BasicAuthentication(String username, String password, Charset charset) { Assert.notNull(username, "Username must not be null"); Assert.notNull(password, "Password must not be null"); this.username = username; @@ -59,28 +45,11 @@ public class BasicAuthentication { this.charset = charset; } - /** - * The username to use. - * @return the username, never {@code null} or {@code empty}. - */ - public String getUsername() { - return this.username; - } - - /** - * The password to use. - * @return the password, never {@code null} or {@code empty}. - */ - public String getPassword() { - return this.password; - } - - /** - * The charset to use. - * @return the charset, or {@code null}. - */ - public Charset getCharset() { - return this.charset; + void applyTo(ClientHttpRequest request) { + HttpHeaders headers = request.getHeaders(); + if (!headers.containsKey(HttpHeaders.AUTHORIZATION)) { + headers.setBasicAuth(this.username, this.password, this.charset); + } } } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactory.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactory.java index dffe789e01..016a2cb0fa 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactory.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactory.java @@ -19,7 +19,6 @@ package org.springframework.boot.web.client; import java.io.IOException; import java.net.URI; -import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; import org.springframework.http.client.ClientHttpRequest; @@ -31,20 +30,13 @@ import org.springframework.util.Assert; * username/password pair, unless a custom Authorization header has been set before. * * @author Dmytro Nosan - * @since 2.2.0 */ -public class BasicAuthenticationClientHttpRequestFactory +class BasicAuthenticationClientHttpRequestFactory extends AbstractClientHttpRequestFactoryWrapper { private final BasicAuthentication authentication; - /** - * Create a new {@link BasicAuthenticationClientHttpRequestFactory} which adds - * {@link HttpHeaders#AUTHORIZATION} header for the given authentication. - * @param authentication the authentication to use - * @param clientHttpRequestFactory the factory to use - */ - public BasicAuthenticationClientHttpRequestFactory(BasicAuthentication authentication, + BasicAuthenticationClientHttpRequestFactory(BasicAuthentication authentication, ClientHttpRequestFactory clientHttpRequestFactory) { super(clientHttpRequestFactory); Assert.notNull(authentication, "Authentication must not be null"); @@ -54,13 +46,8 @@ public class BasicAuthenticationClientHttpRequestFactory @Override protected ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod, ClientHttpRequestFactory requestFactory) throws IOException { - BasicAuthentication authentication = this.authentication; ClientHttpRequest request = requestFactory.createRequest(uri, httpMethod); - HttpHeaders headers = request.getHeaders(); - if (!headers.containsKey(HttpHeaders.AUTHORIZATION)) { - headers.setBasicAuth(authentication.getUsername(), - authentication.getPassword(), authentication.getCharset()); - } + this.authentication.applyTo(request); return request; } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java index 00c0cdfd50..784aacbb0c 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java @@ -19,12 +19,14 @@ package org.springframework.boot.web.client; import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.lang.reflect.Method; +import java.nio.charset.Charset; import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.LinkedHashSet; +import java.util.List; import java.util.Set; import java.util.function.Consumer; import java.util.function.Supplier; @@ -33,7 +35,6 @@ import org.springframework.beans.BeanUtils; import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestInterceptor; -import org.springframework.http.client.InterceptingClientHttpRequestFactory; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -372,26 +373,32 @@ public class RestTemplateBuilder { } /** - * Add HTTP basic authentication to requests. See - * {@link BasicAuthenticationClientHttpRequestFactory} for details. + * Add HTTP Basic Authentication to requests with the given username/password pair, + * unless a custom Authorization header has been set before. * @param username the user name * @param password the password * @return a new builder instance * @since 2.1.0 + * @see #basicAuthentication(String, String, Charset) */ public RestTemplateBuilder basicAuthentication(String username, String password) { - return basicAuthentication(new BasicAuthentication(username, password)); + return basicAuthentication(username, password, null); } /** - * Add HTTP basic authentication to requests. See - * {@link BasicAuthenticationClientHttpRequestFactory} for details. - * @param basicAuthentication the authentication + * Add HTTP Basic Authentication to requests with the given username/password pair, + * unless a custom Authorization header has been set before. + * @param username the user name + * @param password the password + * @param charset the charset to use * @return a new builder instance * @since 2.2.0 + * @see #basicAuthentication(String, String) */ - public RestTemplateBuilder basicAuthentication( - BasicAuthentication basicAuthentication) { + public RestTemplateBuilder basicAuthentication(String username, String password, + Charset charset) { + BasicAuthentication basicAuthentication = new BasicAuthentication(username, + password, charset); return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, this.messageConverters, this.requestFactorySupplier, this.uriTemplateHandler, this.errorHandler, basicAuthentication, @@ -518,7 +525,6 @@ public class RestTemplateBuilder { * @see RestTemplateBuilder#build() * @see #configure(RestTemplate) */ - public T build(Class restTemplateClass) { return configure(BeanUtils.instantiateClass(restTemplateClass)); } @@ -532,7 +538,13 @@ public class RestTemplateBuilder { * @see RestTemplateBuilder#build(Class) */ public T configure(T restTemplate) { - configureRequestFactory(restTemplate); + ClientHttpRequestFactory requestFactory = buildRequestFactory(); + if (requestFactory != null) { + restTemplate.setRequestFactory(requestFactory); + } + if (this.basicAuthentication != null) { + configureBasicAuthentication(restTemplate); + } if (!CollectionUtils.isEmpty(this.messageConverters)) { restTemplate.setMessageConverters(new ArrayList<>(this.messageConverters)); } @@ -545,9 +557,6 @@ public class RestTemplateBuilder { if (this.rootUri != null) { RootUriTemplateHandler.addTo(restTemplate, this.rootUri); } - if (this.basicAuthentication != null) { - configureBasicAuthentication(restTemplate); - } restTemplate.getInterceptors().addAll(this.interceptors); if (!CollectionUtils.isEmpty(this.restTemplateCustomizers)) { for (RestTemplateCustomizer customizer : this.restTemplateCustomizers) { @@ -557,7 +566,13 @@ public class RestTemplateBuilder { return restTemplate; } - private void configureRequestFactory(RestTemplate restTemplate) { + /** + * Build a new {@link ClientHttpRequestFactory} instance using the settings of this + * builder. + * @return a {@link ClientHttpRequestFactory} or {@code null} + * @since 2.2.0 + */ + public ClientHttpRequestFactory buildRequestFactory() { ClientHttpRequestFactory requestFactory = null; if (this.requestFactorySupplier != null) { requestFactory = this.requestFactorySupplier.get(); @@ -569,27 +584,24 @@ public class RestTemplateBuilder { if (this.requestFactoryCustomizer != null) { this.requestFactoryCustomizer.accept(requestFactory); } - restTemplate.setRequestFactory(requestFactory); } + return requestFactory; } private void configureBasicAuthentication(RestTemplate restTemplate) { - ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory(); - while (requestFactory instanceof InterceptingClientHttpRequestFactory - || requestFactory instanceof BasicAuthenticationClientHttpRequestFactory) { - requestFactory = unwrapRequestFactory( - ((AbstractClientHttpRequestFactoryWrapper) requestFactory)); + List interceptors = null; + if (!restTemplate.getInterceptors().isEmpty()) { + // Stash and clear the interceptors so we can access the real factory + interceptors = new ArrayList<>(restTemplate.getInterceptors()); + restTemplate.getInterceptors().clear(); } + ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory(); restTemplate.setRequestFactory(new BasicAuthenticationClientHttpRequestFactory( this.basicAuthentication, requestFactory)); - } - - private static ClientHttpRequestFactory unwrapRequestFactory( - AbstractClientHttpRequestFactoryWrapper requestFactory) { - Field field = ReflectionUtils.findField( - AbstractClientHttpRequestFactoryWrapper.class, "requestFactory"); - ReflectionUtils.makeAccessible(field); - return (ClientHttpRequestFactory) ReflectionUtils.getField(field, requestFactory); + // Restore the original interceptors + if (interceptors != null) { + restTemplate.getInterceptors().addAll(interceptors); + } } private Set append(Set set, Collection additions) { @@ -638,10 +650,16 @@ public class RestTemplateBuilder { private ClientHttpRequestFactory unwrapRequestFactoryIfNecessary( ClientHttpRequestFactory requestFactory) { + if (!(requestFactory instanceof AbstractClientHttpRequestFactoryWrapper)) { + return requestFactory; + } + Field field = ReflectionUtils.findField( + AbstractClientHttpRequestFactoryWrapper.class, "requestFactory"); + ReflectionUtils.makeAccessible(field); ClientHttpRequestFactory unwrappedRequestFactory = requestFactory; while (unwrappedRequestFactory instanceof AbstractClientHttpRequestFactoryWrapper) { - unwrappedRequestFactory = unwrapRequestFactory( - ((AbstractClientHttpRequestFactoryWrapper) unwrappedRequestFactory)); + unwrappedRequestFactory = (ClientHttpRequestFactory) ReflectionUtils + .getField(field, unwrappedRequestFactory); } return unwrappedRequestFactory; } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactoryTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactoryTests.java index 5d7d68a788..3562de4b4d 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactoryTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactoryTests.java @@ -29,7 +29,7 @@ import org.springframework.http.client.ClientHttpRequestFactory; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.BDDMockito.when; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; /** @@ -42,7 +42,7 @@ public class BasicAuthenticationClientHttpRequestFactoryTests { private final HttpHeaders headers = new HttpHeaders(); private final BasicAuthentication authentication = new BasicAuthentication("spring", - "boot"); + "boot", null); private ClientHttpRequestFactory requestFactory; @@ -50,8 +50,8 @@ public class BasicAuthenticationClientHttpRequestFactoryTests { public void setUp() throws IOException { ClientHttpRequestFactory requestFactory = mock(ClientHttpRequestFactory.class); ClientHttpRequest request = mock(ClientHttpRequest.class); - when(requestFactory.createRequest(any(), any())).thenReturn(request); - when(request.getHeaders()).thenReturn(this.headers); + given(requestFactory.createRequest(any(), any())).willReturn(request); + given(request.getHeaders()).willReturn(this.headers); this.requestFactory = new BasicAuthenticationClientHttpRequestFactory( this.authentication, requestFactory); } @@ -74,7 +74,7 @@ public class BasicAuthenticationClientHttpRequestFactoryTests { } private ClientHttpRequest createRequest() throws IOException { - return this.requestFactory.createRequest(URI.create("http://localhost:8080"), + return this.requestFactory.createRequest(URI.create("https://localhost:8080"), HttpMethod.POST); } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java index 0dc4ac24ea..22b54cd147 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java @@ -324,13 +324,13 @@ public class RestTemplateBuilderTests { @Test public void basicAuthenticationShouldApply() { - BasicAuthentication basicAuthentication = new BasicAuthentication("spring", - "boot", StandardCharsets.UTF_8); - RestTemplate template = this.builder.basicAuthentication(basicAuthentication) - .build(); + RestTemplate template = this.builder + .basicAuthentication("spring", "boot", StandardCharsets.UTF_8).build(); ClientHttpRequestFactory requestFactory = template.getRequestFactory(); - assertThat(requestFactory).hasFieldOrPropertyWithValue("authentication", - basicAuthentication); + Object authentication = ReflectionTestUtils.getField(requestFactory, + "authentication"); + assertThat(authentication).extracting("username", "password", "charset") + .containsExactly("spring", "boot", StandardCharsets.UTF_8); } @Test @@ -413,13 +413,15 @@ public class RestTemplateBuilderTests { assertThat(restTemplate.getUriTemplateHandler()) .isInstanceOf(RootUriTemplateHandler.class); assertThat(restTemplate.getErrorHandler()).isEqualTo(errorHandler); - ClientHttpRequestFactory interceptingRequestFactory = restTemplate + ClientHttpRequestFactory actualRequestFactory = restTemplate .getRequestFactory(); - assertThat(interceptingRequestFactory) + assertThat(actualRequestFactory) .isInstanceOf(InterceptingClientHttpRequestFactory.class); - Object basicAuthRequestFactory = ReflectionTestUtils - .getField(interceptingRequestFactory, "requestFactory"); - assertThat(basicAuthRequestFactory).hasFieldOrPropertyWithValue( + ClientHttpRequestFactory authRequestFactory = (ClientHttpRequestFactory) ReflectionTestUtils + .getField(actualRequestFactory, "requestFactory"); + assertThat(authRequestFactory).isInstanceOf( + BasicAuthenticationClientHttpRequestFactory.class); + assertThat(authRequestFactory).hasFieldOrPropertyWithValue( "requestFactory", requestFactory); }).build(); }