diff --git a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/resource/reactive/ReactiveOAuth2ResourceServerJwkConfiguration.java b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/resource/reactive/ReactiveOAuth2ResourceServerJwkConfiguration.java index 223fe26003..5f5cba160e 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/resource/reactive/ReactiveOAuth2ResourceServerJwkConfiguration.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/resource/reactive/ReactiveOAuth2ResourceServerJwkConfiguration.java @@ -61,6 +61,7 @@ import org.springframework.util.CollectionUtils; * @author HaiTao Zhang * @author Anastasiia Losieva * @author Mushtaq Ahmed + * @author Roman Golovin */ @Configuration(proxyBeanMethods = false) class ReactiveOAuth2ResourceServerJwkConfiguration { @@ -71,8 +72,12 @@ class ReactiveOAuth2ResourceServerJwkConfiguration { private final OAuth2ResourceServerProperties.Jwt properties; - JwtConfiguration(OAuth2ResourceServerProperties properties) { + private final List> additionalValidators; + + JwtConfiguration(OAuth2ResourceServerProperties properties, + ObjectProvider> additionalValidators) { this.properties = properties.getJwt(); + this.additionalValidators = additionalValidators.orderedStream().toList(); } @Bean @@ -98,13 +103,16 @@ class ReactiveOAuth2ResourceServerJwkConfiguration { private OAuth2TokenValidator getValidators(OAuth2TokenValidator defaultValidator) { List audiences = this.properties.getAudiences(); - if (CollectionUtils.isEmpty(audiences)) { + if (CollectionUtils.isEmpty(audiences) && this.additionalValidators.isEmpty()) { return defaultValidator; } List> validators = new ArrayList<>(); validators.add(defaultValidator); - validators.add(new JwtClaimValidator>(JwtClaimNames.AUD, - (aud) -> aud != null && !Collections.disjoint(aud, audiences))); + if (!CollectionUtils.isEmpty(audiences)) { + validators.add(new JwtClaimValidator>(JwtClaimNames.AUD, + (aud) -> aud != null && !Collections.disjoint(aud, audiences))); + } + validators.addAll(this.additionalValidators); return new DelegatingOAuth2TokenValidator<>(validators); } diff --git a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/resource/servlet/OAuth2ResourceServerJwtConfiguration.java b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/resource/servlet/OAuth2ResourceServerJwtConfiguration.java index b039ecd6b4..84bafab99d 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/resource/servlet/OAuth2ResourceServerJwtConfiguration.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/resource/servlet/OAuth2ResourceServerJwtConfiguration.java @@ -62,6 +62,7 @@ import static org.springframework.security.config.Customizer.withDefaults; * @author Artsiom Yudovin * @author HaiTao Zhang * @author Mushtaq Ahmed + * @author Roman Golovin */ @Configuration(proxyBeanMethods = false) class OAuth2ResourceServerJwtConfiguration { @@ -72,8 +73,12 @@ class OAuth2ResourceServerJwtConfiguration { private final OAuth2ResourceServerProperties.Jwt properties; - JwtDecoderConfiguration(OAuth2ResourceServerProperties properties) { + private final List> additionalValidators; + + JwtDecoderConfiguration(OAuth2ResourceServerProperties properties, + ObjectProvider> additionalValidators) { this.properties = properties.getJwt(); + this.additionalValidators = additionalValidators.orderedStream().toList(); } @Bean @@ -98,13 +103,16 @@ class OAuth2ResourceServerJwtConfiguration { private OAuth2TokenValidator getValidators(OAuth2TokenValidator defaultValidator) { List audiences = this.properties.getAudiences(); - if (CollectionUtils.isEmpty(audiences)) { + if (CollectionUtils.isEmpty(audiences) && this.additionalValidators.isEmpty()) { return defaultValidator; } List> validators = new ArrayList<>(); validators.add(defaultValidator); - validators.add(new JwtClaimValidator>(JwtClaimNames.AUD, - (aud) -> aud != null && !Collections.disjoint(aud, audiences))); + if (!CollectionUtils.isEmpty(audiences)) { + validators.add(new JwtClaimValidator>(JwtClaimNames.AUD, + (aud) -> aud != null && !Collections.disjoint(aud, audiences))); + } + validators.addAll(this.additionalValidators); return new DelegatingOAuth2TokenValidator<>(validators); } diff --git a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/oauth2/resource/reactive/ReactiveOAuth2ResourceServerAutoConfigurationTests.java b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/oauth2/resource/reactive/ReactiveOAuth2ResourceServerAutoConfigurationTests.java index cc5a652d10..e8165ee189 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/oauth2/resource/reactive/ReactiveOAuth2ResourceServerAutoConfigurationTests.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/oauth2/resource/reactive/ReactiveOAuth2ResourceServerAutoConfigurationTests.java @@ -17,14 +17,17 @@ package org.springframework.boot.autoconfigure.security.oauth2.resource.reactive; import java.io.IOException; -import java.net.MalformedURLException; +import java.net.URI; import java.net.URL; import java.time.Duration; import java.time.Instant; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.function.Consumer; import java.util.stream.Stream; import com.fasterxml.jackson.core.JsonProcessingException; @@ -33,6 +36,7 @@ import com.nimbusds.jose.JWSAlgorithm; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import org.assertj.core.api.InstanceOfAssertFactories; +import org.assertj.core.api.ThrowingConsumer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.mockito.InOrder; @@ -87,6 +91,7 @@ import static org.springframework.security.config.Customizer.withDefaults; * @author HaiTao Zhang * @author Anastasiia Losieva * @author Mushtaq Ahmed + * @author Roman Golovin */ class ReactiveOAuth2ResourceServerAutoConfigurationTests { @@ -438,7 +443,6 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests { .run((context) -> assertThat(context).doesNotHaveBean(ReactiveOpaqueTokenIntrospector.class)); } - @SuppressWarnings("unchecked") @Test void autoConfigurationShouldConfigureResourceServerUsingJwkSetUriAndIssuerUri() throws Exception { this.server = new MockWebServer(); @@ -454,15 +458,11 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests { .run((context) -> { assertThat(context).hasSingleBean(ReactiveJwtDecoder.class); ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class); - DelegatingOAuth2TokenValidator jwtValidator = (DelegatingOAuth2TokenValidator) ReflectionTestUtils - .getField(reactiveJwtDecoder, "jwtValidator"); - Collection> tokenValidators = (Collection>) ReflectionTestUtils - .getField(jwtValidator, "tokenValidators"); - assertThat(tokenValidators).hasAtLeastOneElementOfType(JwtIssuerValidator.class); + validate(jwt().claim("iss", issuer), reactiveJwtDecoder, + (validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class)); }); } - @SuppressWarnings("unchecked") @Test void autoConfigurationShouldNotConfigureIssuerUriAndAudienceJwtValidatorIfPropertyNotConfigured() throws Exception { this.server = new MockWebServer(); @@ -476,13 +476,8 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests { .run((context) -> { assertThat(context).hasSingleBean(ReactiveJwtDecoder.class); ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class); - DelegatingOAuth2TokenValidator jwtValidator = (DelegatingOAuth2TokenValidator) ReflectionTestUtils - .getField(reactiveJwtDecoder, "jwtValidator"); - Collection> tokenValidators = (Collection>) ReflectionTestUtils - .getField(jwtValidator, "tokenValidators"); - assertThat(tokenValidators).hasExactlyElementsOfTypes(JwtTimestampValidator.class); - assertThat(tokenValidators).doesNotHaveAnyElementsOfTypes(JwtClaimValidator.class); - assertThat(tokenValidators).doesNotHaveAnyElementsOfTypes(JwtIssuerValidator.class); + validate(jwt(), reactiveJwtDecoder, (validators) -> assertThat(validators).singleElement() + .isInstanceOf(JwtTimestampValidator.class)); }); } @@ -502,39 +497,15 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests { .run((context) -> { assertThat(context).hasSingleBean(ReactiveJwtDecoder.class); ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class); - validate(issuerUri, reactiveJwtDecoder); + validate( + jwt().claim("iss", URI.create(issuerUri).toURL()) + .claim("aud", List.of("https://test-audience.com")), + reactiveJwtDecoder, + (validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class) + .satisfiesOnlyOnce(audClaimValidator())); }); } - @SuppressWarnings("unchecked") - private void validate(String issuerUri, ReactiveJwtDecoder jwtDecoder) throws MalformedURLException { - DelegatingOAuth2TokenValidator jwtValidator = (DelegatingOAuth2TokenValidator) ReflectionTestUtils - .getField(jwtDecoder, "jwtValidator"); - Jwt.Builder builder = jwt().claim("aud", Collections.singletonList("https://test-audience.com")); - if (issuerUri != null) { - builder.claim("iss", new URL(issuerUri)); - } - Jwt jwt = builder.build(); - assertThat(jwtValidator.validate(jwt).hasErrors()).isFalse(); - Collection> delegates = (Collection>) ReflectionTestUtils - .getField(jwtValidator, "tokenValidators"); - validateDelegates(issuerUri, delegates); - } - - @SuppressWarnings("unchecked") - private void validateDelegates(String issuerUri, Collection> delegates) { - assertThat(delegates).hasAtLeastOneElementOfType(JwtClaimValidator.class); - OAuth2TokenValidator delegatingValidator = delegates.stream() - .filter((v) -> v instanceof DelegatingOAuth2TokenValidator) - .findFirst() - .get(); - Collection> nestedDelegates = (Collection>) ReflectionTestUtils - .getField(delegatingValidator, "tokenValidators"); - if (issuerUri != null) { - assertThat(nestedDelegates).hasAtLeastOneElementOfType(JwtIssuerValidator.class); - } - } - @SuppressWarnings("unchecked") @Test void autoConfigurationShouldConfigureAudienceValidatorIfPropertyProvidedAndIssuerUri() throws Exception { @@ -552,7 +523,12 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests { Mono jwtDecoderSupplier = (Mono) ReflectionTestUtils .getField(supplierJwtDecoderBean, "jwtDecoderMono"); ReactiveJwtDecoder jwtDecoder = jwtDecoderSupplier.block(); - validate(issuerUri, jwtDecoder); + validate( + jwt().claim("iss", URI.create(issuerUri).toURL()) + .claim("aud", List.of("https://test-audience.com")), + jwtDecoder, + (validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class) + .satisfiesOnlyOnce(audClaimValidator())); }); } @@ -570,7 +546,33 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests { .run((context) -> { assertThat(context).hasSingleBean(ReactiveJwtDecoder.class); ReactiveJwtDecoder jwtDecoder = context.getBean(ReactiveJwtDecoder.class); - validate(null, jwtDecoder); + validate(jwt().claim("aud", List.of("https://test-audience.com")), jwtDecoder, + (validators) -> assertThat(validators).satisfiesOnlyOnce(audClaimValidator())); + }); + } + + @SuppressWarnings("unchecked") + @Test + void autoConfigurationShouldConfigureCustomValidators() throws Exception { + this.server = new MockWebServer(); + this.server.start(); + String path = "test"; + String issuer = this.server.url(path).toString(); + String cleanIssuerPath = cleanIssuerPath(issuer); + setupMockResponse(cleanIssuerPath); + String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path; + this.contextRunner + .withPropertyValues("spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com", + "spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri) + .withUserConfiguration(CustomJwtClaimValidatorConfig.class) + .run((context) -> { + assertThat(context).hasSingleBean(ReactiveJwtDecoder.class); + ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class); + OAuth2TokenValidator customValidator = (OAuth2TokenValidator) context + .getBean("customJwtClaimValidator"); + validate(jwt().claim("iss", URI.create(issuerUri).toURL()).claim("custom_claim", "custom_claim_value"), + reactiveJwtDecoder, (validators) -> assertThat(validators).contains(customValidator) + .hasAtLeastOneElementOfType(JwtIssuerValidator.class)); }); } @@ -600,6 +602,30 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests { }); } + @SuppressWarnings("unchecked") + @Test + void customValidatorWhenInvalid() throws Exception { + this.server = new MockWebServer(); + this.server.start(); + String path = "test"; + String issuer = this.server.url(path).toString(); + String cleanIssuerPath = cleanIssuerPath(issuer); + setupMockResponse(cleanIssuerPath); + String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path; + this.contextRunner + .withPropertyValues("spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com", + "spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri) + .withUserConfiguration(CustomJwtClaimValidatorConfig.class) + .run((context) -> { + assertThat(context).hasSingleBean(ReactiveJwtDecoder.class); + ReactiveJwtDecoder jwtDecoder = context.getBean(ReactiveJwtDecoder.class); + DelegatingOAuth2TokenValidator jwtValidator = (DelegatingOAuth2TokenValidator) ReflectionTestUtils + .getField(jwtDecoder, "jwtValidator"); + Jwt jwt = jwt().claim("iss", new URL(issuerUri)).claim("custom_claim", "invalid_value").build(); + assertThat(jwtValidator.validate(jwt).hasErrors()).isTrue(); + }); + } + private void assertFilterConfiguredWithJwtAuthenticationManager(AssertableReactiveWebApplicationContext context) { MatcherSecurityWebFilterChain filterChain = (MatcherSecurityWebFilterChain) context .getBean(BeanIds.SPRING_SECURITY_FILTER_CHAIN); @@ -683,6 +709,37 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests { .subject("mock-test-subject"); } + @SuppressWarnings("unchecked") + private void validate(Jwt.Builder builder, ReactiveJwtDecoder jwtDecoder, + ThrowingConsumer>> validatorsConsumer) { + DelegatingOAuth2TokenValidator jwtValidator = (DelegatingOAuth2TokenValidator) ReflectionTestUtils + .getField(jwtDecoder, "jwtValidator"); + assertThat(jwtValidator.validate(builder.build()).hasErrors()).isFalse(); + validatorsConsumer.accept(extractValidators(jwtValidator)); + } + + @SuppressWarnings("unchecked") + private List> extractValidators(DelegatingOAuth2TokenValidator delegatingValidator) { + Collection> delegates = (Collection>) ReflectionTestUtils + .getField(delegatingValidator, "tokenValidators"); + List> extracted = new ArrayList<>(); + for (OAuth2TokenValidator delegate : delegates) { + if (delegate instanceof DelegatingOAuth2TokenValidator delegatingDelegate) { + extracted.addAll(extractValidators(delegatingDelegate)); + } + else { + extracted.add(delegate); + } + } + return extracted; + } + + private Consumer> audClaimValidator() { + return (validator) -> assertThat(validator).isInstanceOf(JwtClaimValidator.class) + .extracting("claim") + .isEqualTo("aud"); + } + @EnableWebFluxSecurity static class TestConfig { @@ -740,4 +797,14 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests { } + @Configuration(proxyBeanMethods = false) + static class CustomJwtClaimValidatorConfig { + + @Bean + JwtClaimValidator customJwtClaimValidator() { + return new JwtClaimValidator<>("custom_claim", "custom_claim_value"::equals); + } + + } + } diff --git a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/oauth2/resource/servlet/OAuth2ResourceServerAutoConfigurationTests.java b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/oauth2/resource/servlet/OAuth2ResourceServerAutoConfigurationTests.java index 878415ec9f..b7aa1a5ec6 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/oauth2/resource/servlet/OAuth2ResourceServerAutoConfigurationTests.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/oauth2/resource/servlet/OAuth2ResourceServerAutoConfigurationTests.java @@ -16,14 +16,16 @@ package org.springframework.boot.autoconfigure.security.oauth2.resource.servlet; -import java.net.MalformedURLException; +import java.net.URI; import java.net.URL; import java.time.Instant; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Consumer; import java.util.function.Supplier; import com.fasterxml.jackson.core.JsonProcessingException; @@ -33,6 +35,7 @@ import jakarta.servlet.Filter; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import org.assertj.core.api.InstanceOfAssertFactories; +import org.assertj.core.api.ThrowingConsumer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.mockito.InOrder; @@ -80,6 +83,7 @@ import static org.mockito.Mockito.mock; * @author Artsiom Yudovin * @author HaiTao Zhang * @author Mushtaq Ahmed + * @author Roman Golovin */ class OAuth2ResourceServerAutoConfigurationTests { @@ -190,8 +194,8 @@ class OAuth2ResourceServerAutoConfigurationTests { }); } - @Test @SuppressWarnings("unchecked") + @Test void autoConfigurationShouldConfigureResourceServerUsingOidcIssuerUri() throws Exception { this.server = new MockWebServer(); this.server.start(); @@ -215,8 +219,8 @@ class OAuth2ResourceServerAutoConfigurationTests { assertThat(this.server.getRequestCount()).isEqualTo(2); } - @Test @SuppressWarnings("unchecked") + @Test void autoConfigurationShouldConfigureResourceServerUsingOidcRfc8414IssuerUri() throws Exception { this.server = new MockWebServer(); this.server.start(); @@ -240,8 +244,8 @@ class OAuth2ResourceServerAutoConfigurationTests { assertThat(this.server.getRequestCount()).isEqualTo(3); } - @Test @SuppressWarnings("unchecked") + @Test void autoConfigurationShouldConfigureResourceServerUsingOAuthIssuerUri() throws Exception { this.server = new MockWebServer(); this.server.start(); @@ -472,9 +476,8 @@ class OAuth2ResourceServerAutoConfigurationTests { .run((context) -> { assertThat(context).hasSingleBean(JwtDecoder.class); JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class); - assertThat(jwtDecoder).extracting("jwtValidator.tokenValidators") - .asInstanceOf(InstanceOfAssertFactories.collection(OAuth2TokenValidator.class)) - .hasAtLeastOneElementOfType(JwtIssuerValidator.class); + validate(jwt().claim("iss", issuer), jwtDecoder, + (validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class)); }); } @@ -491,11 +494,8 @@ class OAuth2ResourceServerAutoConfigurationTests { .run((context) -> { assertThat(context).hasSingleBean(JwtDecoder.class); JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class); - assertThat(jwtDecoder).extracting("jwtValidator.tokenValidators") - .asInstanceOf(InstanceOfAssertFactories.collection(OAuth2TokenValidator.class)) - .hasExactlyElementsOfTypes(JwtTimestampValidator.class) - .doesNotHaveAnyElementsOfTypes(JwtClaimValidator.class) - .doesNotHaveAnyElementsOfTypes(JwtIssuerValidator.class); + validate(jwt(), jwtDecoder, (validators) -> assertThat(validators).singleElement() + .isInstanceOf(JwtTimestampValidator.class)); }); } @@ -515,7 +515,12 @@ class OAuth2ResourceServerAutoConfigurationTests { .run((context) -> { assertThat(context).hasSingleBean(JwtDecoder.class); JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class); - validate(issuerUri, jwtDecoder); + validate( + jwt().claim("iss", URI.create(issuerUri).toURL()) + .claim("aud", List.of("https://test-audience.com")), + jwtDecoder, + (validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class) + .satisfiesOnlyOnce(audClaimValidator())); }); } @@ -536,36 +541,39 @@ class OAuth2ResourceServerAutoConfigurationTests { Supplier jwtDecoderSupplier = (Supplier) ReflectionTestUtils .getField(supplierJwtDecoderBean, "delegate"); JwtDecoder jwtDecoder = jwtDecoderSupplier.get(); - validate(issuerUri, jwtDecoder); + validate( + jwt().claim("iss", URI.create(issuerUri).toURL()) + .claim("aud", List.of("https://test-audience.com")), + jwtDecoder, + (validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class) + .satisfiesOnlyOnce(audClaimValidator())); }); } @SuppressWarnings("unchecked") - private void validate(String issuerUri, JwtDecoder jwtDecoder) throws MalformedURLException { - DelegatingOAuth2TokenValidator jwtValidator = (DelegatingOAuth2TokenValidator) ReflectionTestUtils - .getField(jwtDecoder, "jwtValidator"); - Jwt.Builder builder = jwt().claim("aud", Collections.singletonList("https://test-audience.com")); - if (issuerUri != null) { - builder.claim("iss", new URL(issuerUri)); - } - Jwt jwt = builder.build(); - assertThat(jwtValidator.validate(jwt).hasErrors()).isFalse(); - Collection> delegates = (Collection>) ReflectionTestUtils - .getField(jwtValidator, "tokenValidators"); - validateDelegates(issuerUri, delegates); - } - - private void validateDelegates(String issuerUri, Collection> delegates) { - assertThat(delegates).hasAtLeastOneElementOfType(JwtClaimValidator.class); - OAuth2TokenValidator delegatingValidator = delegates.stream() - .filter((v) -> v instanceof DelegatingOAuth2TokenValidator) - .findFirst() - .get(); - if (issuerUri != null) { - assertThat(delegatingValidator).extracting("tokenValidators") - .asInstanceOf(InstanceOfAssertFactories.collection(OAuth2TokenValidator.class)) - .hasAtLeastOneElementOfType(JwtIssuerValidator.class); - } + @Test + void autoConfigurationShouldConfigureCustomValidators() throws Exception { + this.server = new MockWebServer(); + this.server.start(); + String path = "test"; + String issuer = this.server.url(path).toString(); + String cleanIssuerPath = cleanIssuerPath(issuer); + setupMockResponse(cleanIssuerPath); + String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path; + this.contextRunner.withPropertyValues("spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri) + .withUserConfiguration(CustomJwtClaimValidatorConfig.class) + .run((context) -> { + SupplierJwtDecoder supplierJwtDecoderBean = context.getBean(SupplierJwtDecoder.class); + Supplier jwtDecoderSupplier = (Supplier) ReflectionTestUtils + .getField(supplierJwtDecoderBean, "delegate"); + JwtDecoder jwtDecoder = jwtDecoderSupplier.get(); + assertThat(context).hasBean("customJwtClaimValidator"); + OAuth2TokenValidator customValidator = (OAuth2TokenValidator) context + .getBean("customJwtClaimValidator"); + validate(jwt().claim("iss", URI.create(issuerUri).toURL()).claim("custom_claim", "custom_claim_value"), + jwtDecoder, (validators) -> assertThat(validators).contains(customValidator) + .hasAtLeastOneElementOfType(JwtIssuerValidator.class)); + }); } @Test @@ -582,7 +590,8 @@ class OAuth2ResourceServerAutoConfigurationTests { .run((context) -> { assertThat(context).hasSingleBean(JwtDecoder.class); JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class); - validate(null, jwtDecoder); + validate(jwt().claim("aud", List.of("https://test-audience.com")), jwtDecoder, + (validators) -> assertThat(validators).satisfiesOnlyOnce(audClaimValidator())); }); } @@ -692,6 +701,37 @@ class OAuth2ResourceServerAutoConfigurationTests { .subject("mock-test-subject"); } + @SuppressWarnings("unchecked") + private void validate(Jwt.Builder builder, JwtDecoder jwtDecoder, + ThrowingConsumer>> validatorsConsumer) { + DelegatingOAuth2TokenValidator jwtValidator = (DelegatingOAuth2TokenValidator) ReflectionTestUtils + .getField(jwtDecoder, "jwtValidator"); + assertThat(jwtValidator.validate(builder.build()).hasErrors()).isFalse(); + validatorsConsumer.accept(extractValidators(jwtValidator)); + } + + @SuppressWarnings("unchecked") + private List> extractValidators(DelegatingOAuth2TokenValidator delegatingValidator) { + Collection> delegates = (Collection>) ReflectionTestUtils + .getField(delegatingValidator, "tokenValidators"); + List> extracted = new ArrayList<>(); + for (OAuth2TokenValidator delegate : delegates) { + if (delegate instanceof DelegatingOAuth2TokenValidator delegatingDelegate) { + extracted.addAll(extractValidators(delegatingDelegate)); + } + else { + extracted.add(delegate); + } + } + return extracted; + } + + private Consumer> audClaimValidator() { + return (validator) -> assertThat(validator).isInstanceOf(JwtClaimValidator.class) + .extracting("claim") + .isEqualTo("aud"); + } + @Configuration(proxyBeanMethods = false) @EnableWebSecurity static class TestConfig { @@ -745,4 +785,14 @@ class OAuth2ResourceServerAutoConfigurationTests { } + @Configuration(proxyBeanMethods = false) + static class CustomJwtClaimValidatorConfig { + + @Bean + JwtClaimValidator customJwtClaimValidator() { + return new JwtClaimValidator<>("custom_claim", "custom_claim_value"::equals); + } + + } + }