Implement SSL hot reload

pull/37808/head
Moritz Halbritter 1 year ago
parent 7b1059a4b5
commit c9e45952b0

@ -7,6 +7,6 @@ org.gradle.jvmargs=-Xmx2g -Dfile.encoding=UTF-8
kotlinVersion=1.9.10 kotlinVersion=1.9.10
nativeBuildToolsVersion=0.9.27 nativeBuildToolsVersion=0.9.27
springFrameworkVersion=6.1.0-SNAPSHOT springFrameworkVersion=6.1.0-SNAPSHOT
tomcatVersion=10.1.13 tomcatVersion=10.1.14
kotlin.stdlib.default.dependency=false kotlin.stdlib.default.dependency=false

@ -0,0 +1,302 @@
/*
* Copyright 2012-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.autoconfigure.ssl;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.file.FileSystems;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardWatchEventKinds;
import java.nio.file.WatchEvent;
import java.nio.file.WatchEvent.Kind;
import java.nio.file.WatchKey;
import java.nio.file.WatchService;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.util.Assert;
/**
* Watches files and directories and triggers a callback on change.
*
* @author Moritz Halbritter
*/
class FileWatcher implements AutoCloseable {
private static final Log logger = LogFactory.getLog(FileWatcher.class);
private final String threadName;
private final Duration quietPeriod;
private final Object lifecycleLock = new Object();
private final Map<WatchKey, List<Registration>> registrations = new ConcurrentHashMap<>();
private volatile WatchService watchService;
private Thread thread;
private boolean running = false;
FileWatcher(String threadName, Duration quietPeriod) {
Assert.notNull(threadName, "threadName must not be null");
Assert.notNull(quietPeriod, "quietPeriod must not be null");
this.threadName = threadName;
this.quietPeriod = quietPeriod;
}
void watch(Set<Path> paths, Callback callback) {
Assert.notNull(callback, "callback must not be null");
Assert.notNull(paths, "paths must not be null");
if (paths.isEmpty()) {
return;
}
startIfNecessary();
try {
registerWatchables(callback, paths, this.watchService);
}
catch (IOException ex) {
throw new UncheckedIOException("Failed to register paths for watching: " + paths, ex);
}
}
void stop() {
synchronized (this.lifecycleLock) {
if (!this.running) {
return;
}
this.running = false;
this.thread.interrupt();
try {
this.thread.join();
}
catch (InterruptedException ex) {
Thread.currentThread().interrupt();
}
this.thread = null;
this.watchService = null;
this.registrations.clear();
}
}
private void startIfNecessary() {
synchronized (this.lifecycleLock) {
if (this.running) {
return;
}
CountDownLatch started = new CountDownLatch(1);
this.thread = new Thread(() -> this.threadMain(started));
this.thread.setName(this.threadName);
this.thread.setDaemon(true);
this.thread.setUncaughtExceptionHandler(this::onThreadException);
this.running = true;
this.thread.start();
try {
started.await();
}
catch (InterruptedException ex) {
Thread.currentThread().interrupt();
}
}
}
private void threadMain(CountDownLatch started) {
logger.debug("Watch thread started");
try (WatchService watcher = FileSystems.getDefault().newWatchService()) {
this.watchService = watcher;
started.countDown();
Map<Registration, List<Change>> accumulatedChanges = new HashMap<>();
while (this.running) {
try {
WatchKey key = watcher.poll(this.quietPeriod.toMillis(), TimeUnit.MILLISECONDS);
if (key == null) {
// WatchService returned without any changes
if (!accumulatedChanges.isEmpty()) {
// We have queued changes, that means there were no changes
// since the quiet period
fireCallback(accumulatedChanges);
accumulatedChanges.clear();
}
}
else {
accumulateChanges(key, accumulatedChanges);
}
}
catch (InterruptedException ex) {
Thread.currentThread().interrupt();
}
}
logger.debug("Watch thread stopped");
}
catch (IOException ex) {
throw new UncheckedIOException(ex);
}
}
private void accumulateChanges(WatchKey key, Map<Registration, List<Change>> accumulatedChanges)
throws IOException {
List<Registration> registrations = this.registrations.get(key);
Path directory = (Path) key.watchable();
for (WatchEvent<?> event : key.pollEvents()) {
Path file = directory.resolve((Path) event.context());
for (Registration registration : registrations) {
if (registration.affectsFile(file)) {
accumulatedChanges.computeIfAbsent(registration, (ignore) -> new ArrayList<>())
.add(new Change(file, Type.from(event.kind())));
}
}
}
key.reset();
}
private void fireCallback(Map<Registration, List<Change>> accumulatedChanges) {
for (Entry<Registration, List<Change>> entry : accumulatedChanges.entrySet()) {
Changes changes = new Changes(entry.getValue());
if (!changes.isEmpty()) {
entry.getKey().callback().onChange(changes);
}
}
}
private void onThreadException(Thread thread, Throwable throwable) {
logger.error("Uncaught exception in file watcher thread", throwable);
}
private void registerWatchables(Callback callback, Set<Path> paths, WatchService watchService) throws IOException {
Set<WatchKey> watchKeys = new HashSet<>();
Set<Path> directories = new HashSet<>();
Set<Path> files = new HashSet<>();
for (Path path : paths) {
Path realPath = path.toRealPath();
if (Files.isDirectory(realPath)) {
directories.add(realPath);
watchKeys.add(registerDirectory(realPath, watchService));
}
else if (Files.isRegularFile(realPath)) {
files.add(realPath);
watchKeys.add(registerFile(realPath, watchService));
}
else {
throw new IOException("'%s' is neither a file nor a directory".formatted(realPath));
}
}
Registration registration = new Registration(callback, directories, files);
for (WatchKey watchKey : watchKeys) {
this.registrations.computeIfAbsent(watchKey, (ignore) -> new CopyOnWriteArrayList<>()).add(registration);
}
}
private WatchKey registerFile(Path file, WatchService watchService) throws IOException {
return register(file.getParent(), watchService);
}
private WatchKey registerDirectory(Path directory, WatchService watchService) throws IOException {
return register(directory, watchService);
}
private WatchKey register(Path directory, WatchService watchService) throws IOException {
logger.debug(LogMessage.format("Registering '%s'", directory));
return directory.register(watchService, StandardWatchEventKinds.ENTRY_CREATE,
StandardWatchEventKinds.ENTRY_MODIFY, StandardWatchEventKinds.ENTRY_DELETE);
}
@Override
public void close() {
stop();
}
private record Registration(Callback callback, Set<Path> directories, Set<Path> files) {
boolean affectsFile(Path file) {
return this.files.contains(file) || isInDirectories(file);
}
private boolean isInDirectories(Path file) {
for (Path directory : this.directories) {
if (file.startsWith(directory)) {
return true;
}
}
return false;
}
}
enum Type {
CREATE, MODIFY, DELETE;
private static Type from(Kind<?> kind) {
if (kind == StandardWatchEventKinds.ENTRY_CREATE) {
return CREATE;
}
if (kind == StandardWatchEventKinds.ENTRY_DELETE) {
return DELETE;
}
if (kind == StandardWatchEventKinds.ENTRY_MODIFY) {
return MODIFY;
}
throw new IllegalArgumentException("Unknown kind: " + kind);
}
}
record Change(Path path, Type type) {
}
static class Changes implements Iterable<Change> {
private final List<Change> changes;
Changes(List<Change> changes) {
this.changes = changes;
}
@Override
public Iterator<Change> iterator() {
return this.changes.iterator();
}
boolean isEmpty() {
return this.changes.isEmpty();
}
}
@FunctionalInterface
interface Callback {
void onChange(Changes changes);
}
}

@ -38,6 +38,11 @@ public class JksSslBundleProperties extends SslBundleProperties {
*/ */
private final Store truststore = new Store(); private final Store truststore = new Store();
/**
* Whether to reload the SSL bundle.
*/
private boolean reloadOnUpdate;
public Store getKeystore() { public Store getKeystore() {
return this.keystore; return this.keystore;
} }
@ -46,6 +51,14 @@ public class JksSslBundleProperties extends SslBundleProperties {
return this.truststore; return this.truststore;
} }
public boolean isReloadOnUpdate() {
return this.reloadOnUpdate;
}
public void setReloadOnUpdate(boolean reloadOnUpdate) {
this.reloadOnUpdate = reloadOnUpdate;
}
/** /**
* Store properties. * Store properties.
*/ */

@ -17,6 +17,7 @@
package org.springframework.boot.autoconfigure.ssl; package org.springframework.boot.autoconfigure.ssl;
import org.springframework.boot.ssl.pem.PemSslStoreBundle; import org.springframework.boot.ssl.pem.PemSslStoreBundle;
import org.springframework.boot.ssl.pem.PemSslStoreDetails;
/** /**
* {@link SslBundleProperties} for PEM-encoded certificates and private keys. * {@link SslBundleProperties} for PEM-encoded certificates and private keys.
@ -39,6 +40,11 @@ public class PemSslBundleProperties extends SslBundleProperties {
*/ */
private final Store truststore = new Store(); private final Store truststore = new Store();
/**
* Whether to reload the SSL bundle.
*/
private boolean reloadOnUpdate;
/** /**
* Whether to verify that the private key matches the public key. * Whether to verify that the private key matches the public key.
*/ */
@ -52,6 +58,14 @@ public class PemSslBundleProperties extends SslBundleProperties {
return this.truststore; return this.truststore;
} }
public boolean isReloadOnUpdate() {
return this.reloadOnUpdate;
}
public void setReloadOnUpdate(boolean reloadOnUpdate) {
this.reloadOnUpdate = reloadOnUpdate;
}
public boolean isVerifyKeys() { public boolean isVerifyKeys() {
return this.verifyKeys; return this.verifyKeys;
} }
@ -117,6 +131,10 @@ public class PemSslBundleProperties extends SslBundleProperties {
this.privateKeyPassword = privateKeyPassword; this.privateKeyPassword = privateKeyPassword;
} }
PemSslStoreDetails asPemSslStoreDetails() {
return new PemSslStoreDetails(this.type, this.certificate, this.privateKey, this.privateKeyPassword);
}
} }
} }

@ -16,8 +16,7 @@
package org.springframework.boot.autoconfigure.ssl; package org.springframework.boot.autoconfigure.ssl;
import java.util.List; import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
@ -37,19 +36,28 @@ import org.springframework.context.annotation.Bean;
@EnableConfigurationProperties(SslProperties.class) @EnableConfigurationProperties(SslProperties.class)
public class SslAutoConfiguration { public class SslAutoConfiguration {
SslAutoConfiguration() { private final SslProperties sslProperties;
SslAutoConfiguration(SslProperties sslProperties) {
this.sslProperties = sslProperties;
}
@Bean
FileWatcher fileWatcher() {
return new FileWatcher("ssl-bundle-watcher",
this.sslProperties.getBundle().getWatch().getFile().getQuietPeriod());
} }
@Bean @Bean
public SslPropertiesBundleRegistrar sslPropertiesSslBundleRegistrar(SslProperties sslProperties) { SslPropertiesBundleRegistrar sslPropertiesSslBundleRegistrar(FileWatcher fileWatcher) {
return new SslPropertiesBundleRegistrar(sslProperties); return new SslPropertiesBundleRegistrar(this.sslProperties, fileWatcher);
} }
@Bean @Bean
@ConditionalOnMissingBean({ SslBundleRegistry.class, SslBundles.class }) @ConditionalOnMissingBean({ SslBundleRegistry.class, SslBundles.class })
public DefaultSslBundleRegistry sslBundleRegistry(List<SslBundleRegistrar> sslBundleRegistrars) { DefaultSslBundleRegistry sslBundleRegistry(ObjectProvider<SslBundleRegistrar> sslBundleRegistrars) {
DefaultSslBundleRegistry registry = new DefaultSslBundleRegistry(); DefaultSslBundleRegistry registry = new DefaultSslBundleRegistry();
sslBundleRegistrars.forEach((registrar) -> registrar.registerBundles(registry)); sslBundleRegistrars.orderedStream().forEach((registrar) -> registrar.registerBundles(registry));
return registry; return registry;
} }

@ -36,7 +36,7 @@ public abstract class SslBundleProperties {
private final Key key = new Key(); private final Key key = new Key();
/** /**
* Options for the SLL connection. * Options for the SSL connection.
*/ */
private final Options options = new Options(); private final Options options = new Options();

@ -16,6 +16,7 @@
package org.springframework.boot.autoconfigure.ssl; package org.springframework.boot.autoconfigure.ssl;
import java.time.Duration;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
@ -25,6 +26,7 @@ import org.springframework.boot.context.properties.ConfigurationProperties;
* Properties for centralized SSL trust material configuration. * Properties for centralized SSL trust material configuration.
* *
* @author Scott Frederick * @author Scott Frederick
* @author Moritz Halbritter
* @since 3.1.0 * @since 3.1.0
*/ */
@ConfigurationProperties(prefix = "spring.ssl") @ConfigurationProperties(prefix = "spring.ssl")
@ -54,6 +56,11 @@ public class SslProperties {
*/ */
private final Map<String, JksSslBundleProperties> jks = new LinkedHashMap<>(); private final Map<String, JksSslBundleProperties> jks = new LinkedHashMap<>();
/**
* Trust material watching.
*/
private final Watch watch = new Watch();
public Map<String, PemSslBundleProperties> getPem() { public Map<String, PemSslBundleProperties> getPem() {
return this.pem; return this.pem;
} }
@ -62,6 +69,40 @@ public class SslProperties {
return this.jks; return this.jks;
} }
public Watch getWatch() {
return this.watch;
}
public static class Watch {
/**
* File watching.
*/
private final File file = new File();
public File getFile() {
return this.file;
}
public static class File {
/**
* Quiet period, after which changes are detected.
*/
private Duration quietPeriod = Duration.ofSeconds(10);
public Duration getQuietPeriod() {
return this.quietPeriod;
}
public void setQuietPeriod(Duration quietPeriod) {
this.quietPeriod = quietPeriod;
}
}
}
} }
} }

@ -16,11 +16,19 @@
package org.springframework.boot.autoconfigure.ssl; package org.springframework.boot.autoconfigure.ssl;
import java.io.FileNotFoundException;
import java.io.UncheckedIOException;
import java.net.URL;
import java.nio.file.Path;
import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.function.Function; import java.util.Map.Entry;
import java.util.Set;
import org.springframework.boot.ssl.SslBundle;
import org.springframework.boot.ssl.SslBundleRegistry; import org.springframework.boot.ssl.SslBundleRegistry;
import org.springframework.boot.ssl.pem.PemSslStoreDetails;
import org.springframework.boot.ssl.pem.PemSslStoreDetails.Type;
import org.springframework.util.ResourceUtils;
/** /**
* A {@link SslBundleRegistrar} that registers SSL bundles based * A {@link SslBundleRegistrar} that registers SSL bundles based
@ -28,25 +36,104 @@ import org.springframework.boot.ssl.SslBundleRegistry;
* *
* @author Scott Frederick * @author Scott Frederick
* @author Phillip Webb * @author Phillip Webb
* @author Moritz Halbritter
*/ */
class SslPropertiesBundleRegistrar implements SslBundleRegistrar { class SslPropertiesBundleRegistrar implements SslBundleRegistrar {
private final SslProperties.Bundles properties; private final SslProperties.Bundles properties;
SslPropertiesBundleRegistrar(SslProperties properties) { private final FileWatcher fileWatcher;
SslPropertiesBundleRegistrar(SslProperties properties, FileWatcher fileWatcher) {
this.properties = properties.getBundle(); this.properties = properties.getBundle();
this.fileWatcher = fileWatcher;
} }
@Override @Override
public void registerBundles(SslBundleRegistry registry) { public void registerBundles(SslBundleRegistry registry) {
registerBundles(registry, this.properties.getPem(), PropertiesSslBundle::get); registerPemBundles(registry, this.properties.getPem());
registerBundles(registry, this.properties.getJks(), PropertiesSslBundle::get); registerJksBundles(registry, this.properties.getJks());
}
private void registerJksBundles(SslBundleRegistry registry, Map<String, JksSslBundleProperties> bundles) {
for (Entry<String, JksSslBundleProperties> bundle : bundles.entrySet()) {
String bundleName = bundle.getKey();
JksSslBundleProperties properties = bundle.getValue();
registry.registerBundle(bundleName, PropertiesSslBundle.get(properties));
if (properties.isReloadOnUpdate()) {
Set<Path> locations = getPathsToWatch(properties);
this.fileWatcher.watch(locations,
(changes) -> registry.updateBundle(bundleName, PropertiesSslBundle.get(properties)));
}
}
}
private void registerPemBundles(SslBundleRegistry registry, Map<String, PemSslBundleProperties> bundles) {
for (Entry<String, PemSslBundleProperties> bundle : bundles.entrySet()) {
String bundleName = bundle.getKey();
PemSslBundleProperties properties = bundle.getValue();
registry.registerBundle(bundleName, PropertiesSslBundle.get(properties));
if (properties.isReloadOnUpdate()) {
Set<Path> locations = getPathsToWatch(properties);
this.fileWatcher.watch(locations,
(changes) -> registry.updateBundle(bundleName, PropertiesSslBundle.get(properties)));
}
}
}
private Set<Path> getPathsToWatch(JksSslBundleProperties properties) {
Set<Path> result = new HashSet<>();
if (properties.getKeystore().getLocation() != null) {
result.add(toPath(properties.getKeystore().getLocation()));
}
if (properties.getTruststore().getLocation() != null) {
result.add(toPath(properties.getTruststore().getLocation()));
}
return result;
} }
private <P extends SslBundleProperties> void registerBundles(SslBundleRegistry registry, Map<String, P> properties, private Set<Path> getPathsToWatch(PemSslBundleProperties properties) {
Function<P, SslBundle> bundleFactory) { PemSslStoreDetails keystore = properties.getKeystore().asPemSslStoreDetails();
properties.forEach((bundleName, bundleProperties) -> registry.registerBundle(bundleName, PemSslStoreDetails truststore = properties.getTruststore().asPemSslStoreDetails();
bundleFactory.apply(bundleProperties))); Set<Path> result = new HashSet<>();
if (keystore.privateKey() != null) {
if (keystore.getPrivateKeyType() != Type.URL) {
throw new IllegalStateException("Keystore private key is not a URL and can't be watched");
}
result.add(toPath(keystore.privateKey()));
}
if (keystore.certificate() != null) {
if (keystore.getCertificateType() != Type.URL) {
throw new IllegalStateException("Keystore certificate is not a URL and can't be watched");
}
result.add(toPath(keystore.certificate()));
}
if (truststore.privateKey() != null) {
if (truststore.getPrivateKeyType() != Type.URL) {
throw new IllegalStateException("Truststore private key is not a URL and can't be watched");
}
result.add(toPath(truststore.privateKey()));
}
if (truststore.certificate() != null) {
if (truststore.getCertificateType() != Type.URL) {
throw new IllegalStateException("Truststore certificate is not a URL and can't be watched");
}
result.add(toPath(truststore.certificate()));
}
return result;
}
private Path toPath(String location) {
try {
URL url = ResourceUtils.getURL(location);
if (!"file".equals(url.getProtocol())) {
throw new IllegalStateException("Location '%s' doesn't point to a file".formatted(location));
}
return Path.of(url.getFile()).toAbsolutePath();
}
catch (FileNotFoundException ex) {
throw new UncheckedIOException("Failed to get URI to location '%s'".formatted(location), ex);
}
} }
} }

@ -0,0 +1,179 @@
/*
* Copyright 2012-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.autoconfigure.ssl;
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Duration;
import java.util.Arrays;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.apache.activemq.artemis.utils.collections.ConcurrentHashSet;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.springframework.boot.autoconfigure.ssl.FileWatcher.Callback;
import org.springframework.boot.autoconfigure.ssl.FileWatcher.Change;
import org.springframework.boot.autoconfigure.ssl.FileWatcher.Changes;
import org.springframework.boot.autoconfigure.ssl.FileWatcher.Type;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assertions.fail;
/**
* Tests for {@link FileWatcher}.
*
* @author Moritz Halbritter
*/
class FileWatcherTests {
private FileWatcher fileWatcher;
@BeforeEach
void setUp() {
this.fileWatcher = new FileWatcher("filewatcher-test-", Duration.ofMillis(10));
}
@AfterEach
void tearDown() {
this.fileWatcher.close();
}
@Test
void shouldTriggerOnFileCreation(@TempDir Path tempDir) throws Exception {
Path newFile = tempDir.resolve("new-file.txt");
WaitingCallback callback = new WaitingCallback();
this.fileWatcher.watch(Set.of(tempDir), callback);
Files.createFile(newFile);
Set<Change> changes = callback.waitForChanges();
assertThatHasChanges(changes, new Change(newFile, Type.CREATE));
}
@Test
void shouldTriggerOnFileDeletion(@TempDir Path tempDir) throws Exception {
Path deletedFile = tempDir.resolve("deleted-file.txt");
Files.createFile(deletedFile);
WaitingCallback callback = new WaitingCallback();
this.fileWatcher.watch(Set.of(tempDir), callback);
Files.delete(deletedFile);
Set<Change> changes = callback.waitForChanges();
assertThatHasChanges(changes, new Change(deletedFile, Type.DELETE));
}
@Test
void shouldTriggerOnFileModification(@TempDir Path tempDir) throws Exception {
Path deletedFile = tempDir.resolve("modified-file.txt");
Files.createFile(deletedFile);
WaitingCallback callback = new WaitingCallback();
this.fileWatcher.watch(Set.of(tempDir), callback);
Files.writeString(deletedFile, "Some content");
Set<Change> changes = callback.waitForChanges();
assertThatHasChanges(changes, new Change(deletedFile, Type.MODIFY));
}
@Test
void shouldWatchFile(@TempDir Path tempDir) throws Exception {
Path watchedFile = tempDir.resolve("watched.txt");
Files.createFile(watchedFile);
WaitingCallback callback = new WaitingCallback();
this.fileWatcher.watch(Set.of(watchedFile), callback);
Files.writeString(watchedFile, "Some content");
Set<Change> changes = callback.waitForChanges();
assertThatHasChanges(changes, new Change(watchedFile, Type.MODIFY));
}
@Test
void shouldIgnoreNotWatchedFiles(@TempDir Path tempDir) throws Exception {
Path watchedFile = tempDir.resolve("watched.txt");
Path notWatchedFile = tempDir.resolve("not-watched.txt");
Files.createFile(watchedFile);
Files.createFile(notWatchedFile);
WaitingCallback callback = new WaitingCallback();
this.fileWatcher.watch(Set.of(watchedFile), callback);
Files.writeString(notWatchedFile, "Some content");
callback.expectNoChanges();
}
@Test
void shouldFailIfDirectoryOrFileDoesntExist(@TempDir Path tempDir) {
Path directory = tempDir.resolve("dir1");
assertThatThrownBy(() -> this.fileWatcher.watch(Set.of(directory), new WaitingCallback()))
.isInstanceOf(UncheckedIOException.class)
.hasMessageMatching("Failed to register paths for watching: \\[.+/dir1]");
}
@Test
void shouldNotFailIfDirectoryIsRegisteredMultipleTimes(@TempDir Path tempDir) {
WaitingCallback callback = new WaitingCallback();
assertThatCode(() -> {
this.fileWatcher.watch(Set.of(tempDir), callback);
this.fileWatcher.watch(Set.of(tempDir), callback);
}).doesNotThrowAnyException();
}
@Test
void shouldNotFailIfStoppedMultipleTimes(@TempDir Path tempDir) {
WaitingCallback callback = new WaitingCallback();
this.fileWatcher.watch(Set.of(tempDir), callback);
assertThatCode(() -> {
this.fileWatcher.stop();
this.fileWatcher.stop();
}).doesNotThrowAnyException();
}
private void assertThatHasChanges(Set<Change> candidates, Change... changes) {
assertThat(candidates).containsAll(Arrays.asList(changes));
}
private static class WaitingCallback implements Callback {
private final CountDownLatch latch = new CountDownLatch(1);
private final Set<Change> changes = new ConcurrentHashSet<>();
@Override
public void onChange(Changes changes) {
for (Change change : changes) {
this.changes.add(change);
}
this.latch.countDown();
}
Set<Change> waitForChanges() throws InterruptedException {
if (!this.latch.await(10, TimeUnit.SECONDS)) {
fail("Timeout while waiting for changes");
}
return this.changes;
}
void expectNoChanges() throws InterruptedException {
if (!this.latch.await(100, TimeUnit.MILLISECONDS)) {
return;
}
assertThat(this.changes).isEmpty();
}
}
}

@ -0,0 +1,172 @@
/*
* Copyright 2012-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.autoconfigure.ssl;
import java.nio.file.Path;
import java.util.Set;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.springframework.boot.ssl.SslBundleRegistry;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.assertArg;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.then;
import static org.mockito.Mockito.times;
/**
* Tests for {@link SslPropertiesBundleRegistrar}.
*
* @author Moritz Halbritter
*/
class SslPropertiesBundleRegistrarTests {
private SslPropertiesBundleRegistrar registrar;
private FileWatcher fileWatcher;
private SslProperties properties;
private SslBundleRegistry registry;
@BeforeEach
void setUp() {
this.properties = new SslProperties();
this.fileWatcher = Mockito.mock(FileWatcher.class);
this.registrar = new SslPropertiesBundleRegistrar(this.properties, this.fileWatcher);
this.registry = Mockito.mock(SslBundleRegistry.class);
}
@Test
void shouldWatchJksBundles() {
JksSslBundleProperties jks = new JksSslBundleProperties();
jks.setReloadOnUpdate(true);
jks.getKeystore().setLocation("classpath:test.jks");
jks.getKeystore().setPassword("secret");
jks.getTruststore().setLocation("classpath:test.jks");
jks.getTruststore().setPassword("secret");
this.properties.getBundle().getJks().put("bundle1", jks);
this.registrar.registerBundles(this.registry);
then(this.registry).should(times(1)).registerBundle(eq("bundle1"), any());
then(this.fileWatcher).should().watch(assertArg((set) -> pathEndingWith(set, "test.jks")), any());
}
@Test
void shouldWatchPemBundles() {
PemSslBundleProperties pem = new PemSslBundleProperties();
pem.setReloadOnUpdate(true);
pem.getKeystore().setCertificate("classpath:org/springframework/boot/autoconfigure/ssl/rsa-cert.pem");
pem.getKeystore().setPrivateKey("classpath:org/springframework/boot/autoconfigure/ssl/rsa-key.pem");
pem.getTruststore().setCertificate("classpath:org/springframework/boot/autoconfigure/ssl/ed25519-cert.pem");
pem.getTruststore().setPrivateKey("classpath:org/springframework/boot/autoconfigure/ssl/ed25519-key.pem");
this.properties.getBundle().getPem().put("bundle1", pem);
this.registrar.registerBundles(this.registry);
then(this.registry).should(times(1)).registerBundle(eq("bundle1"), any());
then(this.fileWatcher).should()
.watch(assertArg((set) -> pathEndingWith(set, "rsa-cert.pem", "rsa-key.pem")), any());
}
@Test
void shouldFailIfPemKeystoreCertificateIsEmbedded() {
PemSslBundleProperties pem = new PemSslBundleProperties();
pem.setReloadOnUpdate(true);
pem.getKeystore().setCertificate("""
-----BEGIN CERTIFICATE-----
MIICCzCCAb2gAwIBAgIUZbDi7G5czH+Yi0k2EMWxdf00XagwBQYDK2VwMHsxCzAJ
BgNVBAYTAlhYMRIwEAYDVQQIDAlTdGF0ZU5hbWUxETAPBgNVBAcMCENpdHlOYW1l
MRQwEgYDVQQKDAtDb21wYW55TmFtZTEbMBkGA1UECwwSQ29tcGFueVNlY3Rpb25O
YW1lMRIwEAYDVQQDDAlsb2NhbGhvc3QwHhcNMjMwOTExMTIxNDMwWhcNMzMwOTA4
MTIxNDMwWjB7MQswCQYDVQQGEwJYWDESMBAGA1UECAwJU3RhdGVOYW1lMREwDwYD
VQQHDAhDaXR5TmFtZTEUMBIGA1UECgwLQ29tcGFueU5hbWUxGzAZBgNVBAsMEkNv
bXBhbnlTZWN0aW9uTmFtZTESMBAGA1UEAwwJbG9jYWxob3N0MCowBQYDK2VwAyEA
Q/DDA4BSgZ+Hx0DUxtIRjVjN+OcxXVURwAWc3Gt9GUyjUzBRMB0GA1UdDgQWBBSv
EdpoaBMBoxgO96GFbf03k07DSTAfBgNVHSMEGDAWgBSvEdpoaBMBoxgO96GFbf03
k07DSTAPBgNVHRMBAf8EBTADAQH/MAUGAytlcANBAHMXDkGd57d4F4cRk/8UjhxD
7OtRBZfdfznSvlhJIMNfH5q0zbC2eO3hWCB3Hrn/vIeswGP8Ov4AJ6eXeX44BQM=
-----END CERTIFICATE-----
""".strip());
this.properties.getBundle().getPem().put("bundle1", pem);
assertThatIllegalStateException().isThrownBy(() -> this.registrar.registerBundles(this.registry))
.withMessage("Keystore certificate is not a URL and can't be watched");
}
@Test
void shouldFailIfPemKeystorePrivateKeyIsEmbedded() {
PemSslBundleProperties pem = new PemSslBundleProperties();
pem.setReloadOnUpdate(true);
pem.getKeystore().setCertificate("classpath:org/springframework/boot/autoconfigure/ssl/ed25519-cert.pem");
pem.getKeystore().setPrivateKey("""
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIC29RnMVTcyqXEAIO1b/6p7RdbM6TiqvnztVQ4IxYxUh
-----END PRIVATE KEY-----
""".strip());
this.properties.getBundle().getPem().put("bundle1", pem);
assertThatIllegalStateException().isThrownBy(() -> this.registrar.registerBundles(this.registry))
.withMessage("Keystore private key is not a URL and can't be watched");
}
@Test
void shouldFailIfPemTruststoreCertificateIsEmbedded() {
PemSslBundleProperties pem = new PemSslBundleProperties();
pem.setReloadOnUpdate(true);
pem.getTruststore().setCertificate("""
-----BEGIN CERTIFICATE-----
MIICCzCCAb2gAwIBAgIUZbDi7G5czH+Yi0k2EMWxdf00XagwBQYDK2VwMHsxCzAJ
BgNVBAYTAlhYMRIwEAYDVQQIDAlTdGF0ZU5hbWUxETAPBgNVBAcMCENpdHlOYW1l
MRQwEgYDVQQKDAtDb21wYW55TmFtZTEbMBkGA1UECwwSQ29tcGFueVNlY3Rpb25O
YW1lMRIwEAYDVQQDDAlsb2NhbGhvc3QwHhcNMjMwOTExMTIxNDMwWhcNMzMwOTA4
MTIxNDMwWjB7MQswCQYDVQQGEwJYWDESMBAGA1UECAwJU3RhdGVOYW1lMREwDwYD
VQQHDAhDaXR5TmFtZTEUMBIGA1UECgwLQ29tcGFueU5hbWUxGzAZBgNVBAsMEkNv
bXBhbnlTZWN0aW9uTmFtZTESMBAGA1UEAwwJbG9jYWxob3N0MCowBQYDK2VwAyEA
Q/DDA4BSgZ+Hx0DUxtIRjVjN+OcxXVURwAWc3Gt9GUyjUzBRMB0GA1UdDgQWBBSv
EdpoaBMBoxgO96GFbf03k07DSTAfBgNVHSMEGDAWgBSvEdpoaBMBoxgO96GFbf03
k07DSTAPBgNVHRMBAf8EBTADAQH/MAUGAytlcANBAHMXDkGd57d4F4cRk/8UjhxD
7OtRBZfdfznSvlhJIMNfH5q0zbC2eO3hWCB3Hrn/vIeswGP8Ov4AJ6eXeX44BQM=
-----END CERTIFICATE-----
""".strip());
this.properties.getBundle().getPem().put("bundle1", pem);
assertThatIllegalStateException().isThrownBy(() -> this.registrar.registerBundles(this.registry))
.withMessage("Truststore certificate is not a URL and can't be watched");
}
@Test
void shouldFailIfPemTruststorePrivateKeyIsEmbedded() {
PemSslBundleProperties pem = new PemSslBundleProperties();
pem.setReloadOnUpdate(true);
pem.getTruststore().setCertificate("classpath:org/springframework/boot/autoconfigure/ssl/ed25519-cert.pem");
pem.getTruststore().setPrivateKey("""
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIC29RnMVTcyqXEAIO1b/6p7RdbM6TiqvnztVQ4IxYxUh
-----END PRIVATE KEY-----
""".strip());
this.properties.getBundle().getPem().put("bundle1", pem);
assertThatIllegalStateException().isThrownBy(() -> this.registrar.registerBundles(this.registry))
.withMessage("Truststore private key is not a URL and can't be watched");
}
private void pathEndingWith(Set<Path> paths, String... suffixes) {
for (String suffix : suffixes) {
assertThat(paths).anyMatch((path) -> path.getFileName().toString().endsWith(suffix));
}
}
}

@ -16,21 +16,37 @@
package org.springframework.boot.ssl; package org.springframework.boot.ssl;
import java.util.Collections;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Consumer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.util.Assert; import org.springframework.util.Assert;
/** /**
* Default {@link SslBundleRegistry} implementation. * Default {@link SslBundleRegistry} implementation.
* *
* @author Scott Frederick * @author Scott Frederick
* @author Moritz Halbritter
* @since 3.1.0 * @since 3.1.0
*/ */
public class DefaultSslBundleRegistry implements SslBundleRegistry, SslBundles { public class DefaultSslBundleRegistry implements SslBundleRegistry, SslBundles {
private static final Log logger = LogFactory.getLog(DefaultSslBundleRegistry.class);
private final Map<String, SslBundle> bundles = new ConcurrentHashMap<>(); private final Map<String, SslBundle> bundles = new ConcurrentHashMap<>();
private final Map<String, List<Consumer<SslBundle>>> listeners = new ConcurrentHashMap<>();
private final Set<String> bundlesWithoutListeners = ConcurrentHashMap.newKeySet();
public DefaultSslBundleRegistry() { public DefaultSslBundleRegistry() {
} }
@ -48,12 +64,55 @@ public class DefaultSslBundleRegistry implements SslBundleRegistry, SslBundles {
@Override @Override
public SslBundle getBundle(String name) { public SslBundle getBundle(String name) {
return getBundle(name, null);
}
@Override
public SslBundle getBundle(String name, Consumer<SslBundle> onUpdate) throws NoSuchSslBundleException {
Assert.notNull(name, "Name must not be null"); Assert.notNull(name, "Name must not be null");
SslBundle bundle = this.bundles.get(name); SslBundle bundle = this.bundles.get(name);
if (bundle == null) { if (bundle == null) {
throw new NoSuchSslBundleException(name, "SSL bundle name '%s' cannot be found".formatted(name)); throw new NoSuchSslBundleException(name, "SSL bundle name '%s' cannot be found".formatted(name));
} }
addListener(name, onUpdate);
return bundle; return bundle;
} }
@Override
public void updateBundle(String name, SslBundle sslBundle) {
Assert.notNull(name, "Name must not be null");
SslBundle bundle = this.bundles.get(name);
if (bundle == null) {
throw new NoSuchSslBundleException(name, "SSL bundle name '%s' cannot be found".formatted(name));
}
this.bundles.put(name, sslBundle);
notifyListeners(name, sslBundle);
logMissingListeners(name);
}
private void notifyListeners(String name, SslBundle sslBundle) {
List<Consumer<SslBundle>> listeners = this.listeners.getOrDefault(name, Collections.emptyList());
for (Consumer<SslBundle> listener : listeners) {
listener.accept(sslBundle);
}
}
private void addListener(String name, Consumer<SslBundle> onUpdate) {
if (onUpdate == null) {
this.bundlesWithoutListeners.add(name);
}
else {
this.listeners.computeIfAbsent(name, (ignore) -> new CopyOnWriteArrayList<>()).add(onUpdate);
}
}
private void logMissingListeners(String name) {
if (logger.isWarnEnabled()) {
if (this.bundlesWithoutListeners.contains(name)) {
logger.warn(LogMessage.format("SSL bundle '%s' has been updated, but not all consumers are updateable",
name));
}
}
}
} }

@ -20,6 +20,7 @@ package org.springframework.boot.ssl;
* Interface that can be used to register an {@link SslBundle} for a given name. * Interface that can be used to register an {@link SslBundle} for a given name.
* *
* @author Scott Frederick * @author Scott Frederick
* @author Moritz Halbritter
* @since 3.1.0 * @since 3.1.0
*/ */
public interface SslBundleRegistry { public interface SslBundleRegistry {
@ -31,4 +32,12 @@ public interface SslBundleRegistry {
*/ */
void registerBundle(String name, SslBundle bundle); void registerBundle(String name, SslBundle bundle);
/**
* Updates an {@link SslBundle}.
* @param name the bundle name
* @param sslBundle the updated bundle
* @since 3.2.0
*/
void updateBundle(String name, SslBundle sslBundle);
} }

@ -16,10 +16,13 @@
package org.springframework.boot.ssl; package org.springframework.boot.ssl;
import java.util.function.Consumer;
/** /**
* A managed set of {@link SslBundle} instances that can be retrieved by name. * A managed set of {@link SslBundle} instances that can be retrieved by name.
* *
* @author Scott Frederick * @author Scott Frederick
* @author Moritz Halbritter
* @since 3.1.0 * @since 3.1.0
*/ */
public interface SslBundles { public interface SslBundles {
@ -32,4 +35,15 @@ public interface SslBundles {
*/ */
SslBundle getBundle(String bundleName) throws NoSuchSslBundleException; SslBundle getBundle(String bundleName) throws NoSuchSslBundleException;
/**
* Return an {@link SslBundle} with the provided name.
* @param bundleName the bundle name
* @param onUpdate the callback, which is called when the bundle is updated or
* {@code null}
* @return the bundle
* @throws NoSuchSslBundleException if a bundle with the provided name does not exist
* @since 3.2.0
*/
SslBundle getBundle(String bundleName, Consumer<SslBundle> onUpdate) throws NoSuchSslBundleException;
} }

@ -1,64 +0,0 @@
/*
* Copyright 2012-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.ssl.pem;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.regex.Pattern;
import org.springframework.util.FileCopyUtils;
import org.springframework.util.ResourceUtils;
/**
* Utility to load PEM content.
*
* @author Scott Frederick
* @author Phillip Webb
*/
final class PemContent {
private static final Pattern PEM_HEADER = Pattern.compile("-+BEGIN\\s+[^-]*-+", Pattern.CASE_INSENSITIVE);
private static final Pattern PEM_FOOTER = Pattern.compile("-+END\\s+[^-]*-+", Pattern.CASE_INSENSITIVE);
private PemContent() {
}
static String load(String content) {
if (content == null || isPemContent(content)) {
return content;
}
try {
URL url = ResourceUtils.getURL(content);
try (Reader reader = new InputStreamReader(url.openStream(), StandardCharsets.UTF_8)) {
return FileCopyUtils.copyToString(reader);
}
}
catch (IOException ex) {
throw new IllegalStateException(
"Error reading certificate or key from file '" + content + "':" + ex.getMessage(), ex);
}
}
private static boolean isPemContent(String content) {
return content != null && PEM_HEADER.matcher(content).find() && PEM_FOOTER.matcher(content).find();
}
}

@ -17,6 +17,10 @@
package org.springframework.boot.ssl.pem; package org.springframework.boot.ssl.pem;
import java.io.IOException; import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.security.KeyStore; import java.security.KeyStore;
import java.security.KeyStoreException; import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
@ -26,7 +30,10 @@ import java.security.cert.X509Certificate;
import org.springframework.boot.ssl.SslStoreBundle; import org.springframework.boot.ssl.SslStoreBundle;
import org.springframework.boot.ssl.pem.KeyVerifier.Result; import org.springframework.boot.ssl.pem.KeyVerifier.Result;
import org.springframework.boot.ssl.pem.PemSslStoreDetails.Type;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.FileCopyUtils;
import org.springframework.util.ResourceUtils;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
/** /**
@ -149,12 +156,14 @@ public class PemSslStoreBundle implements SslStoreBundle {
} }
private static PrivateKey loadPrivateKey(PemSslStoreDetails details) { private static PrivateKey loadPrivateKey(PemSslStoreDetails details) {
String privateKeyContent = PemContent.load(details.privateKey()); String privateKeyContent = (details.getPrivateKeyType() == Type.PEM) ? details.privateKey()
: load(details.privateKey());
return PemPrivateKeyParser.parse(privateKeyContent, details.privateKeyPassword()); return PemPrivateKeyParser.parse(privateKeyContent, details.privateKeyPassword());
} }
private static X509Certificate[] loadCertificates(PemSslStoreDetails details) { private static X509Certificate[] loadCertificates(PemSslStoreDetails details) {
String certificateContent = PemContent.load(details.certificate()); String certificateContent = (details.getCertificateType() == Type.PEM) ? details.certificate()
: load(details.certificate());
X509Certificate[] certificates = PemCertificateParser.parse(certificateContent); X509Certificate[] certificates = PemCertificateParser.parse(certificateContent);
Assert.state(certificates != null && certificates.length > 0, "Loaded certificates are empty"); Assert.state(certificates != null && certificates.length > 0, "Loaded certificates are empty");
return certificates; return certificates;
@ -180,4 +189,20 @@ public class PemSslStoreBundle implements SslStoreBundle {
} }
} }
private static String load(String location) {
if (location == null) {
return null;
}
try {
URL url = ResourceUtils.getURL(location);
try (Reader reader = new InputStreamReader(url.openStream(), StandardCharsets.UTF_8)) {
return FileCopyUtils.copyToString(reader);
}
}
catch (IOException ex) {
throw new IllegalStateException(
"Error reading certificate or key from file '" + location + "':" + ex.getMessage(), ex);
}
}
} }

@ -17,6 +17,7 @@
package org.springframework.boot.ssl.pem; package org.springframework.boot.ssl.pem;
import java.security.KeyStore; import java.security.KeyStore;
import java.util.regex.Pattern;
import org.springframework.util.ResourceUtils; import org.springframework.util.ResourceUtils;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
@ -37,6 +38,10 @@ import org.springframework.util.StringUtils;
*/ */
public record PemSslStoreDetails(String type, String certificate, String privateKey, String privateKeyPassword) { public record PemSslStoreDetails(String type, String certificate, String privateKey, String privateKeyPassword) {
private static final Pattern PEM_HEADER = Pattern.compile("-+BEGIN\\s+[^-]*-+", Pattern.CASE_INSENSITIVE);
private static final Pattern PEM_FOOTER = Pattern.compile("-+END\\s+[^-]*-+", Pattern.CASE_INSENSITIVE);
public PemSslStoreDetails(String type, String certificate, String privateKey) { public PemSslStoreDetails(String type, String certificate, String privateKey) {
this(type, certificate, privateKey, null); this(type, certificate, privateKey, null);
} }
@ -59,6 +64,24 @@ public record PemSslStoreDetails(String type, String certificate, String private
return new PemSslStoreDetails(this.type, this.certificate, this.privateKey, password); return new PemSslStoreDetails(this.type, this.certificate, this.privateKey, password);
} }
/**
* Returns the type of the private key.
* @return the type of the private key
* @since 3.2.0
*/
public Type getPrivateKeyType() {
return (isPemContent(this.privateKey)) ? Type.PEM : Type.URL;
}
/**
* Returns the type of the certificate.
* @return the type of the certificate
* @since 3.2.0
*/
public Type getCertificateType() {
return (isPemContent(this.certificate)) ? Type.PEM : Type.URL;
}
boolean isEmpty() { boolean isEmpty() {
return isEmpty(this.type) && isEmpty(this.certificate) && isEmpty(this.privateKey); return isEmpty(this.type) && isEmpty(this.certificate) && isEmpty(this.privateKey);
} }
@ -77,4 +100,24 @@ public record PemSslStoreDetails(String type, String certificate, String private
return new PemSslStoreDetails(null, certificate, null); return new PemSslStoreDetails(null, certificate, null);
} }
private static boolean isPemContent(String content) {
return content != null && PEM_HEADER.matcher(content).find() && PEM_FOOTER.matcher(content).find();
}
/**
* Type of key or certificate.
*/
public enum Type {
/**
* URL loadable by {@link ResourceUtils#getURL}.
*/
URL,
/**
* PEM content.
*/
PEM
}
} }

@ -25,9 +25,12 @@ import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import reactor.netty.http.HttpProtocol; import reactor.netty.http.HttpProtocol;
import reactor.netty.http.server.HttpServer; import reactor.netty.http.server.HttpServer;
import org.springframework.boot.ssl.SslBundle;
import org.springframework.boot.web.reactive.server.AbstractReactiveWebServerFactory; import org.springframework.boot.web.reactive.server.AbstractReactiveWebServerFactory;
import org.springframework.boot.web.reactive.server.ReactiveWebServerFactory; import org.springframework.boot.web.reactive.server.ReactiveWebServerFactory;
import org.springframework.boot.web.server.Shutdown; import org.springframework.boot.web.server.Shutdown;
@ -42,10 +45,13 @@ import org.springframework.util.Assert;
* {@link ReactiveWebServerFactory} that can be used to create {@link NettyWebServer}s. * {@link ReactiveWebServerFactory} that can be used to create {@link NettyWebServer}s.
* *
* @author Brian Clozel * @author Brian Clozel
* @author Moritz Halbritter
* @since 2.0.0 * @since 2.0.0
*/ */
public class NettyReactiveWebServerFactory extends AbstractReactiveWebServerFactory { public class NettyReactiveWebServerFactory extends AbstractReactiveWebServerFactory {
private static final Log logger = LogFactory.getLog(NettyReactiveWebServerFactory.class);
private Set<NettyServerCustomizer> serverCustomizers = new LinkedHashSet<>(); private Set<NettyServerCustomizer> serverCustomizers = new LinkedHashSet<>();
private final List<NettyRouteProvider> routeProviders = new ArrayList<>(); private final List<NettyRouteProvider> routeProviders = new ArrayList<>();
@ -170,7 +176,14 @@ public class NettyReactiveWebServerFactory extends AbstractReactiveWebServerFact
} }
private HttpServer customizeSslConfiguration(HttpServer httpServer) { private HttpServer customizeSslConfiguration(HttpServer httpServer) {
return new SslServerCustomizer(getHttp2(), getSsl().getClientAuth(), getSslBundle()).apply(httpServer); SslServerCustomizer sslServerCustomizer = new SslServerCustomizer(getHttp2(), getSsl().getClientAuth());
SslBundle sslBundle = getSslBundle((updatedBundle) -> {
logger.debug("SSL Bundle has been updated, reloading SSL configuration");
sslServerCustomizer.setSslBundle(updatedBundle);
sslServerCustomizer.reload();
});
sslServerCustomizer.setSslBundle(sslBundle);
return sslServerCustomizer.apply(httpServer);
} }
private HttpProtocol[] listProtocols() { private HttpProtocol[] listProtocols() {

@ -106,7 +106,6 @@ public class NettyWebServer implements WebServer {
* @param resourceFactory the factory for the server's {@link LoopResources loop * @param resourceFactory the factory for the server's {@link LoopResources loop
* resources}, may be {@code null} * resources}, may be {@code null}
* @since 3.2.0 * @since 3.2.0
* {@link #NettyWebServer(HttpServer, ReactorHttpHandlerAdapter, Duration, Shutdown, ReactorResourceFactory)}
*/ */
public NettyWebServer(HttpServer httpServer, ReactorHttpHandlerAdapter handlerAdapter, Duration lifecycleTimeout, public NettyWebServer(HttpServer httpServer, ReactorHttpHandlerAdapter handlerAdapter, Duration lifecycleTimeout,
Shutdown shutdown, ReactorResourceFactory resourceFactory) { Shutdown shutdown, ReactorResourceFactory resourceFactory) {
@ -149,7 +148,7 @@ public class NettyWebServer implements WebServer {
StringBuilder message = new StringBuilder(); StringBuilder message = new StringBuilder();
tryAppend(message, "port %s", server::port); tryAppend(message, "port %s", server::port);
tryAppend(message, "path %s", server::path); tryAppend(message, "path %s", server::path);
return (message.length() > 0) ? "Netty started on " + message : "Netty started"; return (!message.isEmpty()) ? "Netty started on " + message : "Netty started";
} }
protected String getStartedLogMessage() { protected String getStartedLogMessage() {
@ -159,10 +158,11 @@ public class NettyWebServer implements WebServer {
private void tryAppend(StringBuilder message, String format, Supplier<Object> supplier) { private void tryAppend(StringBuilder message, String format, Supplier<Object> supplier) {
try { try {
Object value = supplier.get(); Object value = supplier.get();
message.append((message.length() != 0) ? " " : ""); message.append((!message.isEmpty()) ? " " : "");
message.append(String.format(format, value)); message.append(String.format(format, value));
} }
catch (UnsupportedOperationException ex) { catch (UnsupportedOperationException ex) {
// Ignore
} }
} }

@ -21,11 +21,13 @@ import reactor.netty.http.Http11SslContextSpec;
import reactor.netty.http.Http2SslContextSpec; import reactor.netty.http.Http2SslContextSpec;
import reactor.netty.http.server.HttpServer; import reactor.netty.http.server.HttpServer;
import reactor.netty.tcp.AbstractProtocolSslContextSpec; import reactor.netty.tcp.AbstractProtocolSslContextSpec;
import reactor.netty.tcp.SslProvider;
import org.springframework.boot.ssl.SslBundle; import org.springframework.boot.ssl.SslBundle;
import org.springframework.boot.ssl.SslOptions; import org.springframework.boot.ssl.SslOptions;
import org.springframework.boot.web.server.Http2; import org.springframework.boot.web.server.Http2;
import org.springframework.boot.web.server.Ssl; import org.springframework.boot.web.server.Ssl;
import org.springframework.util.Assert;
/** /**
* {@link NettyServerCustomizer} that configures SSL for the given Reactor Netty server * {@link NettyServerCustomizer} that configures SSL for the given Reactor Netty server
@ -36,6 +38,7 @@ import org.springframework.boot.web.server.Ssl;
* @author Chris Bono * @author Chris Bono
* @author Cyril Dangerville * @author Cyril Dangerville
* @author Scott Frederick * @author Scott Frederick
* @author Moritz Halbritter
* @since 2.0.0 * @since 2.0.0
*/ */
public class SslServerCustomizer implements NettyServerCustomizer { public class SslServerCustomizer implements NettyServerCustomizer {
@ -44,7 +47,13 @@ public class SslServerCustomizer implements NettyServerCustomizer {
private final Ssl.ClientAuth clientAuth; private final Ssl.ClientAuth clientAuth;
private final SslBundle sslBundle; private volatile SslBundle sslBundle;
private volatile SslProvider currentSslProvider;
SslServerCustomizer(Http2 http2, Ssl.ClientAuth clientAuth) {
this(http2, clientAuth, null);
}
public SslServerCustomizer(Http2 http2, Ssl.ClientAuth clientAuth, SslBundle sslBundle) { public SslServerCustomizer(Http2 http2, Ssl.ClientAuth clientAuth, SslBundle sslBundle) {
this.http2 = http2; this.http2 = http2;
@ -52,13 +61,25 @@ public class SslServerCustomizer implements NettyServerCustomizer {
this.sslBundle = sslBundle; this.sslBundle = sslBundle;
} }
void setSslBundle(SslBundle sslBundle) {
this.sslBundle = sslBundle;
}
@Override @Override
public HttpServer apply(HttpServer server) { public HttpServer apply(HttpServer server) {
AbstractProtocolSslContextSpec<?> sslContextSpec = createSslContextSpec(); AbstractProtocolSslContextSpec<?> sslContextSpec = createSslContextSpec();
return server.secure((spec) -> spec.sslContext(sslContextSpec)); this.currentSslProvider = SslProvider.builder().sslContext(sslContextSpec).build();
return server.secure((spec) -> spec.sslContext(sslContextSpec)
.setSniAsyncMappings((domainName, promise) -> promise.setSuccess(this.currentSslProvider)));
}
void reload() {
AbstractProtocolSslContextSpec<?> sslContextSpec = createSslContextSpec();
this.currentSslProvider = SslProvider.builder().sslContext(sslContextSpec).build();
} }
protected AbstractProtocolSslContextSpec<?> createSslContextSpec() { protected AbstractProtocolSslContextSpec<?> createSslContextSpec() {
Assert.notNull(this.sslBundle, "sslBundle must not be null");
AbstractProtocolSslContextSpec<?> sslContextSpec = (this.http2 != null && this.http2.isEnabled()) AbstractProtocolSslContextSpec<?> sslContextSpec = (this.http2 != null && this.http2.isEnabled())
? Http2SslContextSpec.forServer(this.sslBundle.getManagers().getKeyManagerFactory()) ? Http2SslContextSpec.forServer(this.sslBundle.getManagers().getKeyManagerFactory())
: Http11SslContextSpec.forServer(this.sslBundle.getManagers().getKeyManagerFactory()); : Http11SslContextSpec.forServer(this.sslBundle.getManagers().getKeyManagerFactory());

@ -39,6 +39,7 @@ import org.springframework.util.StringUtils;
* @author Andy Wilkinson * @author Andy Wilkinson
* @author Scott Frederick * @author Scott Frederick
* @author Cyril Dangerville * @author Cyril Dangerville
* @author Moritz Halbritter
*/ */
class SslConnectorCustomizer implements TomcatConnectorCustomizer { class SslConnectorCustomizer implements TomcatConnectorCustomizer {
@ -66,15 +67,19 @@ class SslConnectorCustomizer implements TomcatConnectorCustomizer {
* @param protocol the protocol * @param protocol the protocol
*/ */
void configureSsl(AbstractHttp11JsseProtocol<?> protocol) { void configureSsl(AbstractHttp11JsseProtocol<?> protocol) {
SslBundleKey key = this.sslBundle.getKey();
SslStoreBundle stores = this.sslBundle.getStores();
SslOptions options = this.sslBundle.getOptions();
protocol.setSSLEnabled(true); protocol.setSSLEnabled(true);
SSLHostConfig sslHostConfig = new SSLHostConfig(); SSLHostConfig sslHostConfig = new SSLHostConfig();
sslHostConfig.setHostName(protocol.getDefaultSSLHostConfigName()); sslHostConfig.setHostName(protocol.getDefaultSSLHostConfigName());
sslHostConfig.setSslProtocol(this.sslBundle.getProtocol());
protocol.addSslHostConfig(sslHostConfig);
configureSslClientAuth(sslHostConfig); configureSslClientAuth(sslHostConfig);
applySslBundle(protocol, sslHostConfig);
protocol.addSslHostConfig(sslHostConfig, true);
}
private void applySslBundle(AbstractHttp11JsseProtocol<?> protocol, SSLHostConfig sslHostConfig) {
SslBundleKey key = this.sslBundle.getKey();
SslStoreBundle stores = this.sslBundle.getStores();
SslOptions options = this.sslBundle.getOptions();
sslHostConfig.setSslProtocol(this.sslBundle.getProtocol());
SSLHostConfigCertificate certificate = new SSLHostConfigCertificate(sslHostConfig, Type.UNDEFINED); SSLHostConfigCertificate certificate = new SSLHostConfigCertificate(sslHostConfig, Type.UNDEFINED);
String keystorePassword = (stores.getKeyStorePassword() != null) ? stores.getKeyStorePassword() : ""; String keystorePassword = (stores.getKeyStorePassword() != null) ? stores.getKeyStorePassword() : "";
certificate.setCertificateKeystorePassword(keystorePassword); certificate.setCertificateKeystorePassword(keystorePassword);
@ -89,30 +94,26 @@ class SslConnectorCustomizer implements TomcatConnectorCustomizer {
String ciphers = StringUtils.arrayToCommaDelimitedString(options.getCiphers()); String ciphers = StringUtils.arrayToCommaDelimitedString(options.getCiphers());
sslHostConfig.setCiphers(ciphers); sslHostConfig.setCiphers(ciphers);
} }
configureEnabledProtocols(protocol); configureSslStoreProvider(protocol, sslHostConfig, certificate, stores);
configureSslStoreProvider(protocol, sslHostConfig, certificate); configureEnabledProtocols(sslHostConfig, options);
} }
private void configureEnabledProtocols(AbstractHttp11JsseProtocol<?> protocol) { private void configureEnabledProtocols(SSLHostConfig sslHostConfig, SslOptions options) {
SslOptions options = this.sslBundle.getOptions();
if (options.getEnabledProtocols() != null) { if (options.getEnabledProtocols() != null) {
String enabledProtocols = StringUtils.arrayToDelimitedString(options.getEnabledProtocols(), "+"); String enabledProtocols = StringUtils.arrayToDelimitedString(options.getEnabledProtocols(), "+");
for (SSLHostConfig sslHostConfig : protocol.findSslHostConfigs()) {
sslHostConfig.setProtocols(enabledProtocols); sslHostConfig.setProtocols(enabledProtocols);
} }
} }
}
private void configureSslClientAuth(SSLHostConfig config) { private void configureSslClientAuth(SSLHostConfig config) {
config.setCertificateVerification(ClientAuth.map(this.clientAuth, "none", "optional", "required")); config.setCertificateVerification(ClientAuth.map(this.clientAuth, "none", "optional", "required"));
} }
protected void configureSslStoreProvider(AbstractHttp11JsseProtocol<?> protocol, SSLHostConfig sslHostConfig, private void configureSslStoreProvider(AbstractHttp11JsseProtocol<?> protocol, SSLHostConfig sslHostConfig,
SSLHostConfigCertificate certificate) { SSLHostConfigCertificate certificate, SslStoreBundle stores) {
Assert.isInstanceOf(Http11NioProtocol.class, protocol, Assert.isInstanceOf(Http11NioProtocol.class, protocol,
"SslStoreProvider can only be used with Http11NioProtocol"); "SslStoreProvider can only be used with Http11NioProtocol");
try { try {
SslStoreBundle stores = this.sslBundle.getStores();
if (stores.getKeyStore() != null) { if (stores.getKeyStore() != null) {
certificate.setCertificateKeystore(stores.getKeyStore()); certificate.setCertificateKeystore(stores.getKeyStore());
} }

@ -35,12 +35,15 @@ import org.apache.catalina.connector.Connector;
import org.apache.catalina.core.AprLifecycleListener; import org.apache.catalina.core.AprLifecycleListener;
import org.apache.catalina.loader.WebappLoader; import org.apache.catalina.loader.WebappLoader;
import org.apache.catalina.startup.Tomcat; import org.apache.catalina.startup.Tomcat;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.coyote.AbstractProtocol; import org.apache.coyote.AbstractProtocol;
import org.apache.coyote.ProtocolHandler; import org.apache.coyote.ProtocolHandler;
import org.apache.coyote.http2.Http2Protocol; import org.apache.coyote.http2.Http2Protocol;
import org.apache.tomcat.util.modeler.Registry; import org.apache.tomcat.util.modeler.Registry;
import org.apache.tomcat.util.scan.StandardJarScanFilter; import org.apache.tomcat.util.scan.StandardJarScanFilter;
import org.springframework.boot.ssl.SslBundle;
import org.springframework.boot.util.LambdaSafe; import org.springframework.boot.util.LambdaSafe;
import org.springframework.boot.web.reactive.server.AbstractReactiveWebServerFactory; import org.springframework.boot.web.reactive.server.AbstractReactiveWebServerFactory;
import org.springframework.boot.web.reactive.server.ReactiveWebServerFactory; import org.springframework.boot.web.reactive.server.ReactiveWebServerFactory;
@ -57,11 +60,14 @@ import org.springframework.util.StringUtils;
* *
* @author Brian Clozel * @author Brian Clozel
* @author HaiTao Zhang * @author HaiTao Zhang
* @author Moritz Halbritter
* @since 2.0.0 * @since 2.0.0
*/ */
public class TomcatReactiveWebServerFactory extends AbstractReactiveWebServerFactory public class TomcatReactiveWebServerFactory extends AbstractReactiveWebServerFactory
implements ConfigurableTomcatWebServerFactory { implements ConfigurableTomcatWebServerFactory {
private static final Log logger = LogFactory.getLog(TomcatReactiveWebServerFactory.class);
private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8;
/** /**
@ -224,7 +230,15 @@ public class TomcatReactiveWebServerFactory extends AbstractReactiveWebServerFac
} }
private void customizeSsl(Connector connector) { private void customizeSsl(Connector connector) {
new SslConnectorCustomizer(getSsl().getClientAuth(), getSslBundle()).customize(connector); SslBundle sslBundle = getSslBundle((updatedBundle) -> {
logger.debug("SSL Bundle has been updated, reloading SSL configuration");
customizeSsl(connector, updatedBundle);
});
customizeSsl(connector, sslBundle);
}
private void customizeSsl(Connector connector, SslBundle sslBundle) {
new SslConnectorCustomizer(getSsl().getClientAuth(), sslBundle).customize(connector);
} }
@Override @Override

@ -62,6 +62,8 @@ import org.apache.catalina.util.SessionConfig;
import org.apache.catalina.webresources.AbstractResourceSet; import org.apache.catalina.webresources.AbstractResourceSet;
import org.apache.catalina.webresources.EmptyResource; import org.apache.catalina.webresources.EmptyResource;
import org.apache.catalina.webresources.StandardRoot; import org.apache.catalina.webresources.StandardRoot;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.coyote.AbstractProtocol; import org.apache.coyote.AbstractProtocol;
import org.apache.coyote.ProtocolHandler; import org.apache.coyote.ProtocolHandler;
import org.apache.coyote.http2.Http2Protocol; import org.apache.coyote.http2.Http2Protocol;
@ -69,6 +71,7 @@ import org.apache.tomcat.util.http.Rfc6265CookieProcessor;
import org.apache.tomcat.util.modeler.Registry; import org.apache.tomcat.util.modeler.Registry;
import org.apache.tomcat.util.scan.StandardJarScanFilter; import org.apache.tomcat.util.scan.StandardJarScanFilter;
import org.springframework.boot.ssl.SslBundle;
import org.springframework.boot.util.LambdaSafe; import org.springframework.boot.util.LambdaSafe;
import org.springframework.boot.web.server.Cookie.SameSite; import org.springframework.boot.web.server.Cookie.SameSite;
import org.springframework.boot.web.server.ErrorPage; import org.springframework.boot.web.server.ErrorPage;
@ -103,6 +106,7 @@ import org.springframework.util.StringUtils;
* @author Eddú Meléndez * @author Eddú Meléndez
* @author Christoffer Sawicki * @author Christoffer Sawicki
* @author Dawid Antecki * @author Dawid Antecki
* @author Moritz Halbritter
* @since 2.0.0 * @since 2.0.0
* @see #setPort(int) * @see #setPort(int)
* @see #setContextLifecycleListeners(Collection) * @see #setContextLifecycleListeners(Collection)
@ -111,6 +115,8 @@ import org.springframework.util.StringUtils;
public class TomcatServletWebServerFactory extends AbstractServletWebServerFactory public class TomcatServletWebServerFactory extends AbstractServletWebServerFactory
implements ConfigurableTomcatWebServerFactory, ResourceLoaderAware { implements ConfigurableTomcatWebServerFactory, ResourceLoaderAware {
private static final Log logger = LogFactory.getLog(TomcatServletWebServerFactory.class);
private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8;
private static final Set<Class<?>> NO_CLASSES = Collections.emptySet(); private static final Set<Class<?>> NO_CLASSES = Collections.emptySet();
@ -366,7 +372,15 @@ public class TomcatServletWebServerFactory extends AbstractServletWebServerFacto
} }
private void customizeSsl(Connector connector) { private void customizeSsl(Connector connector) {
new SslConnectorCustomizer(getSsl().getClientAuth(), getSslBundle()).customize(connector); SslBundle sslBundle = getSslBundle((updatedBundle) -> {
logger.debug("SSL Bundle has been updated, reloading SSL configuration");
customizeSsl(connector, updatedBundle);
});
customizeSsl(connector, sslBundle);
}
private void customizeSsl(Connector connector, SslBundle sslBundle) {
new SslConnectorCustomizer(getSsl().getClientAuth(), sslBundle).customize(connector);
} }
/** /**

@ -23,6 +23,7 @@ import java.nio.file.Files;
import java.util.Arrays; import java.util.Arrays;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.Set; import java.util.Set;
import java.util.function.Consumer;
import org.springframework.boot.ssl.SslBundle; import org.springframework.boot.ssl.SslBundle;
import org.springframework.boot.ssl.SslBundles; import org.springframework.boot.ssl.SslBundles;
@ -216,6 +217,17 @@ public abstract class AbstractConfigurableWebServerFactory implements Configurab
return WebServerSslBundle.get(this.ssl, this.sslBundles, this.sslStoreProvider); return WebServerSslBundle.get(this.ssl, this.sslBundles, this.sslStoreProvider);
} }
/**
* Return the {@link SslBundle} that should be used with this server, registering a
* callback for bundle updates.
* @param onUpdate the callback for bundle updates
* @return the SSL bundle
*/
@SuppressWarnings("removal")
protected final SslBundle getSslBundle(Consumer<SslBundle> onUpdate) {
return WebServerSslBundle.get(this.ssl, this.sslBundles, this.sslStoreProvider, onUpdate);
}
/** /**
* Return the absolute temp dir for given web server. * Return the absolute temp dir for given web server.
* @param prefix server name * @param prefix server name

@ -17,6 +17,7 @@
package org.springframework.boot.web.server; package org.springframework.boot.web.server;
import java.security.KeyStore; import java.security.KeyStore;
import java.util.function.Consumer;
import org.springframework.boot.ssl.NoSuchSslBundleException; import org.springframework.boot.ssl.NoSuchSslBundleException;
import org.springframework.boot.ssl.SslBundle; import org.springframework.boot.ssl.SslBundle;
@ -137,6 +138,25 @@ public final class WebServerSslBundle implements SslBundle {
@Deprecated(since = "3.1.0", forRemoval = true) @Deprecated(since = "3.1.0", forRemoval = true)
@SuppressWarnings("removal") @SuppressWarnings("removal")
public static SslBundle get(Ssl ssl, SslBundles sslBundles, SslStoreProvider sslStoreProvider) { public static SslBundle get(Ssl ssl, SslBundles sslBundles, SslStoreProvider sslStoreProvider) {
return get(ssl, sslBundles, sslStoreProvider, null);
}
/**
* Get the {@link SslBundle} that should be used for the given {@link Ssl} and
* {@link SslStoreProvider} instances.
* @param ssl the source {@link Ssl} instance
* @param sslBundles the bundles that should be used when {@link Ssl#getBundle()} is
* set
* @param sslStoreProvider the {@link SslStoreProvider} to use or {@code null}
* @param onUpdate the callback, which is called when the {@link SslBundle} is updated
* @return a {@link SslBundle} instance
* @throws NoSuchSslBundleException if a bundle lookup fails
* @deprecated since 3.1.0 for removal in 3.3.0 along with {@link SslStoreProvider}
*/
@Deprecated(since = "3.1.0", forRemoval = true)
@SuppressWarnings("removal")
public static SslBundle get(Ssl ssl, SslBundles sslBundles, SslStoreProvider sslStoreProvider,
Consumer<SslBundle> onUpdate) {
Assert.state(Ssl.isEnabled(ssl), "SSL is not enabled"); Assert.state(Ssl.isEnabled(ssl), "SSL is not enabled");
String keyPassword = (sslStoreProvider != null) ? sslStoreProvider.getKeyPassword() : null; String keyPassword = (sslStoreProvider != null) ? sslStoreProvider.getKeyPassword() : null;
keyPassword = (keyPassword != null) ? keyPassword : ssl.getKeyPassword(); keyPassword = (keyPassword != null) ? keyPassword : ssl.getKeyPassword();
@ -149,7 +169,7 @@ public final class WebServerSslBundle implements SslBundle {
Assert.state(sslBundles != null, Assert.state(sslBundles != null,
() -> "SSL bundle '%s' was requested but no SslBundles instance was provided" () -> "SSL bundle '%s' was requested but no SslBundles instance was provided"
.formatted(bundleName)); .formatted(bundleName));
return sslBundles.getBundle(bundleName); return sslBundles.getBundle(bundleName, onUpdate);
} }
SslStoreBundle stores = createStoreBundle(ssl); SslStoreBundle stores = createStoreBundle(ssl);
return new WebServerSslBundle(stores, keyPassword, ssl); return new WebServerSslBundle(stores, keyPassword, ssl);

@ -16,26 +16,43 @@
package org.springframework.boot.ssl; package org.springframework.boot.ssl;
import java.util.concurrent.atomic.AtomicReference;
import org.awaitility.Awaitility;
import org.hamcrest.Matchers;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.boot.testsupport.system.CapturedOutput;
import org.springframework.boot.testsupport.system.OutputCaptureExtension;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
/** /**
* Tests for {@link DefaultSslBundleRegistry}. * Tests for {@link DefaultSslBundleRegistry}.
* *
* @author Phillip Webb * @author Phillip Webb
* @author Moritz Halbritter
*/ */
@ExtendWith(OutputCaptureExtension.class)
class DefaultSslBundleRegistryTests { class DefaultSslBundleRegistryTests {
private SslBundle bundle1 = mock(SslBundle.class); private final SslBundle bundle1 = mock(SslBundle.class);
private SslBundle bundle2 = mock(SslBundle.class); private final SslBundle bundle2 = mock(SslBundle.class);
private DefaultSslBundleRegistry registry = new DefaultSslBundleRegistry(); private DefaultSslBundleRegistry registry;
@BeforeEach
void setUp() {
this.registry = new DefaultSslBundleRegistry();
}
@Test @Test
void createWithNameAndBundleRegistersBundle() { void createWithNameAndBundleRegistersBundle() {
@ -89,4 +106,28 @@ class DefaultSslBundleRegistryTests {
assertThat(this.registry.getBundle("test2")).isSameAs(this.bundle2); assertThat(this.registry.getBundle("test2")).isSameAs(this.bundle2);
} }
@Test
void updateBundleShouldNotifyListeners() {
AtomicReference<SslBundle> updatedBundle = new AtomicReference<>();
this.registry.registerBundle("test1", this.bundle1);
this.registry.getBundle("test1", updatedBundle::set);
this.registry.updateBundle("test1", this.bundle2);
Awaitility.await().untilAtomic(updatedBundle, Matchers.equalTo(this.bundle2));
}
@Test
void shouldFailIfUpdatingNonRegisteredBundle() {
assertThatThrownBy(() -> this.registry.updateBundle("dummy", this.bundle1))
.isInstanceOf(NoSuchSslBundleException.class)
.hasMessageContaining("'dummy'");
}
@Test
void shouldLogIfUpdatingBundleWithoutListeners(CapturedOutput output) {
this.registry.registerBundle("test1", this.bundle1);
this.registry.getBundle("test1");
this.registry.updateBundle("test1", this.bundle2);
assertThat(output).contains("SSL bundle 'test1' has been updated");
}
} }

@ -35,6 +35,62 @@ import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
*/ */
class PemSslStoreBundleTests { class PemSslStoreBundleTests {
private static final String CERTIFICATE = """
-----BEGIN CERTIFICATE-----
MIIDqzCCApOgAwIBAgIIFMqbpqvipw0wDQYJKoZIhvcNAQELBQAwbDELMAkGA1UE
BhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExEjAQBgNVBAcTCVBhbG8gQWx0bzEP
MA0GA1UEChMGVk13YXJlMQ8wDQYDVQQLEwZTcHJpbmcxEjAQBgNVBAMTCWxvY2Fs
aG9zdDAgFw0yMzA1MDUxMTI2NThaGA8yMTIzMDQxMTExMjY1OFowbDELMAkGA1UE
BhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExEjAQBgNVBAcTCVBhbG8gQWx0bzEP
MA0GA1UEChMGVk13YXJlMQ8wDQYDVQQLEwZTcHJpbmcxEjAQBgNVBAMTCWxvY2Fs
aG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAPwHWxoE3xjRmNdD
+m+e/aFlr5wEGQUdWSDD613OB1w7kqO/audEp3c6HxDB3GPcEL0amJwXgY6CQMYu
sythuZX/EZSc2HdilTBu/5T+mbdWe5JkKThpiA0RYeucQfKuB7zv4ypioa4wiR4D
nPsZXjg95OF8pCzYEssv8wT49v+M3ohWUgfF0FPlMFCSo0YVTuzB1mhDlWKq/jhQ
11WpTmk/dQX+l6ts6bYIcJt4uItG+a68a4FutuSjZdTAE0f5SOYRBpGH96mjLwEP
fW8ZjzvKb9g4R2kiuoPxvCDs1Y/8V2yvKqLyn5Tx9x/DjFmOi0DRK/TgELvNceCb
UDJmhXMCAwEAAaNPME0wHQYDVR0OBBYEFMBIGU1nwix5RS3O5hGLLoMdR1+NMCwG
A1UdEQQlMCOCCWxvY2FsaG9zdIcQAAAAAAAAAAAAAAAAAAAAAYcEfwAAATANBgkq
hkiG9w0BAQsFAAOCAQEAhepfJgTFvqSccsT97XdAZfvB0noQx5NSynRV8NWmeOld
hHP6Fzj6xCxHSYvlUfmX8fVP9EOAuChgcbbuTIVJBu60rnDT21oOOnp8FvNonCV6
gJ89sCL7wZ77dw2RKIeUFjXXEV3QJhx2wCOVmLxnJspDoKFIEVjfLyiPXKxqe/6b
dG8zzWDZ6z+M2JNCtVoOGpljpHqMPCmbDktncv6H3dDTZ83bmLj1nbpOU587gAJ8
fl1PiUDyPRIl2cnOJd+wCHKsyym/FL7yzk0OSEZ81I92LpGd/0b2Ld3m/bpe+C4Z
ILzLXTnC6AhrLcDc9QN/EO+BiCL52n7EplNLtSn1LQ==
-----END CERTIFICATE-----
""".strip();
private static final String PRIVATE_KEY = """
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQD8B1saBN8Y0ZjX
Q/pvnv2hZa+cBBkFHVkgw+tdzgdcO5Kjv2rnRKd3Oh8Qwdxj3BC9GpicF4GOgkDG
LrMrYbmV/xGUnNh3YpUwbv+U/pm3VnuSZCk4aYgNEWHrnEHyrge87+MqYqGuMIke
A5z7GV44PeThfKQs2BLLL/ME+Pb/jN6IVlIHxdBT5TBQkqNGFU7swdZoQ5Viqv44
UNdVqU5pP3UF/perbOm2CHCbeLiLRvmuvGuBbrbko2XUwBNH+UjmEQaRh/epoy8B
D31vGY87ym/YOEdpIrqD8bwg7NWP/Fdsryqi8p+U8fcfw4xZjotA0Sv04BC7zXHg
m1AyZoVzAgMBAAECggEAfEqiZqANaF+BqXQIb4Dw42ZTJzWsIyYYnPySOGZRoe5t
QJ03uwtULYv34xtANe1DQgd6SMyc46ugBzzjtprQ3ET5Jhn99U6kdcjf+dpf85dO
hOEppP0CkDNI39nleinSfh6uIOqYgt/D143/nqQhn8oCdSOzkbwT9KnWh1bC9T7I
vFjGfElvt1/xl88qYgrWgYLgXaencNGgiv/4/M0FNhiHEGsVC7SCu6kapC/WIQpE
5IdV+HR+tiLoGZhXlhqorY7QC4xKC4wwafVSiFxqDOQAuK+SMD4TCEv0Aop+c+SE
YBigVTmgVeJkjK7IkTEhKkAEFmRF5/5w+bZD9FhTNQKBgQD+4fNG1ChSU8RdizZT
5dPlDyAxpETSCEXFFVGtPPh2j93HDWn7XugNyjn5FylTH507QlabC+5wZqltdIjK
GRB5MIinQ9/nR2fuwGc9s+0BiSEwNOUB1MWm7wWL/JUIiKq6sTi6sJIfsYg79zco
qxl5WE94aoINx9Utq1cdWhwJTQKBgQD9IjPksd4Jprz8zMrGLzR8k1gqHyhv24qY
EJ7jiHKKAP6xllTUYwh1IBSL6w2j5lfZPpIkb4Jlk2KUoX6fN81pWkBC/fTBUSIB
EHM9bL51+yKEYUbGIy/gANuRbHXsWg3sjUsFTNPN4hGTFk3w2xChCyl/f5us8Lo8
Z633SNdpvwKBgQCGyDU9XzNzVZihXtx7wS0sE7OSjKtX5cf/UCbA1V0OVUWR3SYO
J0HPCQFfF0BjFHSwwYPKuaR9C8zMdLNhK5/qdh/NU7czNi9fsZ7moh7SkRFbzJzN
OxbKD9t/CzJEMQEXeF/nWTfsSpUgILqqZtAxuuFLbAcaAnJYlCKdAumQgQKBgQCK
mqjJh68pn7gJwGUjoYNe1xtGbSsqHI9F9ovZ0MPO1v6e5M7sQJHH+Fnnxzv/y8e8
d6tz8e73iX1IHymDKv35uuZHCGF1XOR+qrA/KQUc+vcKf21OXsP/JtkTRs1HLoRD
S5aRf2DWcfvniyYARSNU2xTM8GWgi2ueWbMDHUp+ZwKBgA/swC+K+Jg5DEWm6Sau
e6y+eC6S+SoXEKkI3wf7m9aKoZo0y+jh8Gas6gratlc181pSM8O3vZG0n19b493I
apCFomMLE56zEzvyzfpsNhFhk5MBMCn0LPyzX6MiynRlGyWIj0c99fbHI3pOMufP
WgmVLTZ8uDcSW1MbdUCwFSk5
-----END PRIVATE KEY-----
""".strip();
private static final char[] EMPTY_KEY_PASSWORD = new char[] {}; private static final char[] EMPTY_KEY_PASSWORD = new char[] {};
@Test @Test
@ -99,6 +155,16 @@ class PemSslStoreBundleTests {
assertThat(bundle.getTrustStore()).satisfies(storeContainingCertAndKey("ssl")); assertThat(bundle.getTrustStore()).satisfies(storeContainingCertAndKey("ssl"));
} }
@Test
void whenHasEmbeddedKeyStoreDetailsAndTrustStoreDetails() {
PemSslStoreDetails keyStoreDetails = PemSslStoreDetails.forCertificate(CERTIFICATE).withPrivateKey(PRIVATE_KEY);
PemSslStoreDetails trustStoreDetails = PemSslStoreDetails.forCertificate(CERTIFICATE)
.withPrivateKey(PRIVATE_KEY);
PemSslStoreBundle bundle = new PemSslStoreBundle(keyStoreDetails, trustStoreDetails);
assertThat(bundle.getKeyStore()).satisfies(storeContainingCertAndKey("ssl"));
assertThat(bundle.getTrustStore()).satisfies(storeContainingCertAndKey("ssl"));
}
@Test @Test
void whenHasKeyStoreDetailsAndTrustStoreDetailsAndAlias() { void whenHasKeyStoreDetailsAndTrustStoreDetailsAndAlias() {
PemSslStoreDetails keyStoreDetails = PemSslStoreDetails.forCertificate("classpath:test-cert.pem") PemSslStoreDetails keyStoreDetails = PemSslStoreDetails.forCertificate("classpath:test-cert.pem")

@ -16,29 +16,21 @@
package org.springframework.boot.ssl.pem; package org.springframework.boot.ssl.pem;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.core.io.ClassPathResource; import org.springframework.boot.ssl.pem.PemSslStoreDetails.Type;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
/** /**
* Tests for {@link PemContent}. * Tests for {@link PemSslStoreDetails}.
* *
* @author Phillip Webb * @author Moritz Halbritter
*/ */
class PemContentTests { class PemSslStoreDetailsTests {
@Test
void loadWhenContentIsNullReturnsNull() {
assertThat(PemContent.load(null)).isNull();
}
@Test @Test
void loadWhenContentIsPemContentReturnsContent() { void pemContent() {
String content = """ String content = """
-----BEGIN CERTIFICATE----- -----BEGIN CERTIFICATE-----
MIICpDCCAYwCCQCDOqHKPjAhCTANBgkqhkiG9w0BAQUFADAUMRIwEAYDVQQDDAls MIICpDCCAYwCCQCDOqHKPjAhCTANBgkqhkiG9w0BAQUFADAUMRIwEAYDVQQDDAls
@ -57,21 +49,24 @@ class PemContentTests {
+lGuHKdhNOVW9CmqPD1y76o6c8PQKuF7KZEoY2jvy3GeIfddBvqXgZ4PbWvFz1jO +lGuHKdhNOVW9CmqPD1y76o6c8PQKuF7KZEoY2jvy3GeIfddBvqXgZ4PbWvFz1jO
32C9XWHwRA4= 32C9XWHwRA4=
-----END CERTIFICATE-----"""; -----END CERTIFICATE-----""";
assertThat(PemContent.load(content)).isEqualTo(content); PemSslStoreDetails details = new PemSslStoreDetails("JKS", content, content);
assertThat(details.getCertificateType()).isEqualTo(Type.PEM);
assertThat(details.getPrivateKeyType()).isEqualTo(Type.PEM);
} }
@Test @Test
void loadWhenClasspathLocationReturnsContent() throws IOException { void location() {
String actual = PemContent.load("classpath:test-cert.pem"); PemSslStoreDetails details = new PemSslStoreDetails("JKS", "classpath:certificate.pem", "file:privatekey.pem");
String expected = new ClassPathResource("test-cert.pem").getContentAsString(StandardCharsets.UTF_8); assertThat(details.getCertificateType()).isEqualTo(Type.URL);
assertThat(actual).isEqualTo(expected); assertThat(details.getPrivateKeyType()).isEqualTo(Type.URL);
} }
@Test @Test
void loadWhenFileLocationReturnsContent() throws IOException { void empty() {
String actual = PemContent.load("src/test/resources/test-cert.pem"); PemSslStoreDetails details = new PemSslStoreDetails(null, null, null);
String expected = new ClassPathResource("test-cert.pem").getContentAsString(StandardCharsets.UTF_8); assertThat(details.getCertificateType()).isEqualTo(Type.URL);
assertThat(actual).isEqualTo(expected); assertThat(details.getPrivateKeyType()).isEqualTo(Type.URL);
assertThat(details.isEmpty()).isTrue();
} }
} }

@ -34,6 +34,11 @@ import reactor.netty.DisposableServer;
import reactor.netty.http.server.HttpServer; import reactor.netty.http.server.HttpServer;
import reactor.test.StepVerifier; import reactor.test.StepVerifier;
import org.springframework.boot.ssl.DefaultSslBundleRegistry;
import org.springframework.boot.ssl.SslBundle;
import org.springframework.boot.ssl.SslBundles;
import org.springframework.boot.ssl.pem.PemSslStoreBundle;
import org.springframework.boot.ssl.pem.PemSslStoreDetails;
import org.springframework.boot.web.reactive.server.AbstractReactiveWebServerFactory; import org.springframework.boot.web.reactive.server.AbstractReactiveWebServerFactory;
import org.springframework.boot.web.reactive.server.AbstractReactiveWebServerFactoryTests; import org.springframework.boot.web.reactive.server.AbstractReactiveWebServerFactoryTests;
import org.springframework.boot.web.server.PortInUseException; import org.springframework.boot.web.server.PortInUseException;
@ -59,6 +64,7 @@ import static org.mockito.Mockito.mock;
* *
* @author Brian Clozel * @author Brian Clozel
* @author Chris Bono * @author Chris Bono
* @author Moritz Halbritter
*/ */
class NettyReactiveWebServerFactoryTests extends AbstractReactiveWebServerFactoryTests { class NettyReactiveWebServerFactoryTests extends AbstractReactiveWebServerFactoryTests {
@ -132,6 +138,16 @@ class NettyReactiveWebServerFactoryTests extends AbstractReactiveWebServerFactor
StepVerifier.create(result).expectNext("Hello World").expectComplete().verify(Duration.ofSeconds(30)); StepVerifier.create(result).expectNext("Hello World").expectComplete().verify(Duration.ofSeconds(30));
} }
@Test
void whenSslBundleIsUpdatedThenSslIsReloaded() {
DefaultSslBundleRegistry bundles = new DefaultSslBundleRegistry("bundle1", createSslBundle("1.key", "1.crt"));
Mono<String> result = testSslWithBundle(bundles, "bundle1");
StepVerifier.create(result).expectNext("Hello World").expectComplete().verify(Duration.ofSeconds(30));
bundles.updateBundle("bundle1", createSslBundle("2.key", "2.crt"));
Mono<String> result2 = executeSslRequest();
StepVerifier.create(result2).expectNext("Hello World").expectComplete().verify(Duration.ofSeconds(30));
}
@Test @Test
void whenServerIsShuttingDownGracefullyThenNewConnectionsCannotBeMade() { void whenServerIsShuttingDownGracefullyThenNewConnectionsCannotBeMade() {
NettyReactiveWebServerFactory factory = getFactory(); NettyReactiveWebServerFactory factory = getFactory();
@ -161,7 +177,7 @@ class NettyReactiveWebServerFactoryTests extends AbstractReactiveWebServerFactor
protected void startedLogMessageWithMultiplePorts() { protected void startedLogMessageWithMultiplePorts() {
} }
protected Mono<String> testSslWithAlias(String alias) { private Mono<String> testSslWithAlias(String alias) {
String keyStore = "classpath:test.jks"; String keyStore = "classpath:test.jks";
String keyPassword = "password"; String keyPassword = "password";
NettyReactiveWebServerFactory factory = getFactory(); NettyReactiveWebServerFactory factory = getFactory();
@ -172,6 +188,19 @@ class NettyReactiveWebServerFactoryTests extends AbstractReactiveWebServerFactor
factory.setSsl(ssl); factory.setSsl(ssl);
this.webServer = factory.getWebServer(new EchoHandler()); this.webServer = factory.getWebServer(new EchoHandler());
this.webServer.start(); this.webServer.start();
return executeSslRequest();
}
private Mono<String> testSslWithBundle(SslBundles sslBundles, String bundle) {
NettyReactiveWebServerFactory factory = getFactory();
factory.setSslBundles(sslBundles);
factory.setSsl(Ssl.forBundle(bundle));
this.webServer = factory.getWebServer(new EchoHandler());
this.webServer.start();
return executeSslRequest();
}
private Mono<String> executeSslRequest() {
ReactorClientHttpConnector connector = buildTrustAllSslConnector(); ReactorClientHttpConnector connector = buildTrustAllSslConnector();
WebClient client = WebClient.builder() WebClient client = WebClient.builder()
.baseUrl("https://localhost:" + this.webServer.getPort()) .baseUrl("https://localhost:" + this.webServer.getPort())
@ -200,6 +229,13 @@ class NettyReactiveWebServerFactoryTests extends AbstractReactiveWebServerFactor
throw new UnsupportedOperationException("Reactor Netty does not support multiple ports"); throw new UnsupportedOperationException("Reactor Netty does not support multiple ports");
} }
private static SslBundle createSslBundle(String key, String certificate) {
return SslBundle.of(new PemSslStoreBundle(
new PemSslStoreDetails(null, "classpath:org/springframework/boot/web/embedded/netty/" + certificate,
"classpath:org/springframework/boot/web/embedded/netty/" + key),
null));
}
static class NoPortNettyReactiveWebServerFactory extends NettyReactiveWebServerFactory { static class NoPortNettyReactiveWebServerFactory extends NettyReactiveWebServerFactory {
NoPortNettyReactiveWebServerFactory(int port) { NoPortNettyReactiveWebServerFactory(int port) {

@ -32,6 +32,9 @@ import java.util.concurrent.atomic.AtomicReference;
import javax.naming.InitialContext; import javax.naming.InitialContext;
import javax.naming.NamingException; import javax.naming.NamingException;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import jakarta.servlet.MultipartConfigElement; import jakarta.servlet.MultipartConfigElement;
import jakarta.servlet.ServletContext; import jakarta.servlet.ServletContext;
@ -60,8 +63,11 @@ import org.apache.coyote.http11.AbstractHttp11Protocol;
import org.apache.hc.client5.http.HttpHostConnectException; import org.apache.hc.client5.http.HttpHostConnectException;
import org.apache.hc.client5.http.classic.HttpClient; import org.apache.hc.client5.http.classic.HttpClient;
import org.apache.hc.client5.http.impl.classic.HttpClients; import org.apache.hc.client5.http.impl.classic.HttpClients;
import org.apache.hc.client5.http.ssl.SSLConnectionSocketFactory;
import org.apache.hc.client5.http.ssl.TrustSelfSignedStrategy;
import org.apache.hc.core5.http.HttpResponse; import org.apache.hc.core5.http.HttpResponse;
import org.apache.hc.core5.http.NoHttpResponseException; import org.apache.hc.core5.http.NoHttpResponseException;
import org.apache.hc.core5.ssl.SSLContextBuilder;
import org.apache.jasper.servlet.JspServlet; import org.apache.jasper.servlet.JspServlet;
import org.apache.tomcat.JarScanFilter; import org.apache.tomcat.JarScanFilter;
import org.apache.tomcat.JarScanType; import org.apache.tomcat.JarScanType;
@ -73,9 +79,11 @@ import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.InOrder; import org.mockito.InOrder;
import org.springframework.boot.ssl.DefaultSslBundleRegistry;
import org.springframework.boot.testsupport.system.CapturedOutput; import org.springframework.boot.testsupport.system.CapturedOutput;
import org.springframework.boot.web.server.PortInUseException; import org.springframework.boot.web.server.PortInUseException;
import org.springframework.boot.web.server.Shutdown; import org.springframework.boot.web.server.Shutdown;
import org.springframework.boot.web.server.Ssl;
import org.springframework.boot.web.server.WebServerException; import org.springframework.boot.web.server.WebServerException;
import org.springframework.boot.web.servlet.server.AbstractServletWebServerFactory; import org.springframework.boot.web.servlet.server.AbstractServletWebServerFactory;
import org.springframework.boot.web.servlet.server.AbstractServletWebServerFactoryTests; import org.springframework.boot.web.servlet.server.AbstractServletWebServerFactoryTests;
@ -87,6 +95,7 @@ import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.util.FileSystemUtils; import org.springframework.util.FileSystemUtils;
import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
@ -107,6 +116,7 @@ import static org.mockito.Mockito.mock;
* @author Phillip Webb * @author Phillip Webb
* @author Dave Syer * @author Dave Syer
* @author Stephane Nicoll * @author Stephane Nicoll
* @author Moritz Halbritter
*/ */
class TomcatServletWebServerFactoryTests extends AbstractServletWebServerFactoryTests { class TomcatServletWebServerFactoryTests extends AbstractServletWebServerFactoryTests {
@ -636,6 +646,30 @@ class TomcatServletWebServerFactoryTests extends AbstractServletWebServerFactory
this.webServer.stop(); this.webServer.stop();
} }
@Test
void shouldUpdateSslWhenReloadingSslBundles() throws Exception {
TomcatServletWebServerFactory factory = getFactory();
addTestTxtFile(factory);
DefaultSslBundleRegistry bundles = new DefaultSslBundleRegistry("test",
createPemSslBundle("classpath:org/springframework/boot/web/embedded/tomcat/1.crt",
"classpath:org/springframework/boot/web/embedded/tomcat/1.key"));
factory.setSslBundles(bundles);
factory.setSsl(Ssl.forBundle("test"));
this.webServer = factory.getWebServer();
this.webServer.start();
RememberingHostnameVerifier verifier = new RememberingHostnameVerifier();
SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory(
new SSLContextBuilder().loadTrustMaterial(null, new TrustSelfSignedStrategy()).build(), verifier);
HttpComponentsClientHttpRequestFactory requestFactory = createHttpComponentsRequestFactory(socketFactory);
assertThat(getResponse(getLocalUrl("https", "/test.txt"), requestFactory)).isEqualTo("test");
assertThat(verifier.getLastPrincipal()).isEqualTo("CN=1");
requestFactory = createHttpComponentsRequestFactory(socketFactory);
bundles.updateBundle("test", createPemSslBundle("classpath:org/springframework/boot/web/embedded/tomcat/2.crt",
"classpath:org/springframework/boot/web/embedded/tomcat/2.key"));
assertThat(getResponse(getLocalUrl("https", "/test.txt"), requestFactory)).isEqualTo("test");
assertThat(verifier.getLastPrincipal()).isEqualTo("CN=2");
}
@Override @Override
protected JspServlet getJspServlet() throws ServletException { protected JspServlet getJspServlet() throws ServletException {
Tomcat tomcat = ((TomcatWebServer) this.webServer).getTomcat(); Tomcat tomcat = ((TomcatWebServer) this.webServer).getTomcat();
@ -694,4 +728,25 @@ class TomcatServletWebServerFactoryTests extends AbstractServletWebServerFactory
return ((TomcatWebServer) this.webServer).getStartedLogMessage(); return ((TomcatWebServer) this.webServer).getStartedLogMessage();
} }
private static class RememberingHostnameVerifier implements HostnameVerifier {
private volatile String lastPrincipal;
@Override
public boolean verify(String hostname, SSLSession session) {
try {
this.lastPrincipal = session.getPeerPrincipal().getName();
}
catch (SSLPeerUnverifiedException ex) {
throw new RuntimeException(ex);
}
return true;
}
String getLastPrincipal() {
return this.lastPrincipal;
}
}
} }

@ -789,7 +789,7 @@ public abstract class AbstractServletWebServerFactoryTests {
return new JksSslStoreDetails(getStoreType(location), null, location, "secret"); return new JksSslStoreDetails(getStoreType(location), null, location, "secret");
} }
private SslBundle createPemSslBundle(String cert, String privateKey) { protected SslBundle createPemSslBundle(String cert, String privateKey) {
PemSslStoreDetails keyStoreDetails = PemSslStoreDetails.forCertificate(cert).withPrivateKey(privateKey); PemSslStoreDetails keyStoreDetails = PemSslStoreDetails.forCertificate(cert).withPrivateKey(privateKey);
PemSslStoreDetails trustStoreDetails = PemSslStoreDetails.forCertificate(cert); PemSslStoreDetails trustStoreDetails = PemSslStoreDetails.forCertificate(cert);
SslStoreBundle stores = new PemSslStoreBundle(keyStoreDetails, trustStoreDetails); SslStoreBundle stores = new PemSslStoreBundle(keyStoreDetails, trustStoreDetails);
@ -807,14 +807,13 @@ public abstract class AbstractServletWebServerFactoryTests {
assertThat(getResponse(getLocalUrl("https", "/hello"), requestFactory)).contains("scheme=https"); assertThat(getResponse(getLocalUrl("https", "/hello"), requestFactory)).contains("scheme=https");
} }
private HttpComponentsClientHttpRequestFactory createHttpComponentsRequestFactory( protected HttpComponentsClientHttpRequestFactory createHttpComponentsRequestFactory(
SSLConnectionSocketFactory socketFactory) { SSLConnectionSocketFactory socketFactory) {
PoolingHttpClientConnectionManager connectionManager = PoolingHttpClientConnectionManagerBuilder.create() PoolingHttpClientConnectionManager connectionManager = PoolingHttpClientConnectionManagerBuilder.create()
.setSSLSocketFactory(socketFactory) .setSSLSocketFactory(socketFactory)
.build(); .build();
HttpClient httpClient = this.httpClientBuilder.get().setConnectionManager(connectionManager).build(); HttpClient httpClient = this.httpClientBuilder.get().setConnectionManager(connectionManager).build();
HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory(httpClient); return new HttpComponentsClientHttpRequestFactory(httpClient);
return requestFactory;
} }
private String getStoreType(String keyStore) { private String getStoreType(String keyStore) {
@ -1457,7 +1456,7 @@ public abstract class AbstractServletWebServerFactoryTests {
protected abstract Charset getCharset(Locale locale); protected abstract Charset getCharset(Locale locale);
private void addTestTxtFile(AbstractServletWebServerFactory factory) throws IOException { protected void addTestTxtFile(AbstractServletWebServerFactory factory) throws IOException {
FileCopyUtils.copy("test", new FileWriter(new File(this.tempDir, "test.txt"))); FileCopyUtils.copy("test", new FileWriter(new File(this.tempDir, "test.txt")));
factory.setDocumentRoot(this.tempDir); factory.setDocumentRoot(this.tempDir);
factory.setRegisterDefaultServlet(true); factory.setRegisterDefaultServlet(true);

@ -0,0 +1,9 @@
-----BEGIN CERTIFICATE-----
MIIBLjCB4aADAgECAhQ25wrNnapZEkFc8kgf5NDHXKxnTzAFBgMrZXAwDDEKMAgG
A1UEAwwBMTAgFw0yMzEwMTAwODU1MTJaGA8yMTIzMDkxNjA4NTUxMlowDDEKMAgG
A1UEAwwBMTAqMAUGAytlcAMhAOyxNxHzcNj7xTkcjVLI09sYUGUGIvdV5s0YWXT8
XAiwo1MwUTAdBgNVHQ4EFgQUmm23oLIu5MgdBb/snZSuE+MrRZ0wHwYDVR0jBBgw
FoAUmm23oLIu5MgdBb/snZSuE+MrRZ0wDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXAD
QQA2KMpIyySC8u4onW2MVW1iK2dJJZbMRaNMLlQuE+ZIHQLwflYW4sH/Pp76pboc
QhqKXcO7xH7f2tD5hE2izcUB
-----END CERTIFICATE-----

@ -0,0 +1,3 @@
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIJb1A+i5bmilBD9mUbhk1oFVI6FAZQGnhduv7xV6WWEc
-----END PRIVATE KEY-----

@ -0,0 +1,9 @@
-----BEGIN CERTIFICATE-----
MIIBLjCB4aADAgECAhR4TMDk3qg5sKREp16lEHR3bV3M9zAFBgMrZXAwDDEKMAgG
A1UEAwwBMjAgFw0yMzEwMTAwODU1MjBaGA8yMTIzMDkxNjA4NTUyMFowDDEKMAgG
A1UEAwwBMjAqMAUGAytlcAMhADPft6hzyCjHCe5wSprChuuO/CuPIJ2t+l4roS1D
43/wo1MwUTAdBgNVHQ4EFgQUfrRibAWml4Ous4kpnBIggM2xnLcwHwYDVR0jBBgw
FoAUfrRibAWml4Ous4kpnBIggM2xnLcwDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXAD
QQC/MOclal2Cp0B3kmaLbK0M8mapclIOJa78hzBkqPA3URClAF2GmF187wHqi7qV
+xZ+KWv26pLJR46vk8Kc6ZIO
-----END CERTIFICATE-----

@ -0,0 +1,3 @@
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEICxhres2Z2lICm7/isnm+2iNR12GmgG7KK86BNDZDeIF
-----END PRIVATE KEY-----

@ -0,0 +1,9 @@
-----BEGIN CERTIFICATE-----
MIIBLjCB4aADAgECAhQ25wrNnapZEkFc8kgf5NDHXKxnTzAFBgMrZXAwDDEKMAgG
A1UEAwwBMTAgFw0yMzEwMTAwODU1MTJaGA8yMTIzMDkxNjA4NTUxMlowDDEKMAgG
A1UEAwwBMTAqMAUGAytlcAMhAOyxNxHzcNj7xTkcjVLI09sYUGUGIvdV5s0YWXT8
XAiwo1MwUTAdBgNVHQ4EFgQUmm23oLIu5MgdBb/snZSuE+MrRZ0wHwYDVR0jBBgw
FoAUmm23oLIu5MgdBb/snZSuE+MrRZ0wDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXAD
QQA2KMpIyySC8u4onW2MVW1iK2dJJZbMRaNMLlQuE+ZIHQLwflYW4sH/Pp76pboc
QhqKXcO7xH7f2tD5hE2izcUB
-----END CERTIFICATE-----

@ -0,0 +1,3 @@
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIJb1A+i5bmilBD9mUbhk1oFVI6FAZQGnhduv7xV6WWEc
-----END PRIVATE KEY-----

@ -0,0 +1,9 @@
-----BEGIN CERTIFICATE-----
MIIBLjCB4aADAgECAhR4TMDk3qg5sKREp16lEHR3bV3M9zAFBgMrZXAwDDEKMAgG
A1UEAwwBMjAgFw0yMzEwMTAwODU1MjBaGA8yMTIzMDkxNjA4NTUyMFowDDEKMAgG
A1UEAwwBMjAqMAUGAytlcAMhADPft6hzyCjHCe5wSprChuuO/CuPIJ2t+l4roS1D
43/wo1MwUTAdBgNVHQ4EFgQUfrRibAWml4Ous4kpnBIggM2xnLcwHwYDVR0jBBgw
FoAUfrRibAWml4Ous4kpnBIggM2xnLcwDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXAD
QQC/MOclal2Cp0B3kmaLbK0M8mapclIOJa78hzBkqPA3URClAF2GmF187wHqi7qV
+xZ+KWv26pLJR46vk8Kc6ZIO
-----END CERTIFICATE-----

@ -0,0 +1,3 @@
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEICxhres2Z2lICm7/isnm+2iNR12GmgG7KK86BNDZDeIF
-----END PRIVATE KEY-----
Loading…
Cancel
Save