From 18a0a7a2e87ad7f73ced7a71391c48ef1313b377 Mon Sep 17 00:00:00 2001 From: Dmytro Nosan Date: Wed, 29 May 2019 18:50:52 +0300 Subject: [PATCH 1/2] Use request factory to support Basic Authentication Update `RestTemplateBuilder` to use a custom request factory to add authentication headers rather than an interceptor. Prior to this commit, the use of the `BasicAuthenticationInterceptor` interceptor could cause `OutOfMemoryError` whenever a large file is uploaded. See gh-17010 --- .../test/web/client/TestRestTemplate.java | 47 +++++----- .../web/client/TestRestTemplateTests.java | 43 ++++------ .../boot/web/client/BasicAuthentication.java | 86 +++++++++++++++++++ ...uthenticationClientHttpRequestFactory.java | 67 +++++++++++++++ .../boot/web/client/RestTemplateBuilder.java | 58 +++++++++---- ...ticationClientHttpRequestFactoryTests.java | 81 +++++++++++++++++ .../web/client/RestTemplateBuilderTests.java | 25 +++--- 7 files changed, 331 insertions(+), 76 deletions(-) create mode 100644 spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthentication.java create mode 100644 spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactory.java create mode 100644 spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactoryTests.java diff --git a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java index 4adea0980d..60864e8a21 100644 --- a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java +++ b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java @@ -19,11 +19,8 @@ package org.springframework.boot.test.web.client; import java.io.IOException; import java.lang.reflect.Field; import java.net.URI; -import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Supplier; @@ -41,6 +38,8 @@ import org.apache.http.ssl.SSLContextBuilder; import org.springframework.beans.BeanInstantiationException; import org.springframework.beans.BeanUtils; +import org.springframework.boot.web.client.BasicAuthentication; +import org.springframework.boot.web.client.BasicAuthenticationClientHttpRequestFactory; import org.springframework.boot.web.client.ClientHttpRequestFactorySupplier; import org.springframework.boot.web.client.RestTemplateBuilder; import org.springframework.boot.web.client.RootUriTemplateHandler; @@ -50,12 +49,11 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; +import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; import org.springframework.http.client.ClientHttpRequestFactory; -import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.http.client.InterceptingClientHttpRequestFactory; -import org.springframework.http.client.support.BasicAuthenticationInterceptor; import org.springframework.util.Assert; import org.springframework.util.ReflectionUtils; import org.springframework.web.client.DefaultResponseErrorHandler; @@ -86,6 +84,7 @@ import org.springframework.web.util.UriTemplateHandler; * @author Phillip Webb * @author Andy Wilkinson * @author Kristine Jetzke + * @author Dmytro Nosan * @since 1.4.0 */ public class TestRestTemplate { @@ -154,31 +153,35 @@ public class TestRestTemplate { private Class getRequestFactoryClass( RestTemplate restTemplate) { + return getRequestFactory(restTemplate).getClass(); + } + + private ClientHttpRequestFactory getRequestFactory(RestTemplate restTemplate) { ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory(); - if (InterceptingClientHttpRequestFactory.class - .isAssignableFrom(requestFactory.getClass())) { - Field requestFactoryField = ReflectionUtils.findField(RestTemplate.class, - "requestFactory"); - ReflectionUtils.makeAccessible(requestFactoryField); - requestFactory = (ClientHttpRequestFactory) ReflectionUtils - .getField(requestFactoryField, restTemplate); + while (requestFactory instanceof InterceptingClientHttpRequestFactory + || requestFactory instanceof BasicAuthenticationClientHttpRequestFactory) { + requestFactory = unwrapRequestFactory( + ((AbstractClientHttpRequestFactoryWrapper) requestFactory)); } - return requestFactory.getClass(); + return requestFactory; + } + + private ClientHttpRequestFactory unwrapRequestFactory( + AbstractClientHttpRequestFactoryWrapper requestFactory) { + Field field = ReflectionUtils.findField( + AbstractClientHttpRequestFactoryWrapper.class, "requestFactory"); + ReflectionUtils.makeAccessible(field); + return (ClientHttpRequestFactory) ReflectionUtils.getField(field, requestFactory); } private void addAuthentication(RestTemplate restTemplate, String username, String password) { - if (username == null) { + if (username == null || password == null) { return; } - List interceptors = restTemplate.getInterceptors(); - if (interceptors == null) { - interceptors = Collections.emptyList(); - } - interceptors = new ArrayList<>(interceptors); - interceptors.removeIf(BasicAuthenticationInterceptor.class::isInstance); - interceptors.add(new BasicAuthenticationInterceptor(username, password)); - restTemplate.setInterceptors(interceptors); + ClientHttpRequestFactory requestFactory = getRequestFactory(restTemplate); + restTemplate.setRequestFactory(new BasicAuthenticationClientHttpRequestFactory( + new BasicAuthentication(username, password), requestFactory)); } /** diff --git a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java index efc1efeece..e4993ed99f 100644 --- a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java +++ b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java @@ -20,13 +20,13 @@ import java.io.IOException; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.net.URI; -import java.util.List; import org.apache.http.client.config.RequestConfig; import org.junit.jupiter.api.Test; import org.springframework.boot.test.web.client.TestRestTemplate.CustomHttpComponentsClientHttpRequestFactory; import org.springframework.boot.test.web.client.TestRestTemplate.HttpClientOption; +import org.springframework.boot.web.client.BasicAuthenticationClientHttpRequestFactory; import org.springframework.boot.web.client.RestTemplateBuilder; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpEntity; @@ -35,12 +35,9 @@ import org.springframework.http.HttpStatus; import org.springframework.http.RequestEntity; import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequestFactory; -import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; -import org.springframework.http.client.InterceptingClientHttpRequestFactory; import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; import org.springframework.http.client.SimpleClientHttpRequestFactory; -import org.springframework.http.client.support.BasicAuthenticationInterceptor; import org.springframework.mock.env.MockEnvironment; import org.springframework.mock.http.client.MockClientHttpRequest; import org.springframework.mock.http.client.MockClientHttpResponse; @@ -150,7 +147,7 @@ public class TestRestTemplateTests { public void authenticated() { assertThat(new TestRestTemplate("user", "password").getRestTemplate() .getRequestFactory()) - .isInstanceOf(InterceptingClientHttpRequestFactory.class); + .isInstanceOf(BasicAuthenticationClientHttpRequestFactory.class); } @Test @@ -227,7 +224,7 @@ public class TestRestTemplateTests { } @Test - public void withBasicAuthAddsBasicAuthInterceptorWhenNotAlreadyPresent() { + public void withBasicAuthAddsBasicAuthClientFactoryWhenNotAlreadyPresent() { TestRestTemplate originalTemplate = new TestRestTemplate(); TestRestTemplate basicAuthTemplate = originalTemplate.withBasicAuth("user", "password"); @@ -235,20 +232,19 @@ public class TestRestTemplateTests { .containsExactlyElementsOf( originalTemplate.getRestTemplate().getMessageConverters()); assertThat(basicAuthTemplate.getRestTemplate().getRequestFactory()) - .isInstanceOf(InterceptingClientHttpRequestFactory.class); + .isInstanceOf(BasicAuthenticationClientHttpRequestFactory.class); assertThat(ReflectionTestUtils.getField( basicAuthTemplate.getRestTemplate().getRequestFactory(), "requestFactory")) .isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class); assertThat(basicAuthTemplate.getRestTemplate().getUriTemplateHandler()) .isSameAs(originalTemplate.getRestTemplate().getUriTemplateHandler()); - assertThat(basicAuthTemplate.getRestTemplate().getInterceptors()).hasSize(1); - assertBasicAuthorizationInterceptorCredentials(basicAuthTemplate, "user", - "password"); + assertThat(basicAuthTemplate.getRestTemplate().getInterceptors()).isEmpty(); + assertBasicAuthorizationCredentials(basicAuthTemplate, "user", "password"); } @Test - public void withBasicAuthReplacesBasicAuthInterceptorWhenAlreadyPresent() { + public void withBasicAuthReplacesBasicAuthClientFactoryWhenAlreadyPresent() { TestRestTemplate original = new TestRestTemplate("foo", "bar") .withBasicAuth("replace", "replace"); TestRestTemplate basicAuth = original.withBasicAuth("user", "password"); @@ -256,14 +252,14 @@ public class TestRestTemplateTests { .containsExactlyElementsOf( original.getRestTemplate().getMessageConverters()); assertThat(basicAuth.getRestTemplate().getRequestFactory()) - .isInstanceOf(InterceptingClientHttpRequestFactory.class); + .isInstanceOf(BasicAuthenticationClientHttpRequestFactory.class); assertThat(ReflectionTestUtils.getField( basicAuth.getRestTemplate().getRequestFactory(), "requestFactory")) .isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class); assertThat(basicAuth.getRestTemplate().getUriTemplateHandler()) .isSameAs(original.getRestTemplate().getUriTemplateHandler()); - assertThat(basicAuth.getRestTemplate().getInterceptors()).hasSize(1); - assertBasicAuthorizationInterceptorCredentials(basicAuth, "user", "password"); + assertThat(basicAuth.getRestTemplate().getInterceptors()).isEmpty(); + assertBasicAuthorizationCredentials(basicAuth, "user", "password"); } @Test @@ -394,17 +390,14 @@ public class TestRestTemplateTests { verify(requestFactory).createRequest(eq(absoluteUri), any(HttpMethod.class)); } - private void assertBasicAuthorizationInterceptorCredentials( - TestRestTemplate testRestTemplate, String username, String password) { - @SuppressWarnings("unchecked") - List requestFactoryInterceptors = (List) ReflectionTestUtils - .getField(testRestTemplate.getRestTemplate().getRequestFactory(), - "interceptors"); - assertThat(requestFactoryInterceptors).hasSize(1); - ClientHttpRequestInterceptor interceptor = requestFactoryInterceptors.get(0); - assertThat(interceptor).isInstanceOf(BasicAuthenticationInterceptor.class); - assertThat(interceptor).hasFieldOrPropertyWithValue("username", username); - assertThat(interceptor).hasFieldOrPropertyWithValue("password", password); + private void assertBasicAuthorizationCredentials(TestRestTemplate testRestTemplate, + String username, String password) { + ClientHttpRequestFactory requestFactory = testRestTemplate.getRestTemplate() + .getRequestFactory(); + Object authentication = ReflectionTestUtils.getField(requestFactory, + "authentication"); + assertThat(authentication).hasFieldOrPropertyWithValue("username", username); + assertThat(authentication).hasFieldOrPropertyWithValue("password", password); } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthentication.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthentication.java new file mode 100644 index 0000000000..c9b1f618f8 --- /dev/null +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthentication.java @@ -0,0 +1,86 @@ +/* + * Copyright 2012-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.client; + +import java.nio.charset.Charset; + +import org.springframework.util.Assert; + +/** + * Basic authentication properties which are used by + * {@link BasicAuthenticationClientHttpRequestFactory}. + * + * @author Dmytro Nosan + * @since 2.2.0 + * @see BasicAuthenticationClientHttpRequestFactory + */ +public class BasicAuthentication { + + private final String username; + + private final String password; + + private final Charset charset; + + /** + * Create a new {@link BasicAuthentication}. + * @param username the username to use + * @param password the password to use + */ + public BasicAuthentication(String username, String password) { + this(username, password, null); + } + + /** + * Create a new {@link BasicAuthentication}. + * @param username the username to use + * @param password the password to use + * @param charset the charset to use + */ + public BasicAuthentication(String username, String password, Charset charset) { + Assert.notNull(username, "Username must not be null"); + Assert.notNull(password, "Password must not be null"); + this.username = username; + this.password = password; + this.charset = charset; + } + + /** + * The username to use. + * @return the username, never {@code null} or {@code empty}. + */ + public String getUsername() { + return this.username; + } + + /** + * The password to use. + * @return the password, never {@code null} or {@code empty}. + */ + public String getPassword() { + return this.password; + } + + /** + * The charset to use. + * @return the charset, or {@code null}. + */ + public Charset getCharset() { + return this.charset; + } + +} diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactory.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactory.java new file mode 100644 index 0000000000..dffe789e01 --- /dev/null +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactory.java @@ -0,0 +1,67 @@ +/* + * Copyright 2012-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.client; + +import java.io.IOException; +import java.net.URI; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.util.Assert; + +/** + * {@link ClientHttpRequestFactory} to apply a given HTTP Basic Authentication + * username/password pair, unless a custom Authorization header has been set before. + * + * @author Dmytro Nosan + * @since 2.2.0 + */ +public class BasicAuthenticationClientHttpRequestFactory + extends AbstractClientHttpRequestFactoryWrapper { + + private final BasicAuthentication authentication; + + /** + * Create a new {@link BasicAuthenticationClientHttpRequestFactory} which adds + * {@link HttpHeaders#AUTHORIZATION} header for the given authentication. + * @param authentication the authentication to use + * @param clientHttpRequestFactory the factory to use + */ + public BasicAuthenticationClientHttpRequestFactory(BasicAuthentication authentication, + ClientHttpRequestFactory clientHttpRequestFactory) { + super(clientHttpRequestFactory); + Assert.notNull(authentication, "Authentication must not be null"); + this.authentication = authentication; + } + + @Override + protected ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod, + ClientHttpRequestFactory requestFactory) throws IOException { + BasicAuthentication authentication = this.authentication; + ClientHttpRequest request = requestFactory.createRequest(uri, httpMethod); + HttpHeaders headers = request.getHeaders(); + if (!headers.containsKey(HttpHeaders.AUTHORIZATION)) { + headers.setBasicAuth(authentication.getUsername(), + authentication.getPassword(), authentication.getCharset()); + } + return request; + } + +} diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java index 4a5a09820e..00c0cdfd50 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java @@ -33,7 +33,7 @@ import org.springframework.beans.BeanUtils; import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestInterceptor; -import org.springframework.http.client.support.BasicAuthenticationInterceptor; +import org.springframework.http.client.InterceptingClientHttpRequestFactory; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -58,6 +58,7 @@ import org.springframework.web.util.UriTemplateHandler; * @author Phillip Webb * @author Andy Wilkinson * @author Brian Clozel + * @author Dmytro Nosan * @since 1.4.0 */ public class RestTemplateBuilder { @@ -74,7 +75,7 @@ public class RestTemplateBuilder { private final ResponseErrorHandler errorHandler; - private final BasicAuthenticationInterceptor basicAuthentication; + private final BasicAuthentication basicAuthentication; private final Set restTemplateCustomizers; @@ -106,7 +107,7 @@ public class RestTemplateBuilder { Set> messageConverters, Supplier requestFactorySupplier, UriTemplateHandler uriTemplateHandler, ResponseErrorHandler errorHandler, - BasicAuthenticationInterceptor basicAuthentication, + BasicAuthentication basicAuthentication, Set restTemplateCustomizers, RequestFactoryCustomizer requestFactoryCustomizer, Set interceptors) { @@ -372,17 +373,28 @@ public class RestTemplateBuilder { /** * Add HTTP basic authentication to requests. See - * {@link BasicAuthenticationInterceptor} for details. + * {@link BasicAuthenticationClientHttpRequestFactory} for details. * @param username the user name * @param password the password * @return a new builder instance * @since 2.1.0 */ public RestTemplateBuilder basicAuthentication(String username, String password) { + return basicAuthentication(new BasicAuthentication(username, password)); + } + + /** + * Add HTTP basic authentication to requests. See + * {@link BasicAuthenticationClientHttpRequestFactory} for details. + * @param basicAuthentication the authentication + * @return a new builder instance + * @since 2.2.0 + */ + public RestTemplateBuilder basicAuthentication( + BasicAuthentication basicAuthentication) { return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, this.messageConverters, this.requestFactorySupplier, - this.uriTemplateHandler, this.errorHandler, - new BasicAuthenticationInterceptor(username, password), + this.uriTemplateHandler, this.errorHandler, basicAuthentication, this.restTemplateCustomizers, this.requestFactoryCustomizer, this.interceptors); } @@ -534,7 +546,7 @@ public class RestTemplateBuilder { RootUriTemplateHandler.addTo(restTemplate, this.rootUri); } if (this.basicAuthentication != null) { - restTemplate.getInterceptors().add(this.basicAuthentication); + configureBasicAuthentication(restTemplate); } restTemplate.getInterceptors().addAll(this.interceptors); if (!CollectionUtils.isEmpty(this.restTemplateCustomizers)) { @@ -561,6 +573,25 @@ public class RestTemplateBuilder { } } + private void configureBasicAuthentication(RestTemplate restTemplate) { + ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory(); + while (requestFactory instanceof InterceptingClientHttpRequestFactory + || requestFactory instanceof BasicAuthenticationClientHttpRequestFactory) { + requestFactory = unwrapRequestFactory( + ((AbstractClientHttpRequestFactoryWrapper) requestFactory)); + } + restTemplate.setRequestFactory(new BasicAuthenticationClientHttpRequestFactory( + this.basicAuthentication, requestFactory)); + } + + private static ClientHttpRequestFactory unwrapRequestFactory( + AbstractClientHttpRequestFactoryWrapper requestFactory) { + Field field = ReflectionUtils.findField( + AbstractClientHttpRequestFactoryWrapper.class, "requestFactory"); + ReflectionUtils.makeAccessible(field); + return (ClientHttpRequestFactory) ReflectionUtils.getField(field, requestFactory); + } + private Set append(Set set, Collection additions) { Set result = new LinkedHashSet<>((set != null) ? set : Collections.emptySet()); result.addAll(additions); @@ -607,18 +638,11 @@ public class RestTemplateBuilder { private ClientHttpRequestFactory unwrapRequestFactoryIfNecessary( ClientHttpRequestFactory requestFactory) { - if (!(requestFactory instanceof AbstractClientHttpRequestFactoryWrapper)) { - return requestFactory; - } ClientHttpRequestFactory unwrappedRequestFactory = requestFactory; - Field field = ReflectionUtils.findField( - AbstractClientHttpRequestFactoryWrapper.class, "requestFactory"); - ReflectionUtils.makeAccessible(field); - do { - unwrappedRequestFactory = (ClientHttpRequestFactory) ReflectionUtils - .getField(field, unwrappedRequestFactory); + while (unwrappedRequestFactory instanceof AbstractClientHttpRequestFactoryWrapper) { + unwrappedRequestFactory = unwrapRequestFactory( + ((AbstractClientHttpRequestFactoryWrapper) unwrappedRequestFactory)); } - while (unwrappedRequestFactory instanceof AbstractClientHttpRequestFactoryWrapper); return unwrappedRequestFactory; } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactoryTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactoryTests.java new file mode 100644 index 0000000000..5d7d68a788 --- /dev/null +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactoryTests.java @@ -0,0 +1,81 @@ +/* + * Copyright 2012-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.client; + +import java.io.IOException; +import java.net.URI; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpRequestFactory; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.when; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link BasicAuthenticationClientHttpRequestFactory}. + * + * @author Dmytro Nosan + */ +public class BasicAuthenticationClientHttpRequestFactoryTests { + + private final HttpHeaders headers = new HttpHeaders(); + + private final BasicAuthentication authentication = new BasicAuthentication("spring", + "boot"); + + private ClientHttpRequestFactory requestFactory; + + @Before + public void setUp() throws IOException { + ClientHttpRequestFactory requestFactory = mock(ClientHttpRequestFactory.class); + ClientHttpRequest request = mock(ClientHttpRequest.class); + when(requestFactory.createRequest(any(), any())).thenReturn(request); + when(request.getHeaders()).thenReturn(this.headers); + this.requestFactory = new BasicAuthenticationClientHttpRequestFactory( + this.authentication, requestFactory); + } + + @Test + public void shouldAddAuthorizationHeader() throws IOException { + ClientHttpRequest request = createRequest(); + assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)) + .containsExactly("Basic c3ByaW5nOmJvb3Q="); + } + + @Test + public void shouldNotAddAuthorizationHeaderAuthorizationAlreadySet() + throws IOException { + this.headers.setBasicAuth("boot", "spring"); + ClientHttpRequest request = createRequest(); + assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)) + .doesNotContain("Basic c3ByaW5nOmJvb3Q="); + + } + + private ClientHttpRequest createRequest() throws IOException { + return this.requestFactory.createRequest(URI.create("http://localhost:8080"), + HttpMethod.POST); + } + +} diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java index f2070c7f9a..0dc4ac24ea 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java @@ -16,6 +16,7 @@ package org.springframework.boot.web.client; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.Collections; import java.util.Set; @@ -35,7 +36,6 @@ import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.http.client.InterceptingClientHttpRequestFactory; import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; import org.springframework.http.client.SimpleClientHttpRequestFactory; -import org.springframework.http.client.support.BasicAuthenticationInterceptor; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.ResourceHttpMessageConverter; import org.springframework.http.converter.StringHttpMessageConverter; @@ -324,12 +324,13 @@ public class RestTemplateBuilderTests { @Test public void basicAuthenticationShouldApply() { - RestTemplate template = this.builder.basicAuthentication("spring", "boot") + BasicAuthentication basicAuthentication = new BasicAuthentication("spring", + "boot", StandardCharsets.UTF_8); + RestTemplate template = this.builder.basicAuthentication(basicAuthentication) .build(); - ClientHttpRequestInterceptor interceptor = template.getInterceptors().get(0); - assertThat(interceptor).isInstanceOf(BasicAuthenticationInterceptor.class); - assertThat(interceptor).extracting("username").containsExactly("spring"); - assertThat(interceptor).extracting("password").containsExactly("boot"); + ClientHttpRequestFactory requestFactory = template.getRequestFactory(); + assertThat(requestFactory).hasFieldOrPropertyWithValue("authentication", + basicAuthentication); } @Test @@ -406,19 +407,19 @@ public class RestTemplateBuilderTests { .messageConverters(this.messageConverter).rootUri("http://localhost:8080") .errorHandler(errorHandler).basicAuthentication("spring", "boot") .requestFactory(() -> requestFactory).customizers((restTemplate) -> { - assertThat(restTemplate.getInterceptors()).hasSize(2) - .contains(this.interceptor).anyMatch( - (ic) -> ic instanceof BasicAuthenticationInterceptor); + assertThat(restTemplate.getInterceptors()).hasSize(1); assertThat(restTemplate.getMessageConverters()) .contains(this.messageConverter); assertThat(restTemplate.getUriTemplateHandler()) .isInstanceOf(RootUriTemplateHandler.class); assertThat(restTemplate.getErrorHandler()).isEqualTo(errorHandler); - ClientHttpRequestFactory actualRequestFactory = restTemplate + ClientHttpRequestFactory interceptingRequestFactory = restTemplate .getRequestFactory(); - assertThat(actualRequestFactory) + assertThat(interceptingRequestFactory) .isInstanceOf(InterceptingClientHttpRequestFactory.class); - assertThat(actualRequestFactory).hasFieldOrPropertyWithValue( + Object basicAuthRequestFactory = ReflectionTestUtils + .getField(interceptingRequestFactory, "requestFactory"); + assertThat(basicAuthRequestFactory).hasFieldOrPropertyWithValue( "requestFactory", requestFactory); }).build(); } From 76e075ddd02d011b1c02cb53ad98fb042c8259ea Mon Sep 17 00:00:00 2001 From: Phillip Webb Date: Fri, 31 May 2019 16:45:34 -0700 Subject: [PATCH 2/2] Polish "Use request factory to support Basic Authentication" Reduce the surface area of the public API by making the `BasicAuthentication` and `BasicAuthenticationClientHttpRequestFactory` class package private. This commit also attempts to simplify `TestRestTemplate` by keeping the `RestTemplateBuilder` and reusing it, rather than needing to deal only with a `RestTemplate` instance. See gh-17010 --- .../test/web/client/TestRestTemplate.java | 101 ++++-------------- .../web/client/TestRestTemplateTests.java | 69 +++++------- .../boot/web/client/BasicAuthentication.java | 49 ++------- ...uthenticationClientHttpRequestFactory.java | 19 +--- .../boot/web/client/RestTemplateBuilder.java | 80 ++++++++------ ...ticationClientHttpRequestFactoryTests.java | 10 +- .../web/client/RestTemplateBuilderTests.java | 24 +++-- 7 files changed, 128 insertions(+), 224 deletions(-) diff --git a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java index 60864e8a21..dde8c630d1 100644 --- a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java +++ b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java @@ -17,13 +17,11 @@ package org.springframework.boot.test.web.client; import java.io.IOException; -import java.lang.reflect.Field; import java.net.URI; import java.util.Arrays; import java.util.HashSet; import java.util.Map; import java.util.Set; -import java.util.function.Supplier; import org.apache.http.client.HttpClient; import org.apache.http.client.config.CookieSpecs; @@ -36,11 +34,6 @@ import org.apache.http.impl.client.HttpClients; import org.apache.http.protocol.HttpContext; import org.apache.http.ssl.SSLContextBuilder; -import org.springframework.beans.BeanInstantiationException; -import org.springframework.beans.BeanUtils; -import org.springframework.boot.web.client.BasicAuthentication; -import org.springframework.boot.web.client.BasicAuthenticationClientHttpRequestFactory; -import org.springframework.boot.web.client.ClientHttpRequestFactorySupplier; import org.springframework.boot.web.client.RestTemplateBuilder; import org.springframework.boot.web.client.RootUriTemplateHandler; import org.springframework.core.ParameterizedTypeReference; @@ -49,13 +42,10 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; -import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; -import org.springframework.http.client.InterceptingClientHttpRequestFactory; import org.springframework.util.Assert; -import org.springframework.util.ReflectionUtils; import org.springframework.web.client.DefaultResponseErrorHandler; import org.springframework.web.client.RequestCallback; import org.springframework.web.client.ResponseExtractor; @@ -89,10 +79,12 @@ import org.springframework.web.util.UriTemplateHandler; */ public class TestRestTemplate { - private final RestTemplate restTemplate; + private final RestTemplateBuilder builder; private final HttpClientOption[] httpClientOptions; + private final RestTemplate restTemplate; + /** * Create a new {@link TestRestTemplate} instance. * @param restTemplateBuilder builder used to configure underlying @@ -124,64 +116,30 @@ public class TestRestTemplate { /** * Create a new {@link TestRestTemplate} instance with the specified credentials. - * @param restTemplateBuilder builder used to configure underlying - * {@link RestTemplate} + * @param builder builder used to configure underlying {@link RestTemplate} * @param username the username to use (or {@code null}) * @param password the password (or {@code null}) * @param httpClientOptions client options to use if the Apache HTTP Client is used * @since 2.0.0 */ - public TestRestTemplate(RestTemplateBuilder restTemplateBuilder, String username, - String password, HttpClientOption... httpClientOptions) { - this((restTemplateBuilder != null) ? restTemplateBuilder.build() : null, username, - password, httpClientOptions); - } - - private TestRestTemplate(RestTemplate restTemplate, String username, String password, + public TestRestTemplate(RestTemplateBuilder builder, String username, String password, HttpClientOption... httpClientOptions) { - Assert.notNull(restTemplate, "RestTemplate must not be null"); + Assert.notNull(builder, "Builder must not be null"); + this.builder = builder; this.httpClientOptions = httpClientOptions; - if (getRequestFactoryClass(restTemplate) - .isAssignableFrom(HttpComponentsClientHttpRequestFactory.class)) { - restTemplate.setRequestFactory( - new CustomHttpComponentsClientHttpRequestFactory(httpClientOptions)); - } - addAuthentication(restTemplate, username, password); - restTemplate.setErrorHandler(new NoOpResponseErrorHandler()); - this.restTemplate = restTemplate; - } - - private Class getRequestFactoryClass( - RestTemplate restTemplate) { - return getRequestFactory(restTemplate).getClass(); - } - - private ClientHttpRequestFactory getRequestFactory(RestTemplate restTemplate) { - ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory(); - while (requestFactory instanceof InterceptingClientHttpRequestFactory - || requestFactory instanceof BasicAuthenticationClientHttpRequestFactory) { - requestFactory = unwrapRequestFactory( - ((AbstractClientHttpRequestFactoryWrapper) requestFactory)); + if (httpClientOptions != null) { + ClientHttpRequestFactory requestFactory = builder.buildRequestFactory(); + if (requestFactory instanceof HttpComponentsClientHttpRequestFactory) { + builder = builder.requestFactory( + () -> new CustomHttpComponentsClientHttpRequestFactory( + httpClientOptions)); + } } - return requestFactory; - } - - private ClientHttpRequestFactory unwrapRequestFactory( - AbstractClientHttpRequestFactoryWrapper requestFactory) { - Field field = ReflectionUtils.findField( - AbstractClientHttpRequestFactoryWrapper.class, "requestFactory"); - ReflectionUtils.makeAccessible(field); - return (ClientHttpRequestFactory) ReflectionUtils.getField(field, requestFactory); - } - - private void addAuthentication(RestTemplate restTemplate, String username, - String password) { - if (username == null || password == null) { - return; + if (username != null || password != null) { + builder = builder.basicAuthentication(username, password); } - ClientHttpRequestFactory requestFactory = getRequestFactory(restTemplate); - restTemplate.setRequestFactory(new BasicAuthenticationClientHttpRequestFactory( - new BasicAuthentication(username, password), requestFactory)); + this.restTemplate = builder.build(); + this.restTemplate.setErrorHandler(new NoOpResponseErrorHandler()); } /** @@ -1038,25 +996,10 @@ public class TestRestTemplate { * @since 1.4.1 */ public TestRestTemplate withBasicAuth(String username, String password) { - RestTemplate restTemplate = new RestTemplateBuilder() - .requestFactory(getRequestFactorySupplier()) - .messageConverters(getRestTemplate().getMessageConverters()) - .interceptors(getRestTemplate().getInterceptors()) - .uriTemplateHandler(getRestTemplate().getUriTemplateHandler()).build(); - return new TestRestTemplate(restTemplate, username, password, + TestRestTemplate template = new TestRestTemplate(this.builder, username, password, this.httpClientOptions); - } - - private Supplier getRequestFactorySupplier() { - return () -> { - try { - return BeanUtils - .instantiateClass(getRequestFactoryClass(getRestTemplate())); - } - catch (BeanInstantiationException ex) { - return new ClientHttpRequestFactorySupplier().get(); - } - }; + template.setUriTemplateHandler(getRestTemplate().getUriTemplateHandler()); + return template; } @SuppressWarnings({ "rawtypes", "unchecked" }) @@ -1078,7 +1021,7 @@ public class TestRestTemplate { } /** - * Options used to customize the Apache Http Client if it is used. + * Options used to customize the Apache HTTP Client. */ public enum HttpClientOption { diff --git a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java index e4993ed99f..e2feb5b009 100644 --- a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java +++ b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java @@ -20,13 +20,14 @@ import java.io.IOException; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.net.URI; +import java.util.List; +import java.util.stream.Collectors; import org.apache.http.client.config.RequestConfig; import org.junit.jupiter.api.Test; import org.springframework.boot.test.web.client.TestRestTemplate.CustomHttpComponentsClientHttpRequestFactory; import org.springframework.boot.test.web.client.TestRestTemplate.HttpClientOption; -import org.springframework.boot.web.client.BasicAuthenticationClientHttpRequestFactory; import org.springframework.boot.web.client.RestTemplateBuilder; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpEntity; @@ -99,25 +100,11 @@ public class TestRestTemplateTests { TestRestTemplate testRestTemplate = new TestRestTemplate(builder) .withBasicAuth("test", "test"); RestTemplate restTemplate = testRestTemplate.getRestTemplate(); + assertThat(restTemplate.getRequestFactory().getClass().getName()) + .contains("BasicAuth"); Object requestFactory = ReflectionTestUtils .getField(restTemplate.getRequestFactory(), "requestFactory"); - assertThat(requestFactory).isNotEqualTo(customFactory) - .hasSameClassAs(customFactory); - } - - @Test - public void withBasicAuthWhenRequestFactoryTypeCannotBeInstantiatedShouldFallback() { - TestClientHttpRequestFactory customFactory = new TestClientHttpRequestFactory( - "my-request-factory"); - RestTemplateBuilder builder = new RestTemplateBuilder() - .requestFactory(() -> customFactory); - TestRestTemplate testRestTemplate = new TestRestTemplate(builder) - .withBasicAuth("test", "test"); - RestTemplate restTemplate = testRestTemplate.getRestTemplate(); - Object requestFactory = ReflectionTestUtils - .getField(restTemplate.getRequestFactory(), "requestFactory"); - assertThat(requestFactory).isNotEqualTo(customFactory) - .isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class); + assertThat(requestFactory).isEqualTo(customFactory).hasSameClassAs(customFactory); } @Test @@ -145,9 +132,10 @@ public class TestRestTemplateTests { @Test public void authenticated() { - assertThat(new TestRestTemplate("user", "password").getRestTemplate() - .getRequestFactory()) - .isInstanceOf(BasicAuthenticationClientHttpRequestFactory.class); + RestTemplate restTemplate = new TestRestTemplate("user", "password") + .getRestTemplate(); + ClientHttpRequestFactory factory = restTemplate.getRequestFactory(); + assertThat(factory.getClass().getName()).contains("BasicAuthentication"); } @Test @@ -225,22 +213,17 @@ public class TestRestTemplateTests { @Test public void withBasicAuthAddsBasicAuthClientFactoryWhenNotAlreadyPresent() { - TestRestTemplate originalTemplate = new TestRestTemplate(); - TestRestTemplate basicAuthTemplate = originalTemplate.withBasicAuth("user", - "password"); - assertThat(basicAuthTemplate.getRestTemplate().getMessageConverters()) - .containsExactlyElementsOf( - originalTemplate.getRestTemplate().getMessageConverters()); - assertThat(basicAuthTemplate.getRestTemplate().getRequestFactory()) - .isInstanceOf(BasicAuthenticationClientHttpRequestFactory.class); + TestRestTemplate original = new TestRestTemplate(); + TestRestTemplate basicAuth = original.withBasicAuth("user", "password"); + assertThat(getConverterClasses(original)) + .containsExactlyElementsOf(getConverterClasses(basicAuth)); + assertThat(basicAuth.getRestTemplate().getRequestFactory().getClass().getName()) + .contains("BasicAuth"); assertThat(ReflectionTestUtils.getField( - basicAuthTemplate.getRestTemplate().getRequestFactory(), - "requestFactory")) + basicAuth.getRestTemplate().getRequestFactory(), "requestFactory")) .isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class); - assertThat(basicAuthTemplate.getRestTemplate().getUriTemplateHandler()) - .isSameAs(originalTemplate.getRestTemplate().getUriTemplateHandler()); - assertThat(basicAuthTemplate.getRestTemplate().getInterceptors()).isEmpty(); - assertBasicAuthorizationCredentials(basicAuthTemplate, "user", "password"); + assertThat(basicAuth.getRestTemplate().getInterceptors()).isEmpty(); + assertBasicAuthorizationCredentials(basicAuth, "user", "password"); } @Test @@ -248,20 +231,22 @@ public class TestRestTemplateTests { TestRestTemplate original = new TestRestTemplate("foo", "bar") .withBasicAuth("replace", "replace"); TestRestTemplate basicAuth = original.withBasicAuth("user", "password"); - assertThat(basicAuth.getRestTemplate().getMessageConverters()) - .containsExactlyElementsOf( - original.getRestTemplate().getMessageConverters()); - assertThat(basicAuth.getRestTemplate().getRequestFactory()) - .isInstanceOf(BasicAuthenticationClientHttpRequestFactory.class); + assertThat(getConverterClasses(basicAuth)) + .containsExactlyElementsOf(getConverterClasses(original)); + assertThat(basicAuth.getRestTemplate().getRequestFactory().getClass().getName()) + .contains("BasicAuth"); assertThat(ReflectionTestUtils.getField( basicAuth.getRestTemplate().getRequestFactory(), "requestFactory")) .isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class); - assertThat(basicAuth.getRestTemplate().getUriTemplateHandler()) - .isSameAs(original.getRestTemplate().getUriTemplateHandler()); assertThat(basicAuth.getRestTemplate().getInterceptors()).isEmpty(); assertBasicAuthorizationCredentials(basicAuth, "user", "password"); } + private List> getConverterClasses(TestRestTemplate testRestTemplate) { + return testRestTemplate.getRestTemplate().getMessageConverters().stream() + .map(Object::getClass).collect(Collectors.toList()); + } + @Test public void withBasicAuthShouldUseNoOpErrorHandler() throws Exception { TestRestTemplate originalTemplate = new TestRestTemplate("foo", "bar"); diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthentication.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthentication.java index c9b1f618f8..059499ccef 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthentication.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthentication.java @@ -18,6 +18,8 @@ package org.springframework.boot.web.client; import java.nio.charset.Charset; +import org.springframework.http.HttpHeaders; +import org.springframework.http.client.ClientHttpRequest; import org.springframework.util.Assert; /** @@ -25,10 +27,9 @@ import org.springframework.util.Assert; * {@link BasicAuthenticationClientHttpRequestFactory}. * * @author Dmytro Nosan - * @since 2.2.0 * @see BasicAuthenticationClientHttpRequestFactory */ -public class BasicAuthentication { +class BasicAuthentication { private final String username; @@ -36,22 +37,7 @@ public class BasicAuthentication { private final Charset charset; - /** - * Create a new {@link BasicAuthentication}. - * @param username the username to use - * @param password the password to use - */ - public BasicAuthentication(String username, String password) { - this(username, password, null); - } - - /** - * Create a new {@link BasicAuthentication}. - * @param username the username to use - * @param password the password to use - * @param charset the charset to use - */ - public BasicAuthentication(String username, String password, Charset charset) { + BasicAuthentication(String username, String password, Charset charset) { Assert.notNull(username, "Username must not be null"); Assert.notNull(password, "Password must not be null"); this.username = username; @@ -59,28 +45,11 @@ public class BasicAuthentication { this.charset = charset; } - /** - * The username to use. - * @return the username, never {@code null} or {@code empty}. - */ - public String getUsername() { - return this.username; - } - - /** - * The password to use. - * @return the password, never {@code null} or {@code empty}. - */ - public String getPassword() { - return this.password; - } - - /** - * The charset to use. - * @return the charset, or {@code null}. - */ - public Charset getCharset() { - return this.charset; + void applyTo(ClientHttpRequest request) { + HttpHeaders headers = request.getHeaders(); + if (!headers.containsKey(HttpHeaders.AUTHORIZATION)) { + headers.setBasicAuth(this.username, this.password, this.charset); + } } } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactory.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactory.java index dffe789e01..016a2cb0fa 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactory.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactory.java @@ -19,7 +19,6 @@ package org.springframework.boot.web.client; import java.io.IOException; import java.net.URI; -import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; import org.springframework.http.client.ClientHttpRequest; @@ -31,20 +30,13 @@ import org.springframework.util.Assert; * username/password pair, unless a custom Authorization header has been set before. * * @author Dmytro Nosan - * @since 2.2.0 */ -public class BasicAuthenticationClientHttpRequestFactory +class BasicAuthenticationClientHttpRequestFactory extends AbstractClientHttpRequestFactoryWrapper { private final BasicAuthentication authentication; - /** - * Create a new {@link BasicAuthenticationClientHttpRequestFactory} which adds - * {@link HttpHeaders#AUTHORIZATION} header for the given authentication. - * @param authentication the authentication to use - * @param clientHttpRequestFactory the factory to use - */ - public BasicAuthenticationClientHttpRequestFactory(BasicAuthentication authentication, + BasicAuthenticationClientHttpRequestFactory(BasicAuthentication authentication, ClientHttpRequestFactory clientHttpRequestFactory) { super(clientHttpRequestFactory); Assert.notNull(authentication, "Authentication must not be null"); @@ -54,13 +46,8 @@ public class BasicAuthenticationClientHttpRequestFactory @Override protected ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod, ClientHttpRequestFactory requestFactory) throws IOException { - BasicAuthentication authentication = this.authentication; ClientHttpRequest request = requestFactory.createRequest(uri, httpMethod); - HttpHeaders headers = request.getHeaders(); - if (!headers.containsKey(HttpHeaders.AUTHORIZATION)) { - headers.setBasicAuth(authentication.getUsername(), - authentication.getPassword(), authentication.getCharset()); - } + this.authentication.applyTo(request); return request; } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java index 00c0cdfd50..784aacbb0c 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java @@ -19,12 +19,14 @@ package org.springframework.boot.web.client; import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.lang.reflect.Method; +import java.nio.charset.Charset; import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.LinkedHashSet; +import java.util.List; import java.util.Set; import java.util.function.Consumer; import java.util.function.Supplier; @@ -33,7 +35,6 @@ import org.springframework.beans.BeanUtils; import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestInterceptor; -import org.springframework.http.client.InterceptingClientHttpRequestFactory; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -372,26 +373,32 @@ public class RestTemplateBuilder { } /** - * Add HTTP basic authentication to requests. See - * {@link BasicAuthenticationClientHttpRequestFactory} for details. + * Add HTTP Basic Authentication to requests with the given username/password pair, + * unless a custom Authorization header has been set before. * @param username the user name * @param password the password * @return a new builder instance * @since 2.1.0 + * @see #basicAuthentication(String, String, Charset) */ public RestTemplateBuilder basicAuthentication(String username, String password) { - return basicAuthentication(new BasicAuthentication(username, password)); + return basicAuthentication(username, password, null); } /** - * Add HTTP basic authentication to requests. See - * {@link BasicAuthenticationClientHttpRequestFactory} for details. - * @param basicAuthentication the authentication + * Add HTTP Basic Authentication to requests with the given username/password pair, + * unless a custom Authorization header has been set before. + * @param username the user name + * @param password the password + * @param charset the charset to use * @return a new builder instance * @since 2.2.0 + * @see #basicAuthentication(String, String) */ - public RestTemplateBuilder basicAuthentication( - BasicAuthentication basicAuthentication) { + public RestTemplateBuilder basicAuthentication(String username, String password, + Charset charset) { + BasicAuthentication basicAuthentication = new BasicAuthentication(username, + password, charset); return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, this.messageConverters, this.requestFactorySupplier, this.uriTemplateHandler, this.errorHandler, basicAuthentication, @@ -518,7 +525,6 @@ public class RestTemplateBuilder { * @see RestTemplateBuilder#build() * @see #configure(RestTemplate) */ - public T build(Class restTemplateClass) { return configure(BeanUtils.instantiateClass(restTemplateClass)); } @@ -532,7 +538,13 @@ public class RestTemplateBuilder { * @see RestTemplateBuilder#build(Class) */ public T configure(T restTemplate) { - configureRequestFactory(restTemplate); + ClientHttpRequestFactory requestFactory = buildRequestFactory(); + if (requestFactory != null) { + restTemplate.setRequestFactory(requestFactory); + } + if (this.basicAuthentication != null) { + configureBasicAuthentication(restTemplate); + } if (!CollectionUtils.isEmpty(this.messageConverters)) { restTemplate.setMessageConverters(new ArrayList<>(this.messageConverters)); } @@ -545,9 +557,6 @@ public class RestTemplateBuilder { if (this.rootUri != null) { RootUriTemplateHandler.addTo(restTemplate, this.rootUri); } - if (this.basicAuthentication != null) { - configureBasicAuthentication(restTemplate); - } restTemplate.getInterceptors().addAll(this.interceptors); if (!CollectionUtils.isEmpty(this.restTemplateCustomizers)) { for (RestTemplateCustomizer customizer : this.restTemplateCustomizers) { @@ -557,7 +566,13 @@ public class RestTemplateBuilder { return restTemplate; } - private void configureRequestFactory(RestTemplate restTemplate) { + /** + * Build a new {@link ClientHttpRequestFactory} instance using the settings of this + * builder. + * @return a {@link ClientHttpRequestFactory} or {@code null} + * @since 2.2.0 + */ + public ClientHttpRequestFactory buildRequestFactory() { ClientHttpRequestFactory requestFactory = null; if (this.requestFactorySupplier != null) { requestFactory = this.requestFactorySupplier.get(); @@ -569,27 +584,24 @@ public class RestTemplateBuilder { if (this.requestFactoryCustomizer != null) { this.requestFactoryCustomizer.accept(requestFactory); } - restTemplate.setRequestFactory(requestFactory); } + return requestFactory; } private void configureBasicAuthentication(RestTemplate restTemplate) { - ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory(); - while (requestFactory instanceof InterceptingClientHttpRequestFactory - || requestFactory instanceof BasicAuthenticationClientHttpRequestFactory) { - requestFactory = unwrapRequestFactory( - ((AbstractClientHttpRequestFactoryWrapper) requestFactory)); + List interceptors = null; + if (!restTemplate.getInterceptors().isEmpty()) { + // Stash and clear the interceptors so we can access the real factory + interceptors = new ArrayList<>(restTemplate.getInterceptors()); + restTemplate.getInterceptors().clear(); } + ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory(); restTemplate.setRequestFactory(new BasicAuthenticationClientHttpRequestFactory( this.basicAuthentication, requestFactory)); - } - - private static ClientHttpRequestFactory unwrapRequestFactory( - AbstractClientHttpRequestFactoryWrapper requestFactory) { - Field field = ReflectionUtils.findField( - AbstractClientHttpRequestFactoryWrapper.class, "requestFactory"); - ReflectionUtils.makeAccessible(field); - return (ClientHttpRequestFactory) ReflectionUtils.getField(field, requestFactory); + // Restore the original interceptors + if (interceptors != null) { + restTemplate.getInterceptors().addAll(interceptors); + } } private Set append(Set set, Collection additions) { @@ -638,10 +650,16 @@ public class RestTemplateBuilder { private ClientHttpRequestFactory unwrapRequestFactoryIfNecessary( ClientHttpRequestFactory requestFactory) { + if (!(requestFactory instanceof AbstractClientHttpRequestFactoryWrapper)) { + return requestFactory; + } + Field field = ReflectionUtils.findField( + AbstractClientHttpRequestFactoryWrapper.class, "requestFactory"); + ReflectionUtils.makeAccessible(field); ClientHttpRequestFactory unwrappedRequestFactory = requestFactory; while (unwrappedRequestFactory instanceof AbstractClientHttpRequestFactoryWrapper) { - unwrappedRequestFactory = unwrapRequestFactory( - ((AbstractClientHttpRequestFactoryWrapper) unwrappedRequestFactory)); + unwrappedRequestFactory = (ClientHttpRequestFactory) ReflectionUtils + .getField(field, unwrappedRequestFactory); } return unwrappedRequestFactory; } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactoryTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactoryTests.java index 5d7d68a788..3562de4b4d 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactoryTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactoryTests.java @@ -29,7 +29,7 @@ import org.springframework.http.client.ClientHttpRequestFactory; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.BDDMockito.when; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; /** @@ -42,7 +42,7 @@ public class BasicAuthenticationClientHttpRequestFactoryTests { private final HttpHeaders headers = new HttpHeaders(); private final BasicAuthentication authentication = new BasicAuthentication("spring", - "boot"); + "boot", null); private ClientHttpRequestFactory requestFactory; @@ -50,8 +50,8 @@ public class BasicAuthenticationClientHttpRequestFactoryTests { public void setUp() throws IOException { ClientHttpRequestFactory requestFactory = mock(ClientHttpRequestFactory.class); ClientHttpRequest request = mock(ClientHttpRequest.class); - when(requestFactory.createRequest(any(), any())).thenReturn(request); - when(request.getHeaders()).thenReturn(this.headers); + given(requestFactory.createRequest(any(), any())).willReturn(request); + given(request.getHeaders()).willReturn(this.headers); this.requestFactory = new BasicAuthenticationClientHttpRequestFactory( this.authentication, requestFactory); } @@ -74,7 +74,7 @@ public class BasicAuthenticationClientHttpRequestFactoryTests { } private ClientHttpRequest createRequest() throws IOException { - return this.requestFactory.createRequest(URI.create("http://localhost:8080"), + return this.requestFactory.createRequest(URI.create("https://localhost:8080"), HttpMethod.POST); } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java index 0dc4ac24ea..22b54cd147 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java @@ -324,13 +324,13 @@ public class RestTemplateBuilderTests { @Test public void basicAuthenticationShouldApply() { - BasicAuthentication basicAuthentication = new BasicAuthentication("spring", - "boot", StandardCharsets.UTF_8); - RestTemplate template = this.builder.basicAuthentication(basicAuthentication) - .build(); + RestTemplate template = this.builder + .basicAuthentication("spring", "boot", StandardCharsets.UTF_8).build(); ClientHttpRequestFactory requestFactory = template.getRequestFactory(); - assertThat(requestFactory).hasFieldOrPropertyWithValue("authentication", - basicAuthentication); + Object authentication = ReflectionTestUtils.getField(requestFactory, + "authentication"); + assertThat(authentication).extracting("username", "password", "charset") + .containsExactly("spring", "boot", StandardCharsets.UTF_8); } @Test @@ -413,13 +413,15 @@ public class RestTemplateBuilderTests { assertThat(restTemplate.getUriTemplateHandler()) .isInstanceOf(RootUriTemplateHandler.class); assertThat(restTemplate.getErrorHandler()).isEqualTo(errorHandler); - ClientHttpRequestFactory interceptingRequestFactory = restTemplate + ClientHttpRequestFactory actualRequestFactory = restTemplate .getRequestFactory(); - assertThat(interceptingRequestFactory) + assertThat(actualRequestFactory) .isInstanceOf(InterceptingClientHttpRequestFactory.class); - Object basicAuthRequestFactory = ReflectionTestUtils - .getField(interceptingRequestFactory, "requestFactory"); - assertThat(basicAuthRequestFactory).hasFieldOrPropertyWithValue( + ClientHttpRequestFactory authRequestFactory = (ClientHttpRequestFactory) ReflectionTestUtils + .getField(actualRequestFactory, "requestFactory"); + assertThat(authRequestFactory).isInstanceOf( + BasicAuthenticationClientHttpRequestFactory.class); + assertThat(authRequestFactory).hasFieldOrPropertyWithValue( "requestFactory", requestFactory); }).build(); }