Polish "Allow custom RSocket WebsocketServerSpecs to be defined"

See gh-29567
pull/36604/head
Stephane Nicoll 1 year ago
parent f840141652
commit b0438b0f03

@ -1,5 +1,5 @@
/* /*
* Copyright 2012-2020 the original author or authors. * Copyright 2012-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -73,7 +73,7 @@ public class RSocketProperties {
@NestedConfigurationProperty @NestedConfigurationProperty
private Ssl ssl; private Ssl ssl;
private Spec spec = new Spec(); private final Spec spec = new Spec();
public Integer getPort() { public Integer getPort() {
return this.port; return this.port;
@ -127,10 +127,6 @@ public class RSocketProperties {
return this.spec; return this.spec;
} }
public void setSpec(Spec spec) {
this.spec = spec;
}
public static class Spec { public static class Spec {
/** /**
@ -139,19 +135,17 @@ public class RSocketProperties {
private String protocols; private String protocols;
/** /**
* Specifies a custom maximum allowable frame payload length. 65536 by * Maximum allowable frame payload length.
* default.
*/ */
private int maxFramePayloadLength = 65536; private DataSize maxFramePayloadLength = DataSize.ofBytes(65536);
/** /**
* Flag whether to proxy websocket ping frames or respond to them. * Whether to proxy websocket ping frames or respond to them.
*/ */
private boolean handlePing; private boolean handlePing;
/** /**
* Flag whether the websocket compression extension is enabled if the client * Whether the websocket compression extension is enabled.
* request presents websocket extensions headers.
*/ */
private boolean compress; private boolean compress;
@ -163,11 +157,11 @@ public class RSocketProperties {
this.protocols = protocols; this.protocols = protocols;
} }
public int getMaxFramePayloadLength() { public DataSize getMaxFramePayloadLength() {
return this.maxFramePayloadLength; return this.maxFramePayloadLength;
} }
public void setMaxFramePayloadLength(int maxFramePayloadLength) { public void setMaxFramePayloadLength(DataSize maxFramePayloadLength) {
this.maxFramePayloadLength = maxFramePayloadLength; this.maxFramePayloadLength = maxFramePayloadLength;
} }

@ -16,10 +16,13 @@
package org.springframework.boot.autoconfigure.rsocket; package org.springframework.boot.autoconfigure.rsocket;
import java.util.function.Consumer;
import io.rsocket.core.RSocketServer; import io.rsocket.core.RSocketServer;
import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.frame.decoder.PayloadDecoder;
import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.transport.netty.server.TcpServerTransport;
import reactor.netty.http.server.HttpServer; import reactor.netty.http.server.HttpServer;
import reactor.netty.http.server.WebsocketServerSpec.Builder;
import org.springframework.beans.factory.ObjectProvider; import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfiguration;
@ -31,6 +34,7 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication; import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication;
import org.springframework.boot.autoconfigure.reactor.netty.ReactorNettyConfigurations; import org.springframework.boot.autoconfigure.reactor.netty.ReactorNettyConfigurations;
import org.springframework.boot.autoconfigure.rsocket.RSocketProperties.Server.Spec;
import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.context.properties.PropertyMapper; import org.springframework.boot.context.properties.PropertyMapper;
import org.springframework.boot.rsocket.context.RSocketServerBootstrap; import org.springframework.boot.rsocket.context.RSocketServerBootstrap;
@ -46,6 +50,7 @@ import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.client.reactive.ReactorResourceFactory; import org.springframework.http.client.reactive.ReactorResourceFactory;
import org.springframework.messaging.rsocket.RSocketStrategies; import org.springframework.messaging.rsocket.RSocketStrategies;
import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler; import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler;
import org.springframework.util.unit.DataSize;
/** /**
* {@link EnableAutoConfiguration Auto-configuration} for RSocket servers. In the case of * {@link EnableAutoConfiguration Auto-configuration} for RSocket servers. In the case of
@ -73,7 +78,18 @@ public class RSocketServerAutoConfiguration {
RSocketWebSocketNettyRouteProvider rSocketWebsocketRouteProvider(RSocketProperties properties, RSocketWebSocketNettyRouteProvider rSocketWebsocketRouteProvider(RSocketProperties properties,
RSocketMessageHandler messageHandler, ObjectProvider<RSocketServerCustomizer> customizers) { RSocketMessageHandler messageHandler, ObjectProvider<RSocketServerCustomizer> customizers) {
return new RSocketWebSocketNettyRouteProvider(properties.getServer().getMappingPath(), return new RSocketWebSocketNettyRouteProvider(properties.getServer().getMappingPath(),
properties.getServer().getSpec(), messageHandler.responder(), customizers.orderedStream()); messageHandler.responder(), customizeWebsocketServerSpec(properties.getServer().getSpec()),
customizers.orderedStream());
}
private Consumer<Builder> customizeWebsocketServerSpec(Spec spec) {
return (builder) -> {
PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull();
map.from(spec.getProtocols()).to(builder::protocols);
map.from(spec.getMaxFramePayloadLength()).asInt(DataSize::toBytes).to(builder::maxFramePayloadLength);
map.from(spec.isHandlePing()).to(builder::handlePing);
map.from(spec.isCompress()).to(builder::compress);
};
} }
} }

@ -1,5 +1,5 @@
/* /*
* Copyright 2012-2022 the original author or authors. * Copyright 2012-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,6 +17,7 @@
package org.springframework.boot.autoconfigure.rsocket; package org.springframework.boot.autoconfigure.rsocket;
import java.util.List; import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Stream; import java.util.stream.Stream;
import io.rsocket.SocketAcceptor; import io.rsocket.SocketAcceptor;
@ -25,8 +26,8 @@ import io.rsocket.transport.ServerTransport;
import io.rsocket.transport.netty.server.WebsocketRouteTransport; import io.rsocket.transport.netty.server.WebsocketRouteTransport;
import reactor.netty.http.server.HttpServerRoutes; import reactor.netty.http.server.HttpServerRoutes;
import reactor.netty.http.server.WebsocketServerSpec; import reactor.netty.http.server.WebsocketServerSpec;
import reactor.netty.http.server.WebsocketServerSpec.Builder;
import org.springframework.boot.autoconfigure.rsocket.RSocketProperties.Server.Spec;
import org.springframework.boot.rsocket.server.RSocketServerCustomizer; import org.springframework.boot.rsocket.server.RSocketServerCustomizer;
import org.springframework.boot.web.embedded.netty.NettyRouteProvider; import org.springframework.boot.web.embedded.netty.NettyRouteProvider;
@ -44,14 +45,14 @@ class RSocketWebSocketNettyRouteProvider implements NettyRouteProvider {
private final List<RSocketServerCustomizer> customizers; private final List<RSocketServerCustomizer> customizers;
private final Spec spec; private final Consumer<Builder> serverSpecCustomizer;
RSocketWebSocketNettyRouteProvider(String mappingPath, Spec spec, SocketAcceptor socketAcceptor, RSocketWebSocketNettyRouteProvider(String mappingPath, SocketAcceptor socketAcceptor,
Stream<RSocketServerCustomizer> customizers) { Consumer<Builder> serverSpecCustomizer, Stream<RSocketServerCustomizer> customizers) {
this.mappingPath = mappingPath; this.mappingPath = mappingPath;
this.socketAcceptor = socketAcceptor; this.socketAcceptor = socketAcceptor;
this.serverSpecCustomizer = serverSpecCustomizer;
this.customizers = customizers.toList(); this.customizers = customizers.toList();
this.spec = spec;
} }
@Override @Override
@ -59,11 +60,14 @@ class RSocketWebSocketNettyRouteProvider implements NettyRouteProvider {
RSocketServer server = RSocketServer.create(this.socketAcceptor); RSocketServer server = RSocketServer.create(this.socketAcceptor);
this.customizers.forEach((customizer) -> customizer.customize(server)); this.customizers.forEach((customizer) -> customizer.customize(server));
ServerTransport.ConnectionAcceptor connectionAcceptor = server.asConnectionAcceptor(); ServerTransport.ConnectionAcceptor connectionAcceptor = server.asConnectionAcceptor();
WebsocketServerSpec.Builder build = (this.spec.getProtocols() == null) ? WebsocketServerSpec.builder()
: WebsocketServerSpec.builder().protocols(this.spec.getProtocols());
return httpServerRoutes.ws(this.mappingPath, WebsocketRouteTransport.newHandler(connectionAcceptor), return httpServerRoutes.ws(this.mappingPath, WebsocketRouteTransport.newHandler(connectionAcceptor),
build.maxFramePayloadLength(this.spec.getMaxFramePayloadLength()).handlePing(this.spec.isHandlePing()) createWebsocketServerSpec());
.compress(this.spec.isCompress()).build()); }
private WebsocketServerSpec createWebsocketServerSpec() {
WebsocketServerSpec.Builder builder = WebsocketServerSpec.builder();
this.serverSpecCustomizer.accept(builder);
return builder.build();
} }
} }

@ -0,0 +1,43 @@
/*
* Copyright 2012-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.autoconfigure.rsocket;
import org.junit.jupiter.api.Test;
import reactor.netty.http.server.WebsocketServerSpec;
import org.springframework.boot.autoconfigure.rsocket.RSocketProperties.Server.Spec;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Tests for {@link RSocketProperties}.
*
* @author Stephane Nicoll
*/
class RSocketPropertiesTests {
@Test
void defaultServerSpecValuesAreConsistent() {
WebsocketServerSpec spec = WebsocketServerSpec.builder().build();
Spec properties = new RSocketProperties().getServer().getSpec();
assertThat(properties.getProtocols()).isEqualTo(spec.protocols());
assertThat(properties.getMaxFramePayloadLength().toBytes()).isEqualTo(spec.maxFramePayloadLength());
assertThat(properties.isHandlePing()).isEqualTo(spec.handlePing());
assertThat(properties.isCompress()).isEqualTo(spec.compress());
}
}
Loading…
Cancel
Save