From 4feaa28fd1ce83e251aefb1aac1a42272054b9bc Mon Sep 17 00:00:00 2001 From: Andy Wilkinson Date: Wed, 5 Jul 2023 14:01:08 +0100 Subject: [PATCH] Polish "Support custom token validators for OAuth2" See gh-35874 --- ...eOAuth2ResourceServerJwkConfiguration.java | 16 +- .../OAuth2ResourceServerJwtConfiguration.java | 16 +- ...2ResourceServerAutoConfigurationTests.java | 187 ++++++++++-------- ...2ResourceServerAutoConfigurationTests.java | 126 ++++++------ 4 files changed, 181 insertions(+), 164 deletions(-) diff --git a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/resource/reactive/ReactiveOAuth2ResourceServerJwkConfiguration.java b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/resource/reactive/ReactiveOAuth2ResourceServerJwkConfiguration.java index e57b824e93..5f5cba160e 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/resource/reactive/ReactiveOAuth2ResourceServerJwkConfiguration.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/resource/reactive/ReactiveOAuth2ResourceServerJwkConfiguration.java @@ -72,12 +72,12 @@ class ReactiveOAuth2ResourceServerJwkConfiguration { private final OAuth2ResourceServerProperties.Jwt properties; - private final List> customOAuth2TokenValidators; + private final List> additionalValidators; JwtConfiguration(OAuth2ResourceServerProperties properties, - List> customOAuth2TokenValidators) { + ObjectProvider> additionalValidators) { this.properties = properties.getJwt(); - this.customOAuth2TokenValidators = customOAuth2TokenValidators; + this.additionalValidators = additionalValidators.orderedStream().toList(); } @Bean @@ -102,17 +102,17 @@ class ReactiveOAuth2ResourceServerJwkConfiguration { } private OAuth2TokenValidator getValidators(OAuth2TokenValidator defaultValidator) { + List audiences = this.properties.getAudiences(); + if (CollectionUtils.isEmpty(audiences) && this.additionalValidators.isEmpty()) { + return defaultValidator; + } List> validators = new ArrayList<>(); validators.add(defaultValidator); - validators.addAll(this.customOAuth2TokenValidators); - List audiences = this.properties.getAudiences(); if (!CollectionUtils.isEmpty(audiences)) { validators.add(new JwtClaimValidator>(JwtClaimNames.AUD, (aud) -> aud != null && !Collections.disjoint(aud, audiences))); } - if (validators.size() == 1) { - return validators.get(0); - } + validators.addAll(this.additionalValidators); return new DelegatingOAuth2TokenValidator<>(validators); } diff --git a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/resource/servlet/OAuth2ResourceServerJwtConfiguration.java b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/resource/servlet/OAuth2ResourceServerJwtConfiguration.java index bb32ea0149..84bafab99d 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/resource/servlet/OAuth2ResourceServerJwtConfiguration.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/resource/servlet/OAuth2ResourceServerJwtConfiguration.java @@ -73,12 +73,12 @@ class OAuth2ResourceServerJwtConfiguration { private final OAuth2ResourceServerProperties.Jwt properties; - private final List> customOAuth2TokenValidators; + private final List> additionalValidators; JwtDecoderConfiguration(OAuth2ResourceServerProperties properties, - List> customOAuth2TokenValidators) { + ObjectProvider> additionalValidators) { this.properties = properties.getJwt(); - this.customOAuth2TokenValidators = customOAuth2TokenValidators; + this.additionalValidators = additionalValidators.orderedStream().toList(); } @Bean @@ -102,17 +102,17 @@ class OAuth2ResourceServerJwtConfiguration { } private OAuth2TokenValidator getValidators(OAuth2TokenValidator defaultValidator) { + List audiences = this.properties.getAudiences(); + if (CollectionUtils.isEmpty(audiences) && this.additionalValidators.isEmpty()) { + return defaultValidator; + } List> validators = new ArrayList<>(); validators.add(defaultValidator); - validators.addAll(this.customOAuth2TokenValidators); - List audiences = this.properties.getAudiences(); if (!CollectionUtils.isEmpty(audiences)) { validators.add(new JwtClaimValidator>(JwtClaimNames.AUD, (aud) -> aud != null && !Collections.disjoint(aud, audiences))); } - if (validators.size() == 1) { - return validators.get(0); - } + validators.addAll(this.additionalValidators); return new DelegatingOAuth2TokenValidator<>(validators); } diff --git a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/oauth2/resource/reactive/ReactiveOAuth2ResourceServerAutoConfigurationTests.java b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/oauth2/resource/reactive/ReactiveOAuth2ResourceServerAutoConfigurationTests.java index 65c522417f..e8165ee189 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/oauth2/resource/reactive/ReactiveOAuth2ResourceServerAutoConfigurationTests.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/oauth2/resource/reactive/ReactiveOAuth2ResourceServerAutoConfigurationTests.java @@ -17,16 +17,17 @@ package org.springframework.boot.autoconfigure.security.oauth2.resource.reactive; import java.io.IOException; -import java.net.MalformedURLException; +import java.net.URI; import java.net.URL; import java.time.Duration; import java.time.Instant; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; +import java.util.function.Consumer; import java.util.stream.Stream; import com.fasterxml.jackson.core.JsonProcessingException; @@ -35,6 +36,7 @@ import com.nimbusds.jose.JWSAlgorithm; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import org.assertj.core.api.InstanceOfAssertFactories; +import org.assertj.core.api.ThrowingConsumer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.mockito.InOrder; @@ -441,7 +443,6 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests { .run((context) -> assertThat(context).doesNotHaveBean(ReactiveOpaqueTokenIntrospector.class)); } - @SuppressWarnings("unchecked") @Test void autoConfigurationShouldConfigureResourceServerUsingJwkSetUriAndIssuerUri() throws Exception { this.server = new MockWebServer(); @@ -457,15 +458,11 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests { .run((context) -> { assertThat(context).hasSingleBean(ReactiveJwtDecoder.class); ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class); - DelegatingOAuth2TokenValidator jwtValidator = (DelegatingOAuth2TokenValidator) ReflectionTestUtils - .getField(reactiveJwtDecoder, "jwtValidator"); - Collection> tokenValidators = (Collection>) ReflectionTestUtils - .getField(jwtValidator, "tokenValidators"); - assertThat(tokenValidators).hasAtLeastOneElementOfType(JwtIssuerValidator.class); + validate(jwt().claim("iss", issuer), reactiveJwtDecoder, + (validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class)); }); } - @SuppressWarnings("unchecked") @Test void autoConfigurationShouldNotConfigureIssuerUriAndAudienceJwtValidatorIfPropertyNotConfigured() throws Exception { this.server = new MockWebServer(); @@ -479,13 +476,8 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests { .run((context) -> { assertThat(context).hasSingleBean(ReactiveJwtDecoder.class); ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class); - DelegatingOAuth2TokenValidator jwtValidator = (DelegatingOAuth2TokenValidator) ReflectionTestUtils - .getField(reactiveJwtDecoder, "jwtValidator"); - Collection> tokenValidators = (Collection>) ReflectionTestUtils - .getField(jwtValidator, "tokenValidators"); - assertThat(tokenValidators).hasExactlyElementsOfTypes(JwtTimestampValidator.class); - assertThat(tokenValidators).doesNotHaveAnyElementsOfTypes(JwtClaimValidator.class); - assertThat(tokenValidators).doesNotHaveAnyElementsOfTypes(JwtIssuerValidator.class); + validate(jwt(), reactiveJwtDecoder, (validators) -> assertThat(validators).singleElement() + .isInstanceOf(JwtTimestampValidator.class)); }); } @@ -505,13 +497,18 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests { .run((context) -> { assertThat(context).hasSingleBean(ReactiveJwtDecoder.class); ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class); - validate(issuerUri, reactiveJwtDecoder, null); + validate( + jwt().claim("iss", URI.create(issuerUri).toURL()) + .claim("aud", List.of("https://test-audience.com")), + reactiveJwtDecoder, + (validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class) + .satisfiesOnlyOnce(audClaimValidator())); }); } @SuppressWarnings("unchecked") @Test - void autoConfigurationShouldConfigureAudienceAndCustomValidatorsIfPropertyProvidedAndIssuerUri() throws Exception { + void autoConfigurationShouldConfigureAudienceValidatorIfPropertyProvidedAndIssuerUri() throws Exception { this.server = new MockWebServer(); this.server.start(); String path = "test"; @@ -519,98 +516,63 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests { String cleanIssuerPath = cleanIssuerPath(issuer); setupMockResponse(cleanIssuerPath); String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path; - this.contextRunner.withPropertyValues( - "spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com", - "spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri, + this.contextRunner.withPropertyValues("spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri, "spring.security.oauth2.resourceserver.jwt.audiences=https://test-audience.com,https://test-audience1.com") - .withUserConfiguration(CustomTokenValidatorsConfig.class) .run((context) -> { - assertThat(context).hasSingleBean(ReactiveJwtDecoder.class); - ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class); - assertThat(context).hasBean("customJwtClaimValidator"); - OAuth2TokenValidator customValidator = (OAuth2TokenValidator) context - .getBean("customJwtClaimValidator"); - validate(issuerUri, reactiveJwtDecoder, customValidator); + SupplierReactiveJwtDecoder supplierJwtDecoderBean = context.getBean(SupplierReactiveJwtDecoder.class); + Mono jwtDecoderSupplier = (Mono) ReflectionTestUtils + .getField(supplierJwtDecoderBean, "jwtDecoderMono"); + ReactiveJwtDecoder jwtDecoder = jwtDecoderSupplier.block(); + validate( + jwt().claim("iss", URI.create(issuerUri).toURL()) + .claim("aud", List.of("https://test-audience.com")), + jwtDecoder, + (validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class) + .satisfiesOnlyOnce(audClaimValidator())); }); } - @SuppressWarnings("unchecked") - private void validate(String issuerUri, ReactiveJwtDecoder jwtDecoder, OAuth2TokenValidator customValidator) - throws MalformedURLException { - DelegatingOAuth2TokenValidator jwtValidator = (DelegatingOAuth2TokenValidator) ReflectionTestUtils - .getField(jwtDecoder, "jwtValidator"); - Jwt.Builder builder = jwt().claim("aud", Collections.singletonList("https://test-audience.com")); - if (issuerUri != null) { - builder.claim("iss", new URL(issuerUri)); - } - if (customValidator != null) { - builder.claim("custom_claim", "custom_claim_value"); - } - Jwt jwt = builder.build(); - assertThat(jwtValidator.validate(jwt).hasErrors()).isFalse(); - Collection> delegates = (Collection>) ReflectionTestUtils - .getField(jwtValidator, "tokenValidators"); - validateDelegates(issuerUri, delegates, customValidator); - } - - private void validateDelegates(String issuerUri, Collection> delegates, - OAuth2TokenValidator customValidator) { - assertThat(delegates).hasAtLeastOneElementOfType(JwtClaimValidator.class); - OAuth2TokenValidator delegatingValidator = delegates.stream() - .filter((v) -> v instanceof DelegatingOAuth2TokenValidator) - .findFirst() - .get(); - if (issuerUri != null) { - assertThat(delegatingValidator).extracting("tokenValidators") - .asInstanceOf(InstanceOfAssertFactories.collection(OAuth2TokenValidator.class)) - .hasAtLeastOneElementOfType(JwtIssuerValidator.class); - } - List> claimValidators = delegates.stream() - .filter((d) -> d instanceof JwtClaimValidator) - .collect(Collectors.toList()); - assertThat(claimValidators).anyMatch((v) -> "aud".equals(ReflectionTestUtils.getField(v, "claim"))); - if (customValidator != null) { - assertThat(claimValidators) - .anyMatch((v) -> "custom_claim".equals(ReflectionTestUtils.getField(v, "claim"))); - } - } - - @SuppressWarnings("unchecked") @Test - void autoConfigurationShouldConfigureAudienceValidatorIfPropertyProvidedAndIssuerUri() throws Exception { + void autoConfigurationShouldConfigureAudienceValidatorIfPropertyProvidedAndPublicKey() throws Exception { this.server = new MockWebServer(); this.server.start(); String path = "test"; String issuer = this.server.url(path).toString(); String cleanIssuerPath = cleanIssuerPath(issuer); setupMockResponse(cleanIssuerPath); - String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path; - this.contextRunner.withPropertyValues("spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri, + this.contextRunner.withPropertyValues( + "spring.security.oauth2.resourceserver.jwt.public-key-location=classpath:public-key-location", "spring.security.oauth2.resourceserver.jwt.audiences=https://test-audience.com,https://test-audience1.com") .run((context) -> { - SupplierReactiveJwtDecoder supplierJwtDecoderBean = context.getBean(SupplierReactiveJwtDecoder.class); - Mono jwtDecoderSupplier = (Mono) ReflectionTestUtils - .getField(supplierJwtDecoderBean, "jwtDecoderMono"); - ReactiveJwtDecoder jwtDecoder = jwtDecoderSupplier.block(); - validate(issuerUri, jwtDecoder, null); + assertThat(context).hasSingleBean(ReactiveJwtDecoder.class); + ReactiveJwtDecoder jwtDecoder = context.getBean(ReactiveJwtDecoder.class); + validate(jwt().claim("aud", List.of("https://test-audience.com")), jwtDecoder, + (validators) -> assertThat(validators).satisfiesOnlyOnce(audClaimValidator())); }); } + @SuppressWarnings("unchecked") @Test - void autoConfigurationShouldConfigureAudienceValidatorIfPropertyProvidedAndPublicKey() throws Exception { + void autoConfigurationShouldConfigureCustomValidators() throws Exception { this.server = new MockWebServer(); this.server.start(); String path = "test"; String issuer = this.server.url(path).toString(); String cleanIssuerPath = cleanIssuerPath(issuer); setupMockResponse(cleanIssuerPath); - this.contextRunner.withPropertyValues( - "spring.security.oauth2.resourceserver.jwt.public-key-location=classpath:public-key-location", - "spring.security.oauth2.resourceserver.jwt.audiences=https://test-audience.com,https://test-audience1.com") + String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path; + this.contextRunner + .withPropertyValues("spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com", + "spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri) + .withUserConfiguration(CustomJwtClaimValidatorConfig.class) .run((context) -> { assertThat(context).hasSingleBean(ReactiveJwtDecoder.class); - ReactiveJwtDecoder jwtDecoder = context.getBean(ReactiveJwtDecoder.class); - validate(null, jwtDecoder, null); + ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class); + OAuth2TokenValidator customValidator = (OAuth2TokenValidator) context + .getBean("customJwtClaimValidator"); + validate(jwt().claim("iss", URI.create(issuerUri).toURL()).claim("custom_claim", "custom_claim_value"), + reactiveJwtDecoder, (validators) -> assertThat(validators).contains(customValidator) + .hasAtLeastOneElementOfType(JwtIssuerValidator.class)); }); } @@ -640,6 +602,30 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests { }); } + @SuppressWarnings("unchecked") + @Test + void customValidatorWhenInvalid() throws Exception { + this.server = new MockWebServer(); + this.server.start(); + String path = "test"; + String issuer = this.server.url(path).toString(); + String cleanIssuerPath = cleanIssuerPath(issuer); + setupMockResponse(cleanIssuerPath); + String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path; + this.contextRunner + .withPropertyValues("spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com", + "spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri) + .withUserConfiguration(CustomJwtClaimValidatorConfig.class) + .run((context) -> { + assertThat(context).hasSingleBean(ReactiveJwtDecoder.class); + ReactiveJwtDecoder jwtDecoder = context.getBean(ReactiveJwtDecoder.class); + DelegatingOAuth2TokenValidator jwtValidator = (DelegatingOAuth2TokenValidator) ReflectionTestUtils + .getField(jwtDecoder, "jwtValidator"); + Jwt jwt = jwt().claim("iss", new URL(issuerUri)).claim("custom_claim", "invalid_value").build(); + assertThat(jwtValidator.validate(jwt).hasErrors()).isTrue(); + }); + } + private void assertFilterConfiguredWithJwtAuthenticationManager(AssertableReactiveWebApplicationContext context) { MatcherSecurityWebFilterChain filterChain = (MatcherSecurityWebFilterChain) context .getBean(BeanIds.SPRING_SECURITY_FILTER_CHAIN); @@ -723,6 +709,37 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests { .subject("mock-test-subject"); } + @SuppressWarnings("unchecked") + private void validate(Jwt.Builder builder, ReactiveJwtDecoder jwtDecoder, + ThrowingConsumer>> validatorsConsumer) { + DelegatingOAuth2TokenValidator jwtValidator = (DelegatingOAuth2TokenValidator) ReflectionTestUtils + .getField(jwtDecoder, "jwtValidator"); + assertThat(jwtValidator.validate(builder.build()).hasErrors()).isFalse(); + validatorsConsumer.accept(extractValidators(jwtValidator)); + } + + @SuppressWarnings("unchecked") + private List> extractValidators(DelegatingOAuth2TokenValidator delegatingValidator) { + Collection> delegates = (Collection>) ReflectionTestUtils + .getField(delegatingValidator, "tokenValidators"); + List> extracted = new ArrayList<>(); + for (OAuth2TokenValidator delegate : delegates) { + if (delegate instanceof DelegatingOAuth2TokenValidator delegatingDelegate) { + extracted.addAll(extractValidators(delegatingDelegate)); + } + else { + extracted.add(delegate); + } + } + return extracted; + } + + private Consumer> audClaimValidator() { + return (validator) -> assertThat(validator).isInstanceOf(JwtClaimValidator.class) + .extracting("claim") + .isEqualTo("aud"); + } + @EnableWebFluxSecurity static class TestConfig { @@ -781,7 +798,7 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests { } @Configuration(proxyBeanMethods = false) - static class CustomTokenValidatorsConfig { + static class CustomJwtClaimValidatorConfig { @Bean JwtClaimValidator customJwtClaimValidator() { diff --git a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/oauth2/resource/servlet/OAuth2ResourceServerAutoConfigurationTests.java b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/oauth2/resource/servlet/OAuth2ResourceServerAutoConfigurationTests.java index 42ec6ca08d..b7aa1a5ec6 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/oauth2/resource/servlet/OAuth2ResourceServerAutoConfigurationTests.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/oauth2/resource/servlet/OAuth2ResourceServerAutoConfigurationTests.java @@ -16,16 +16,17 @@ package org.springframework.boot.autoconfigure.security.oauth2.resource.servlet; -import java.net.MalformedURLException; +import java.net.URI; import java.net.URL; import java.time.Instant; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Consumer; import java.util.function.Supplier; -import java.util.stream.Collectors; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; @@ -34,6 +35,7 @@ import jakarta.servlet.Filter; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import org.assertj.core.api.InstanceOfAssertFactories; +import org.assertj.core.api.ThrowingConsumer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.mockito.InOrder; @@ -192,8 +194,8 @@ class OAuth2ResourceServerAutoConfigurationTests { }); } - @Test @SuppressWarnings("unchecked") + @Test void autoConfigurationShouldConfigureResourceServerUsingOidcIssuerUri() throws Exception { this.server = new MockWebServer(); this.server.start(); @@ -217,8 +219,8 @@ class OAuth2ResourceServerAutoConfigurationTests { assertThat(this.server.getRequestCount()).isEqualTo(2); } - @Test @SuppressWarnings("unchecked") + @Test void autoConfigurationShouldConfigureResourceServerUsingOidcRfc8414IssuerUri() throws Exception { this.server = new MockWebServer(); this.server.start(); @@ -242,8 +244,8 @@ class OAuth2ResourceServerAutoConfigurationTests { assertThat(this.server.getRequestCount()).isEqualTo(3); } - @Test @SuppressWarnings("unchecked") + @Test void autoConfigurationShouldConfigureResourceServerUsingOAuthIssuerUri() throws Exception { this.server = new MockWebServer(); this.server.start(); @@ -474,9 +476,8 @@ class OAuth2ResourceServerAutoConfigurationTests { .run((context) -> { assertThat(context).hasSingleBean(JwtDecoder.class); JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class); - assertThat(jwtDecoder).extracting("jwtValidator.tokenValidators") - .asInstanceOf(InstanceOfAssertFactories.collection(OAuth2TokenValidator.class)) - .hasAtLeastOneElementOfType(JwtIssuerValidator.class); + validate(jwt().claim("iss", issuer), jwtDecoder, + (validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class)); }); } @@ -493,11 +494,8 @@ class OAuth2ResourceServerAutoConfigurationTests { .run((context) -> { assertThat(context).hasSingleBean(JwtDecoder.class); JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class); - assertThat(jwtDecoder).extracting("jwtValidator.tokenValidators") - .asInstanceOf(InstanceOfAssertFactories.collection(OAuth2TokenValidator.class)) - .hasExactlyElementsOfTypes(JwtTimestampValidator.class) - .doesNotHaveAnyElementsOfTypes(JwtClaimValidator.class) - .doesNotHaveAnyElementsOfTypes(JwtIssuerValidator.class); + validate(jwt(), jwtDecoder, (validators) -> assertThat(validators).singleElement() + .isInstanceOf(JwtTimestampValidator.class)); }); } @@ -517,7 +515,12 @@ class OAuth2ResourceServerAutoConfigurationTests { .run((context) -> { assertThat(context).hasSingleBean(JwtDecoder.class); JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class); - validate(issuerUri, jwtDecoder, null); + validate( + jwt().claim("iss", URI.create(issuerUri).toURL()) + .claim("aud", List.of("https://test-audience.com")), + jwtDecoder, + (validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class) + .satisfiesOnlyOnce(audClaimValidator())); }); } @@ -538,13 +541,18 @@ class OAuth2ResourceServerAutoConfigurationTests { Supplier jwtDecoderSupplier = (Supplier) ReflectionTestUtils .getField(supplierJwtDecoderBean, "delegate"); JwtDecoder jwtDecoder = jwtDecoderSupplier.get(); - validate(issuerUri, jwtDecoder, null); + validate( + jwt().claim("iss", URI.create(issuerUri).toURL()) + .claim("aud", List.of("https://test-audience.com")), + jwtDecoder, + (validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class) + .satisfiesOnlyOnce(audClaimValidator())); }); } @SuppressWarnings("unchecked") @Test - void autoConfigurationShouldConfigureAudienceAndCustomValidatorsIfPropertyProvidedAndIssuerUri() throws Exception { + void autoConfigurationShouldConfigureCustomValidators() throws Exception { this.server = new MockWebServer(); this.server.start(); String path = "test"; @@ -552,9 +560,8 @@ class OAuth2ResourceServerAutoConfigurationTests { String cleanIssuerPath = cleanIssuerPath(issuer); setupMockResponse(cleanIssuerPath); String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path; - this.contextRunner.withPropertyValues("spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri, - "spring.security.oauth2.resourceserver.jwt.audiences=https://test-audience.com,https://test-audience1.com") - .withUserConfiguration(CustomTokenValidatorsConfig.class) + this.contextRunner.withPropertyValues("spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri) + .withUserConfiguration(CustomJwtClaimValidatorConfig.class) .run((context) -> { SupplierJwtDecoder supplierJwtDecoderBean = context.getBean(SupplierJwtDecoder.class); Supplier jwtDecoderSupplier = (Supplier) ReflectionTestUtils @@ -563,51 +570,12 @@ class OAuth2ResourceServerAutoConfigurationTests { assertThat(context).hasBean("customJwtClaimValidator"); OAuth2TokenValidator customValidator = (OAuth2TokenValidator) context .getBean("customJwtClaimValidator"); - validate(issuerUri, jwtDecoder, customValidator); + validate(jwt().claim("iss", URI.create(issuerUri).toURL()).claim("custom_claim", "custom_claim_value"), + jwtDecoder, (validators) -> assertThat(validators).contains(customValidator) + .hasAtLeastOneElementOfType(JwtIssuerValidator.class)); }); } - @SuppressWarnings("unchecked") - private void validate(String issuerUri, JwtDecoder jwtDecoder, OAuth2TokenValidator customValidator) - throws MalformedURLException { - DelegatingOAuth2TokenValidator jwtValidator = (DelegatingOAuth2TokenValidator) ReflectionTestUtils - .getField(jwtDecoder, "jwtValidator"); - Jwt.Builder builder = jwt().claim("aud", Collections.singletonList("https://test-audience.com")); - if (issuerUri != null) { - builder.claim("iss", new URL(issuerUri)); - } - if (customValidator != null) { - builder.claim("custom_claim", "custom_claim_value"); - } - Jwt jwt = builder.build(); - assertThat(jwtValidator.validate(jwt).hasErrors()).isFalse(); - Collection> delegates = (Collection>) ReflectionTestUtils - .getField(jwtValidator, "tokenValidators"); - validateDelegates(issuerUri, delegates, customValidator); - } - - private void validateDelegates(String issuerUri, Collection> delegates, - OAuth2TokenValidator customValidator) { - assertThat(delegates).hasAtLeastOneElementOfType(JwtClaimValidator.class); - OAuth2TokenValidator delegatingValidator = delegates.stream() - .filter((v) -> v instanceof DelegatingOAuth2TokenValidator) - .findFirst() - .get(); - if (issuerUri != null) { - assertThat(delegatingValidator).extracting("tokenValidators") - .asInstanceOf(InstanceOfAssertFactories.collection(OAuth2TokenValidator.class)) - .hasAtLeastOneElementOfType(JwtIssuerValidator.class); - } - List> claimValidators = delegates.stream() - .filter((d) -> d instanceof JwtClaimValidator) - .collect(Collectors.toList()); - assertThat(claimValidators).anyMatch((v) -> "aud".equals(ReflectionTestUtils.getField(v, "claim"))); - if (customValidator != null) { - assertThat(claimValidators) - .anyMatch((v) -> "custom_claim".equals(ReflectionTestUtils.getField(v, "claim"))); - } - } - @Test void autoConfigurationShouldConfigureAudienceValidatorIfPropertyProvidedAndPublicKey() throws Exception { this.server = new MockWebServer(); @@ -622,7 +590,8 @@ class OAuth2ResourceServerAutoConfigurationTests { .run((context) -> { assertThat(context).hasSingleBean(JwtDecoder.class); JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class); - validate(null, jwtDecoder, null); + validate(jwt().claim("aud", List.of("https://test-audience.com")), jwtDecoder, + (validators) -> assertThat(validators).satisfiesOnlyOnce(audClaimValidator())); }); } @@ -732,6 +701,37 @@ class OAuth2ResourceServerAutoConfigurationTests { .subject("mock-test-subject"); } + @SuppressWarnings("unchecked") + private void validate(Jwt.Builder builder, JwtDecoder jwtDecoder, + ThrowingConsumer>> validatorsConsumer) { + DelegatingOAuth2TokenValidator jwtValidator = (DelegatingOAuth2TokenValidator) ReflectionTestUtils + .getField(jwtDecoder, "jwtValidator"); + assertThat(jwtValidator.validate(builder.build()).hasErrors()).isFalse(); + validatorsConsumer.accept(extractValidators(jwtValidator)); + } + + @SuppressWarnings("unchecked") + private List> extractValidators(DelegatingOAuth2TokenValidator delegatingValidator) { + Collection> delegates = (Collection>) ReflectionTestUtils + .getField(delegatingValidator, "tokenValidators"); + List> extracted = new ArrayList<>(); + for (OAuth2TokenValidator delegate : delegates) { + if (delegate instanceof DelegatingOAuth2TokenValidator delegatingDelegate) { + extracted.addAll(extractValidators(delegatingDelegate)); + } + else { + extracted.add(delegate); + } + } + return extracted; + } + + private Consumer> audClaimValidator() { + return (validator) -> assertThat(validator).isInstanceOf(JwtClaimValidator.class) + .extracting("claim") + .isEqualTo("aud"); + } + @Configuration(proxyBeanMethods = false) @EnableWebSecurity static class TestConfig { @@ -786,7 +786,7 @@ class OAuth2ResourceServerAutoConfigurationTests { } @Configuration(proxyBeanMethods = false) - static class CustomTokenValidatorsConfig { + static class CustomJwtClaimValidatorConfig { @Bean JwtClaimValidator customJwtClaimValidator() {