Polish 'Choose SAML party based on entity ID rather than always using first'

See gh-35902
pull/36620/head
Phillip Webb 1 year ago
parent 864af59adc
commit b6990940b1

@ -20,6 +20,7 @@ import java.io.InputStream;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPrivateKey;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
@ -63,6 +64,7 @@ import org.springframework.util.StringUtils;
* @author Madhura Bhave
* @author Phillip Webb
* @author Moritz Halbritter
* @author Lasse Lindqvist
*/
@Configuration(proxyBeanMethods = false)
@Conditional(RegistrationConfiguredCondition.class)
@ -88,14 +90,8 @@ class Saml2RelyingPartyRegistrationConfiguration {
private RelyingPartyRegistration asRegistration(String id, Registration properties) {
AssertingPartyProperties assertingParty = new AssertingPartyProperties(properties, id);
boolean usingMetadata = StringUtils.hasText(assertingParty.getMetadataUri());
Builder builder = (usingMetadata) ? RelyingPartyRegistrations
.collectionFromMetadataLocation(properties.getAssertingparty().getMetadataUri())
.stream()
.filter(b -> entityIdsMatch(properties, b))
.findFirst()
.orElseThrow(() -> new IllegalStateException(
"No relying party with entity-id " + properties.getEntityId() + " found."))
.registrationId(id) : RelyingPartyRegistration.withRegistrationId(id);
Builder builder = (!usingMetadata) ? RelyingPartyRegistration.withRegistrationId(id)
: createBuilderUsingMetadata(id, assertingParty).registrationId(id);
builder.assertionConsumerServiceLocation(properties.getAcs().getLocation());
builder.assertionConsumerServiceBinding(properties.getAcs().getBinding());
builder.assertingPartyDetails(mapAssertingParty(properties, id, usingMetadata));
@ -124,17 +120,23 @@ class Saml2RelyingPartyRegistrationConfiguration {
return registration;
}
/**
* Tests if the builder would have the correct entity-id. If no entity-id is given in
* properties, any builder passes the test.
* @param properties the properties
* @param b the builder
* @return true if the builder passes the test
*/
private boolean entityIdsMatch(Registration properties, Builder b) {
RelyingPartyRegistration rpr = b.build();
return properties.getAssertingparty().getEntityId() == null
|| properties.getAssertingparty().getEntityId().equals(rpr.getAssertingPartyDetails().getEntityId());
private RelyingPartyRegistration.Builder createBuilderUsingMetadata(String id,
AssertingPartyProperties properties) {
String requiredEntityId = properties.getEntityId();
Collection<Builder> candidates = RelyingPartyRegistrations
.collectionFromMetadataLocation(properties.getMetadataUri());
for (RelyingPartyRegistration.Builder candidate : candidates) {
if (requiredEntityId == null || requiredEntityId.equals(getEntityId(candidate))) {
return candidate;
}
}
throw new IllegalStateException("No relying party with Entity ID '" + requiredEntityId + "' found");
}
private Object getEntityId(RelyingPartyRegistration.Builder candidate) {
String[] result = new String[1];
candidate.assertingPartyDetails((builder) -> result[0] = builder.build().getEntityId());
return result[0];
}
private Consumer<AssertingPartyDetails.Builder> mapAssertingParty(Registration registration, String id,

@ -16,6 +16,7 @@
package org.springframework.boot.autoconfigure.security.saml2;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
@ -55,6 +56,7 @@ import static org.mockito.Mockito.mock;
*
* @author Madhura Bhave
* @author Moritz Halbritter
* @author Lasse Lindqvist
*/
class Saml2RelyingPartyAutoConfigurationTests {
@ -404,38 +406,34 @@ class Saml2RelyingPartyAutoConfigurationTests {
}
@Test
void autoconfigurationShouldWorkWithMultipleProvidersWithNoEntityId() throws Exception {
try (MockWebServer server = new MockWebServer()) {
server.start();
String metadataUrl = server.url("").toString();
setupMockResponse(server, new ClassPathResource("saml/idp-metadata-with-multiple-providers"));
this.contextRunner.withPropertyValues(PREFIX + ".foo.assertingparty.metadata-uri=" + metadataUrl)
.run((context) -> {
assertThat(context).hasSingleBean(RelyingPartyRegistrationRepository.class);
assertThat(server.getRequestCount()).isOne();
RelyingPartyRegistrationRepository repository = context.getBean(RelyingPartyRegistrationRepository.class);
RelyingPartyRegistration registration = repository.findByRegistrationId("foo");
assertThat(registration.getAssertingPartyDetails().getEntityId())
.isEqualTo("https://idp.example.com/idp/shibboleth");
});
}
void autoconfigurationWhenMultipleProvidersAndNoSpecifiedEntityId() throws Exception {
testMultipleProviders(null, "https://idp.example.com/idp/shibboleth");
}
@Test
void autoconfigurationShouldWorkWithMultipleProviders() throws Exception {
void autoconfigurationWhenMultipleProvidersAndSpecifiedEntityId() throws Exception {
testMultipleProviders("https://idp.example.com/idp/shibboleth", "https://idp.example.com/idp/shibboleth");
testMultipleProviders("https://idp2.example.com/idp/shibboleth", "https://idp2.example.com/idp/shibboleth");
}
private void testMultipleProviders(String specifiedEntityId, String expected) throws IOException, Exception {
try (MockWebServer server = new MockWebServer()) {
server.start();
String metadataUrl = server.url("").toString();
setupMockResponse(server, new ClassPathResource("saml/idp-metadata-with-multiple-providers"));
this.contextRunner.withPropertyValues(PREFIX + ".foo.assertingparty.metadata-uri=" + metadataUrl,
PREFIX + ".foo.assertingparty.entity-id=https://idp2.example.com/idp/shibboleth")
.run((context) -> {
WebApplicationContextRunner contextRunner = this.contextRunner
.withPropertyValues(PREFIX + ".foo.assertingparty.metadata-uri=" + metadataUrl);
if (specifiedEntityId != null) {
contextRunner = contextRunner
.withPropertyValues(PREFIX + ".foo.assertingparty.entity-id=" + specifiedEntityId);
}
contextRunner.run((context) -> {
assertThat(context).hasSingleBean(RelyingPartyRegistrationRepository.class);
assertThat(server.getRequestCount()).isOne();
RelyingPartyRegistrationRepository repository = context.getBean(RelyingPartyRegistrationRepository.class);
RelyingPartyRegistrationRepository repository = context
.getBean(RelyingPartyRegistrationRepository.class);
RelyingPartyRegistration registration = repository.findByRegistrationId("foo");
assertThat(registration.getAssertingPartyDetails().getEntityId())
.isEqualTo("https://idp2.example.com/idp/shibboleth");
assertThat(registration.getAssertingPartyDetails().getEntityId()).isEqualTo(expected);
});
}
}

Loading…
Cancel
Save