Polish 'Choose SAML party based on entity ID rather than always using first'

See gh-35902
pull/36620/head
Phillip Webb 1 year ago
parent 864af59adc
commit b6990940b1

@ -20,6 +20,7 @@ import java.io.InputStream;
import java.security.cert.CertificateFactory; import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPrivateKey; import java.security.interfaces.RSAPrivateKey;
import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.function.Consumer; import java.util.function.Consumer;
@ -63,6 +64,7 @@ import org.springframework.util.StringUtils;
* @author Madhura Bhave * @author Madhura Bhave
* @author Phillip Webb * @author Phillip Webb
* @author Moritz Halbritter * @author Moritz Halbritter
* @author Lasse Lindqvist
*/ */
@Configuration(proxyBeanMethods = false) @Configuration(proxyBeanMethods = false)
@Conditional(RegistrationConfiguredCondition.class) @Conditional(RegistrationConfiguredCondition.class)
@ -88,14 +90,8 @@ class Saml2RelyingPartyRegistrationConfiguration {
private RelyingPartyRegistration asRegistration(String id, Registration properties) { private RelyingPartyRegistration asRegistration(String id, Registration properties) {
AssertingPartyProperties assertingParty = new AssertingPartyProperties(properties, id); AssertingPartyProperties assertingParty = new AssertingPartyProperties(properties, id);
boolean usingMetadata = StringUtils.hasText(assertingParty.getMetadataUri()); boolean usingMetadata = StringUtils.hasText(assertingParty.getMetadataUri());
Builder builder = (usingMetadata) ? RelyingPartyRegistrations Builder builder = (!usingMetadata) ? RelyingPartyRegistration.withRegistrationId(id)
.collectionFromMetadataLocation(properties.getAssertingparty().getMetadataUri()) : createBuilderUsingMetadata(id, assertingParty).registrationId(id);
.stream()
.filter(b -> entityIdsMatch(properties, b))
.findFirst()
.orElseThrow(() -> new IllegalStateException(
"No relying party with entity-id " + properties.getEntityId() + " found."))
.registrationId(id) : RelyingPartyRegistration.withRegistrationId(id);
builder.assertionConsumerServiceLocation(properties.getAcs().getLocation()); builder.assertionConsumerServiceLocation(properties.getAcs().getLocation());
builder.assertionConsumerServiceBinding(properties.getAcs().getBinding()); builder.assertionConsumerServiceBinding(properties.getAcs().getBinding());
builder.assertingPartyDetails(mapAssertingParty(properties, id, usingMetadata)); builder.assertingPartyDetails(mapAssertingParty(properties, id, usingMetadata));
@ -124,17 +120,23 @@ class Saml2RelyingPartyRegistrationConfiguration {
return registration; return registration;
} }
/** private RelyingPartyRegistration.Builder createBuilderUsingMetadata(String id,
* Tests if the builder would have the correct entity-id. If no entity-id is given in AssertingPartyProperties properties) {
* properties, any builder passes the test. String requiredEntityId = properties.getEntityId();
* @param properties the properties Collection<Builder> candidates = RelyingPartyRegistrations
* @param b the builder .collectionFromMetadataLocation(properties.getMetadataUri());
* @return true if the builder passes the test for (RelyingPartyRegistration.Builder candidate : candidates) {
*/ if (requiredEntityId == null || requiredEntityId.equals(getEntityId(candidate))) {
private boolean entityIdsMatch(Registration properties, Builder b) { return candidate;
RelyingPartyRegistration rpr = b.build(); }
return properties.getAssertingparty().getEntityId() == null }
|| properties.getAssertingparty().getEntityId().equals(rpr.getAssertingPartyDetails().getEntityId()); throw new IllegalStateException("No relying party with Entity ID '" + requiredEntityId + "' found");
}
private Object getEntityId(RelyingPartyRegistration.Builder candidate) {
String[] result = new String[1];
candidate.assertingPartyDetails((builder) -> result[0] = builder.build().getEntityId());
return result[0];
} }
private Consumer<AssertingPartyDetails.Builder> mapAssertingParty(Registration registration, String id, private Consumer<AssertingPartyDetails.Builder> mapAssertingParty(Registration registration, String id,

@ -16,6 +16,7 @@
package org.springframework.boot.autoconfigure.security.saml2; package org.springframework.boot.autoconfigure.security.saml2;
import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.List; import java.util.List;
@ -55,6 +56,7 @@ import static org.mockito.Mockito.mock;
* *
* @author Madhura Bhave * @author Madhura Bhave
* @author Moritz Halbritter * @author Moritz Halbritter
* @author Lasse Lindqvist
*/ */
class Saml2RelyingPartyAutoConfigurationTests { class Saml2RelyingPartyAutoConfigurationTests {
@ -404,39 +406,35 @@ class Saml2RelyingPartyAutoConfigurationTests {
} }
@Test @Test
void autoconfigurationShouldWorkWithMultipleProvidersWithNoEntityId() throws Exception { void autoconfigurationWhenMultipleProvidersAndNoSpecifiedEntityId() throws Exception {
try (MockWebServer server = new MockWebServer()) { testMultipleProviders(null, "https://idp.example.com/idp/shibboleth");
server.start();
String metadataUrl = server.url("").toString();
setupMockResponse(server, new ClassPathResource("saml/idp-metadata-with-multiple-providers"));
this.contextRunner.withPropertyValues(PREFIX + ".foo.assertingparty.metadata-uri=" + metadataUrl)
.run((context) -> {
assertThat(context).hasSingleBean(RelyingPartyRegistrationRepository.class);
assertThat(server.getRequestCount()).isOne();
RelyingPartyRegistrationRepository repository = context.getBean(RelyingPartyRegistrationRepository.class);
RelyingPartyRegistration registration = repository.findByRegistrationId("foo");
assertThat(registration.getAssertingPartyDetails().getEntityId())
.isEqualTo("https://idp.example.com/idp/shibboleth");
});
}
} }
@Test @Test
void autoconfigurationShouldWorkWithMultipleProviders() throws Exception { void autoconfigurationWhenMultipleProvidersAndSpecifiedEntityId() throws Exception {
testMultipleProviders("https://idp.example.com/idp/shibboleth", "https://idp.example.com/idp/shibboleth");
testMultipleProviders("https://idp2.example.com/idp/shibboleth", "https://idp2.example.com/idp/shibboleth");
}
private void testMultipleProviders(String specifiedEntityId, String expected) throws IOException, Exception {
try (MockWebServer server = new MockWebServer()) { try (MockWebServer server = new MockWebServer()) {
server.start(); server.start();
String metadataUrl = server.url("").toString(); String metadataUrl = server.url("").toString();
setupMockResponse(server, new ClassPathResource("saml/idp-metadata-with-multiple-providers")); setupMockResponse(server, new ClassPathResource("saml/idp-metadata-with-multiple-providers"));
this.contextRunner.withPropertyValues(PREFIX + ".foo.assertingparty.metadata-uri=" + metadataUrl, WebApplicationContextRunner contextRunner = this.contextRunner
PREFIX + ".foo.assertingparty.entity-id=https://idp2.example.com/idp/shibboleth") .withPropertyValues(PREFIX + ".foo.assertingparty.metadata-uri=" + metadataUrl);
.run((context) -> { if (specifiedEntityId != null) {
assertThat(context).hasSingleBean(RelyingPartyRegistrationRepository.class); contextRunner = contextRunner
assertThat(server.getRequestCount()).isOne(); .withPropertyValues(PREFIX + ".foo.assertingparty.entity-id=" + specifiedEntityId);
RelyingPartyRegistrationRepository repository = context.getBean(RelyingPartyRegistrationRepository.class); }
RelyingPartyRegistration registration = repository.findByRegistrationId("foo"); contextRunner.run((context) -> {
assertThat(registration.getAssertingPartyDetails().getEntityId()) assertThat(context).hasSingleBean(RelyingPartyRegistrationRepository.class);
.isEqualTo("https://idp2.example.com/idp/shibboleth"); assertThat(server.getRequestCount()).isOne();
}); RelyingPartyRegistrationRepository repository = context
.getBean(RelyingPartyRegistrationRepository.class);
RelyingPartyRegistration registration = repository.findByRegistrationId("foo");
assertThat(registration.getAssertingPartyDetails().getEntityId()).isEqualTo(expected);
});
} }
} }

Loading…
Cancel
Save