From b6990940b12bd431376f866959ac86e2b76d9395 Mon Sep 17 00:00:00 2001 From: Phillip Webb Date: Sun, 2 Jul 2023 15:28:05 +0100 Subject: [PATCH] Polish 'Choose SAML party based on entity ID rather than always using first' See gh-35902 --- ...RelyingPartyRegistrationConfiguration.java | 40 +++++++------- ...ml2RelyingPartyAutoConfigurationTests.java | 54 +++++++++---------- 2 files changed, 47 insertions(+), 47 deletions(-) diff --git a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/saml2/Saml2RelyingPartyRegistrationConfiguration.java b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/saml2/Saml2RelyingPartyRegistrationConfiguration.java index 886ad9a5ac..7e9ca5120b 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/saml2/Saml2RelyingPartyRegistrationConfiguration.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/saml2/Saml2RelyingPartyRegistrationConfiguration.java @@ -20,6 +20,7 @@ import java.io.InputStream; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; import java.security.interfaces.RSAPrivateKey; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.function.Consumer; @@ -63,6 +64,7 @@ import org.springframework.util.StringUtils; * @author Madhura Bhave * @author Phillip Webb * @author Moritz Halbritter + * @author Lasse Lindqvist */ @Configuration(proxyBeanMethods = false) @Conditional(RegistrationConfiguredCondition.class) @@ -88,14 +90,8 @@ class Saml2RelyingPartyRegistrationConfiguration { private RelyingPartyRegistration asRegistration(String id, Registration properties) { AssertingPartyProperties assertingParty = new AssertingPartyProperties(properties, id); boolean usingMetadata = StringUtils.hasText(assertingParty.getMetadataUri()); - Builder builder = (usingMetadata) ? RelyingPartyRegistrations - .collectionFromMetadataLocation(properties.getAssertingparty().getMetadataUri()) - .stream() - .filter(b -> entityIdsMatch(properties, b)) - .findFirst() - .orElseThrow(() -> new IllegalStateException( - "No relying party with entity-id " + properties.getEntityId() + " found.")) - .registrationId(id) : RelyingPartyRegistration.withRegistrationId(id); + Builder builder = (!usingMetadata) ? RelyingPartyRegistration.withRegistrationId(id) + : createBuilderUsingMetadata(id, assertingParty).registrationId(id); builder.assertionConsumerServiceLocation(properties.getAcs().getLocation()); builder.assertionConsumerServiceBinding(properties.getAcs().getBinding()); builder.assertingPartyDetails(mapAssertingParty(properties, id, usingMetadata)); @@ -124,17 +120,23 @@ class Saml2RelyingPartyRegistrationConfiguration { return registration; } - /** - * Tests if the builder would have the correct entity-id. If no entity-id is given in - * properties, any builder passes the test. - * @param properties the properties - * @param b the builder - * @return true if the builder passes the test - */ - private boolean entityIdsMatch(Registration properties, Builder b) { - RelyingPartyRegistration rpr = b.build(); - return properties.getAssertingparty().getEntityId() == null - || properties.getAssertingparty().getEntityId().equals(rpr.getAssertingPartyDetails().getEntityId()); + private RelyingPartyRegistration.Builder createBuilderUsingMetadata(String id, + AssertingPartyProperties properties) { + String requiredEntityId = properties.getEntityId(); + Collection candidates = RelyingPartyRegistrations + .collectionFromMetadataLocation(properties.getMetadataUri()); + for (RelyingPartyRegistration.Builder candidate : candidates) { + if (requiredEntityId == null || requiredEntityId.equals(getEntityId(candidate))) { + return candidate; + } + } + throw new IllegalStateException("No relying party with Entity ID '" + requiredEntityId + "' found"); + } + + private Object getEntityId(RelyingPartyRegistration.Builder candidate) { + String[] result = new String[1]; + candidate.assertingPartyDetails((builder) -> result[0] = builder.build().getEntityId()); + return result[0]; } private Consumer mapAssertingParty(Registration registration, String id, diff --git a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/saml2/Saml2RelyingPartyAutoConfigurationTests.java b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/saml2/Saml2RelyingPartyAutoConfigurationTests.java index 37d407d3ed..1df102b853 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/saml2/Saml2RelyingPartyAutoConfigurationTests.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/saml2/Saml2RelyingPartyAutoConfigurationTests.java @@ -16,6 +16,7 @@ package org.springframework.boot.autoconfigure.security.saml2; +import java.io.IOException; import java.io.InputStream; import java.util.List; @@ -55,6 +56,7 @@ import static org.mockito.Mockito.mock; * * @author Madhura Bhave * @author Moritz Halbritter + * @author Lasse Lindqvist */ class Saml2RelyingPartyAutoConfigurationTests { @@ -402,41 +404,37 @@ class Saml2RelyingPartyAutoConfigurationTests { this.contextRunner.withPropertyValues(getPropertyValues(false)) .run((context) -> assertThat(hasFilter(context, Saml2LogoutRequestFilter.class)).isTrue()); } - + @Test - void autoconfigurationShouldWorkWithMultipleProvidersWithNoEntityId() throws Exception { - try (MockWebServer server = new MockWebServer()) { - server.start(); - String metadataUrl = server.url("").toString(); - setupMockResponse(server, new ClassPathResource("saml/idp-metadata-with-multiple-providers")); - this.contextRunner.withPropertyValues(PREFIX + ".foo.assertingparty.metadata-uri=" + metadataUrl) - .run((context) -> { - assertThat(context).hasSingleBean(RelyingPartyRegistrationRepository.class); - assertThat(server.getRequestCount()).isOne(); - RelyingPartyRegistrationRepository repository = context.getBean(RelyingPartyRegistrationRepository.class); - RelyingPartyRegistration registration = repository.findByRegistrationId("foo"); - assertThat(registration.getAssertingPartyDetails().getEntityId()) - .isEqualTo("https://idp.example.com/idp/shibboleth"); - }); - } + void autoconfigurationWhenMultipleProvidersAndNoSpecifiedEntityId() throws Exception { + testMultipleProviders(null, "https://idp.example.com/idp/shibboleth"); } - + @Test - void autoconfigurationShouldWorkWithMultipleProviders() throws Exception { + void autoconfigurationWhenMultipleProvidersAndSpecifiedEntityId() throws Exception { + testMultipleProviders("https://idp.example.com/idp/shibboleth", "https://idp.example.com/idp/shibboleth"); + testMultipleProviders("https://idp2.example.com/idp/shibboleth", "https://idp2.example.com/idp/shibboleth"); + } + + private void testMultipleProviders(String specifiedEntityId, String expected) throws IOException, Exception { try (MockWebServer server = new MockWebServer()) { server.start(); String metadataUrl = server.url("").toString(); setupMockResponse(server, new ClassPathResource("saml/idp-metadata-with-multiple-providers")); - this.contextRunner.withPropertyValues(PREFIX + ".foo.assertingparty.metadata-uri=" + metadataUrl, - PREFIX + ".foo.assertingparty.entity-id=https://idp2.example.com/idp/shibboleth") - .run((context) -> { - assertThat(context).hasSingleBean(RelyingPartyRegistrationRepository.class); - assertThat(server.getRequestCount()).isOne(); - RelyingPartyRegistrationRepository repository = context.getBean(RelyingPartyRegistrationRepository.class); - RelyingPartyRegistration registration = repository.findByRegistrationId("foo"); - assertThat(registration.getAssertingPartyDetails().getEntityId()) - .isEqualTo("https://idp2.example.com/idp/shibboleth"); - }); + WebApplicationContextRunner contextRunner = this.contextRunner + .withPropertyValues(PREFIX + ".foo.assertingparty.metadata-uri=" + metadataUrl); + if (specifiedEntityId != null) { + contextRunner = contextRunner + .withPropertyValues(PREFIX + ".foo.assertingparty.entity-id=" + specifiedEntityId); + } + contextRunner.run((context) -> { + assertThat(context).hasSingleBean(RelyingPartyRegistrationRepository.class); + assertThat(server.getRequestCount()).isOne(); + RelyingPartyRegistrationRepository repository = context + .getBean(RelyingPartyRegistrationRepository.class); + RelyingPartyRegistration registration = repository.findByRegistrationId("foo"); + assertThat(registration.getAssertingPartyDetails().getEntityId()).isEqualTo(expected); + }); } }