From 2123b267aa2f1fd5406d69364d669e8541bdc491 Mon Sep 17 00:00:00 2001 From: Phillip Webb Date: Mon, 1 Jun 2015 13:23:54 -0700 Subject: [PATCH] Add HTTP tunnel support Add server and client components to support tunneling of binary TCP protocols over HTTP. Primarily designed to support Java's remote debug protocol (JDWP). See gh-3087 --- .../tunnel/client/HttpTunnelConnection.java | 216 ++++++++ .../tunnel/client/TunnelClient.java | 207 ++++++++ .../tunnel/client/TunnelClientListener.java | 41 ++ .../tunnel/client/TunnelClientListeners.java | 56 ++ .../tunnel/client/TunnelConnection.java | 42 ++ .../tunnel/client/package-info.java | 21 + .../developertools/tunnel/package-info.java | 23 + .../tunnel/payload/HttpTunnelPayload.java | 185 +++++++ .../payload/HttpTunnelPayloadForwarder.java | 69 +++ .../tunnel/payload/package-info.java | 21 + .../tunnel/server/HttpTunnelServer.java | 486 ++++++++++++++++++ .../server/HttpTunnelServerHandler.java | 51 ++ .../tunnel/server/PortProvider.java | 34 ++ .../server/RemoteDebugPortProvider.java | 61 +++ .../server/SocketTargetServerConnection.java | 101 ++++ .../tunnel/server/StaticPortProvider.java | 41 ++ .../tunnel/server/TargetServerConnection.java | 38 ++ .../tunnel/server/package-info.java | 21 + .../client/HttpTunnelConnectionTests.java | 166 ++++++ .../tunnel/client/TunnelClientTests.java | 199 +++++++ .../HttpTunnelPayloadForwarderTests.java | 85 +++ .../payload/HttpTunnelPayloadTests.java | 151 ++++++ .../server/HttpTunnelServerHandlerTests.java | 55 ++ .../tunnel/server/HttpTunnelServerTests.java | 480 +++++++++++++++++ .../SocketTargetServerConnectionTests.java | 178 +++++++ .../server/StaticPortProviderTests.java | 49 ++ 26 files changed, 3077 insertions(+) create mode 100644 spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/HttpTunnelConnection.java create mode 100644 spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/TunnelClient.java create mode 100644 spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/TunnelClientListener.java create mode 100644 spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/TunnelClientListeners.java create mode 100644 spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/TunnelConnection.java create mode 100644 spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/package-info.java create mode 100644 spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/package-info.java create mode 100644 spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/payload/HttpTunnelPayload.java create mode 100644 spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/payload/HttpTunnelPayloadForwarder.java create mode 100644 spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/payload/package-info.java create mode 100644 spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/HttpTunnelServer.java create mode 100644 spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/HttpTunnelServerHandler.java create mode 100644 spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/PortProvider.java create mode 100644 spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/RemoteDebugPortProvider.java create mode 100644 spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/SocketTargetServerConnection.java create mode 100644 spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/StaticPortProvider.java create mode 100644 spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/TargetServerConnection.java create mode 100644 spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/package-info.java create mode 100644 spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/client/HttpTunnelConnectionTests.java create mode 100644 spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/client/TunnelClientTests.java create mode 100644 spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/payload/HttpTunnelPayloadForwarderTests.java create mode 100644 spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/payload/HttpTunnelPayloadTests.java create mode 100644 spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/server/HttpTunnelServerHandlerTests.java create mode 100644 spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/server/HttpTunnelServerTests.java create mode 100644 spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/server/SocketTargetServerConnectionTests.java create mode 100644 spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/server/StaticPortProviderTests.java diff --git a/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/HttpTunnelConnection.java b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/HttpTunnelConnection.java new file mode 100644 index 0000000000..67a325603d --- /dev/null +++ b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/HttpTunnelConnection.java @@ -0,0 +1,216 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.client; + +import java.io.Closeable; +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URL; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicLong; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.boot.developertools.tunnel.payload.HttpTunnelPayload; +import org.springframework.boot.developertools.tunnel.payload.HttpTunnelPayloadForwarder; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.util.Assert; + +/** + * {@link TunnelConnection} implementation that uses HTTP to transfer data. + * + * @author Phillip Webb + * @author Rob Winch + * @since 1.3.0 + * @see TunnelClient + * @see org.springframework.boot.developertools.tunnel.server.HttpTunnelServer + */ +public class HttpTunnelConnection implements TunnelConnection { + + private static Log logger = LogFactory.getLog(HttpTunnelConnection.class); + + private final URI uri; + + private final ClientHttpRequestFactory requestFactory; + + private final Executor executor; + + /** + * Create a new {@link HttpTunnelConnection} instance. + * @param url the URL to connect to + * @param requestFactory the HTTP request factory + */ + public HttpTunnelConnection(String url, ClientHttpRequestFactory requestFactory) { + this(url, requestFactory, null); + } + + /** + * Create a new {@link HttpTunnelConnection} instance. + * @param url the URL to connect to + * @param requestFactory the HTTP request factory + * @param executor the executor used to handle connections + */ + protected HttpTunnelConnection(String url, ClientHttpRequestFactory requestFactory, + Executor executor) { + Assert.hasLength(url, "URL must not be empty"); + Assert.notNull(requestFactory, "RequestFactory must not be null"); + try { + this.uri = new URL(url).toURI(); + } + catch (URISyntaxException ex) { + throw new IllegalArgumentException("Malformed URL '" + url + "'"); + } + catch (MalformedURLException ex) { + throw new IllegalArgumentException("Malformed URL '" + url + "'"); + } + this.requestFactory = requestFactory; + this.executor = (executor == null ? Executors + .newCachedThreadPool(new TunnelThreadFactory()) : executor); + } + + @Override + public TunnelChannel open(WritableByteChannel incomingChannel, Closeable closeable) + throws Exception { + logger.trace("Opening HTTP tunnel to " + this.uri); + return new TunnelChannel(incomingChannel, closeable); + } + + protected final ClientHttpRequest createRequest(boolean hasPayload) + throws IOException { + HttpMethod method = (hasPayload ? HttpMethod.POST : HttpMethod.GET); + return this.requestFactory.createRequest(this.uri, method); + } + + /** + * A {@link WritableByteChannel} used to transfer traffic. + */ + protected class TunnelChannel implements WritableByteChannel { + + private final HttpTunnelPayloadForwarder forwarder; + + private final Closeable closeable; + + private boolean open = true; + + private AtomicLong requestSeq = new AtomicLong(); + + public TunnelChannel(WritableByteChannel incomingChannel, Closeable closeable) { + this.forwarder = new HttpTunnelPayloadForwarder(incomingChannel); + this.closeable = closeable; + openNewConnection(null); + } + + @Override + public boolean isOpen() { + return this.open; + } + + @Override + public void close() throws IOException { + if (this.open) { + this.open = false; + this.closeable.close(); + } + } + + @Override + public int write(ByteBuffer src) throws IOException { + int size = src.remaining(); + if (size > 0) { + openNewConnection(new HttpTunnelPayload( + this.requestSeq.incrementAndGet(), src)); + } + return size; + } + + private synchronized void openNewConnection(final HttpTunnelPayload payload) { + HttpTunnelConnection.this.executor.execute(new Runnable() { + + @Override + public void run() { + try { + sendAndReceive(payload); + } + catch (IOException ex) { + logger.trace("Unexpected connection error", ex); + closeQuitely(); + } + } + + private void closeQuitely() { + try { + close(); + } + catch (IOException ex) { + } + } + + }); + } + + private void sendAndReceive(HttpTunnelPayload payload) throws IOException { + ClientHttpRequest request = createRequest(payload != null); + if (payload != null) { + payload.logIncoming(); + payload.assignTo(request); + } + handleResponse(request.execute()); + } + + private void handleResponse(ClientHttpResponse response) throws IOException { + if (response.getStatusCode() == HttpStatus.GONE) { + close(); + return; + } + if (response.getStatusCode() == HttpStatus.OK) { + HttpTunnelPayload payload = HttpTunnelPayload.get(response); + if (payload != null) { + this.forwarder.forward(payload); + } + } + if (response.getStatusCode() != HttpStatus.TOO_MANY_REQUESTS) { + openNewConnection(null); + } + } + + } + + /** + * {@link ThreadFactory} used to create the tunnel thread. + */ + private static class TunnelThreadFactory implements ThreadFactory { + + @Override + public Thread newThread(Runnable runnable) { + Thread thread = new Thread(runnable, "HTTP Tunnel Connection"); + thread.setDaemon(true); + return thread; + } + + } + +} diff --git a/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/TunnelClient.java b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/TunnelClient.java new file mode 100644 index 0000000000..85e5bb13c9 --- /dev/null +++ b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/TunnelClient.java @@ -0,0 +1,207 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.client; + +import java.io.Closeable; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.nio.ByteBuffer; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.nio.channels.WritableByteChannel; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.beans.factory.SmartInitializingSingleton; +import org.springframework.util.Assert; + +/** + * The client side component of a socket tunnel. Starts a {@link ServerSocket} of the + * specified port for local clients to connect to. + * + * @author Phillip Webb + * @since 1.3.0 + */ +public class TunnelClient implements SmartInitializingSingleton { + + private static final int BUFFER_SIZE = 1024 * 100; + + private static final Log logger = LogFactory.getLog(TunnelClient.class); + + private final int listenPort; + + private final TunnelConnection tunnelConnection; + + private TunnelClientListeners listeners = new TunnelClientListeners(); + + private ServerThread serverThread; + + public TunnelClient(int listenPort, TunnelConnection tunnelConnection) { + Assert.isTrue(listenPort > 0, "ListenPort must be positive"); + Assert.notNull(tunnelConnection, "TunnelConnection must not be null"); + this.listenPort = listenPort; + this.tunnelConnection = tunnelConnection; + } + + @Override + public void afterSingletonsInstantiated() { + if (this.serverThread == null) { + try { + start(); + } + catch (IOException ex) { + throw new IllegalStateException(ex); + } + } + } + + /** + * Start the client and accept incoming connections on the port. + * @throws IOException + */ + public synchronized void start() throws IOException { + Assert.state(this.serverThread == null, "Server already started"); + ServerSocketChannel serverSocketChannel = ServerSocketChannel.open(); + serverSocketChannel.socket().bind(new InetSocketAddress(this.listenPort)); + logger.trace("Listening for TCP traffic to tunnel on port " + this.listenPort); + this.serverThread = new ServerThread(serverSocketChannel); + this.serverThread.start(); + } + + /** + * Stop the client, disconnecting any servers. + * @throws IOException + */ + public synchronized void stop() throws IOException { + if (this.serverThread != null) { + logger.trace("Closing tunnel client on port " + this.listenPort); + this.serverThread.close(); + try { + this.serverThread.join(2000); + } + catch (InterruptedException ex) { + } + this.serverThread = null; + } + } + + protected final ServerThread getServerThread() { + return this.serverThread; + } + + public void addListener(TunnelClientListener listener) { + this.listeners.addListener(listener); + } + + public void removeListener(TunnelClientListener listener) { + this.listeners.removeListener(listener); + } + + /** + * The main server thread. + */ + protected class ServerThread extends Thread { + + private final ServerSocketChannel serverSocketChannel; + + private boolean acceptConnections = true; + + public ServerThread(ServerSocketChannel serverSocketChannel) { + this.serverSocketChannel = serverSocketChannel; + setName("Tunnel Server"); + setDaemon(true); + } + + public void close() throws IOException { + this.serverSocketChannel.close(); + this.acceptConnections = false; + interrupt(); + } + + @Override + public void run() { + try { + while (this.acceptConnections) { + SocketChannel socket = this.serverSocketChannel.accept(); + try { + handleConnection(socket); + } + finally { + socket.close(); + } + } + } + catch (Exception ex) { + logger.trace("Unexpected exception from tunnel client", ex); + } + } + + private void handleConnection(SocketChannel socketChannel) throws Exception { + Closeable closeable = new SocketCloseable(socketChannel); + WritableByteChannel outputChannel = TunnelClient.this.tunnelConnection.open( + socketChannel, closeable); + TunnelClient.this.listeners.fireOpenEvent(socketChannel); + try { + logger.trace("Accepted connection to tunnel client from " + + socketChannel.socket().getRemoteSocketAddress()); + while (true) { + ByteBuffer buffer = ByteBuffer.allocate(BUFFER_SIZE); + int amountRead = socketChannel.read(buffer); + if (amountRead == -1) { + outputChannel.close(); + return; + } + if (amountRead > 0) { + buffer.flip(); + outputChannel.write(buffer); + } + } + } + finally { + outputChannel.close(); + } + } + + protected void stopAcceptingConnections() { + this.acceptConnections = false; + } + } + + /** + * {@link Closeable} used to close a {@link SocketChannel} and fire an event. + */ + private class SocketCloseable implements Closeable { + + private final SocketChannel socketChannel; + + private boolean closed = false; + + public SocketCloseable(SocketChannel socketChannel) { + this.socketChannel = socketChannel; + } + + @Override + public void close() throws IOException { + if (!this.closed) { + this.socketChannel.close(); + TunnelClient.this.listeners.fireCloseEvent(this.socketChannel); + this.closed = true; + } + } + } +} diff --git a/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/TunnelClientListener.java b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/TunnelClientListener.java new file mode 100644 index 0000000000..af7e7af634 --- /dev/null +++ b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/TunnelClientListener.java @@ -0,0 +1,41 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.client; + +import java.nio.channels.SocketChannel; + +/** + * Listener that can be used to receive {@link TunnelClient} events. + * + * @author Phillip Webb + * @since 1.3.0 + */ +public interface TunnelClientListener { + + /** + * Called when a socket channel is opened. + * @param socket the socket channel + */ + void onOpen(SocketChannel socket); + + /** + * Called when a socket channel is closed. + * @param socket the socket channel + */ + void onClose(SocketChannel socket); + +} diff --git a/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/TunnelClientListeners.java b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/TunnelClientListeners.java new file mode 100644 index 0000000000..dc5b33e3b4 --- /dev/null +++ b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/TunnelClientListeners.java @@ -0,0 +1,56 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.client; + +import java.nio.channels.SocketChannel; +import java.util.ArrayList; +import java.util.List; + +import org.springframework.util.Assert; + +/** + * A collection of {@link TunnelClientListener}. + * + * @author Phillip Webb + */ +class TunnelClientListeners { + + private final List listeners = new ArrayList(); + + public void addListener(TunnelClientListener listener) { + Assert.notNull(listener, "Listener must not be null"); + this.listeners.add(listener); + } + + public void removeListener(TunnelClientListener listener) { + Assert.notNull(listener, "Listener must not be null"); + this.listeners.remove(listener); + } + + public void fireOpenEvent(SocketChannel socket) { + for (TunnelClientListener listener : this.listeners) { + listener.onOpen(socket); + } + } + + public void fireCloseEvent(SocketChannel socket) { + for (TunnelClientListener listener : this.listeners) { + listener.onClose(socket); + } + } + +} diff --git a/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/TunnelConnection.java b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/TunnelConnection.java new file mode 100644 index 0000000000..f885b2a1ee --- /dev/null +++ b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/TunnelConnection.java @@ -0,0 +1,42 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.client; + +import java.io.Closeable; +import java.nio.channels.WritableByteChannel; + +/** + * Interface used to manage socket tunnel connections. + * + * @author Phillip Webb + * @since 1.3.0 + */ +public interface TunnelConnection { + + /** + * Open the tunnel connection. + * @param incomingChannel A {@link WritableByteChannel} that should be used to write + * any incoming data received from the remote server. + * @param closeable + * @return A {@link WritableByteChannel} that should be used to send any outgoing data + * destined for the remote server + * @throws Exception + */ + WritableByteChannel open(WritableByteChannel incomingChannel, Closeable closeable) + throws Exception; + +} diff --git a/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/package-info.java b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/package-info.java new file mode 100644 index 0000000000..109e9d16f8 --- /dev/null +++ b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/client/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Client side TCP tunnel support. + */ +package org.springframework.boot.developertools.tunnel.client; + diff --git a/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/package-info.java b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/package-info.java new file mode 100644 index 0000000000..ee9ad911d2 --- /dev/null +++ b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/package-info.java @@ -0,0 +1,23 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides support for tunneling TCP traffic over HTTP. Tunneling is primarily designed + * for the Java Debug Wire Protocol (JDWP) and as such only expects a single connection + * and isn't particularly worried about resource usage. + */ +package org.springframework.boot.developertools.tunnel; + diff --git a/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/payload/HttpTunnelPayload.java b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/payload/HttpTunnelPayload.java new file mode 100644 index 0000000000..72dfe63422 --- /dev/null +++ b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/payload/HttpTunnelPayload.java @@ -0,0 +1,185 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.payload; + +import java.io.IOException; +import java.io.InterruptedIOException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * Encapsulates a payload data sent via a HTTP tunnel. + * + * @author Phillip Webb + * @since 1.3.0 + */ +public class HttpTunnelPayload { + + private static final String SEQ_HEADER = "x-seq"; + + private static final int BUFFER_SIZE = 1024 * 100; + + final protected static char[] HEX_CHARS = "0123456789ABCDEF".toCharArray(); + + private static final Log logger = LogFactory.getLog(HttpTunnelPayload.class); + + private final long sequence; + + private final ByteBuffer data; + + /** + * Create a new {@link HttpTunnelPayload} instance. + * @param sequence the sequence number of the payload + * @param data the payload data + */ + public HttpTunnelPayload(long sequence, ByteBuffer data) { + Assert.isTrue(sequence > 0, "Sequence must be positive"); + Assert.notNull(data, "Data must not be null"); + this.sequence = sequence; + this.data = data; + } + + /** + * Return the sequence number of the payload. + * @return the sequence + */ + public long getSequence() { + return this.sequence; + } + + /** + * Assign this payload to the given {@link HttpOutputMessage}. + * @param message the message to assign this payload to + * @throws IOException + */ + public void assignTo(HttpOutputMessage message) throws IOException { + Assert.notNull(message, "Message must not be null"); + HttpHeaders headers = message.getHeaders(); + headers.setContentLength(this.data.remaining()); + headers.add(SEQ_HEADER, Long.toString(getSequence())); + headers.setContentType(MediaType.APPLICATION_OCTET_STREAM); + WritableByteChannel body = Channels.newChannel(message.getBody()); + while (this.data.hasRemaining()) { + body.write(this.data); + } + body.close(); + } + + /** + * Write the content of this payload to the given target channel. + * @param channel the channel to write to + * @throws IOException + */ + public void writeTo(WritableByteChannel channel) throws IOException { + Assert.notNull(channel, "Channel must not be null"); + while (this.data.hasRemaining()) { + channel.write(this.data); + } + } + + /** + * Return the {@link HttpTunnelPayload} for the given message or {@code null} if there + * is no payload. + * @param message the HTTP message + * @return the payload or {@code null} + * @throws IOException + */ + public static HttpTunnelPayload get(HttpInputMessage message) throws IOException { + long length = message.getHeaders().getContentLength(); + if (length <= 0) { + return null; + } + String seqHeader = message.getHeaders().getFirst(SEQ_HEADER); + Assert.state(StringUtils.hasLength(seqHeader), "Missing sequence header"); + ReadableByteChannel body = Channels.newChannel(message.getBody()); + ByteBuffer payload = ByteBuffer.allocate((int) length); + while (payload.hasRemaining()) { + body.read(payload); + } + body.close(); + payload.flip(); + return new HttpTunnelPayload(Long.valueOf(seqHeader), payload); + } + + /** + * Return the payload data for the given source {@link ReadableByteChannel} or null if + * the channel timed out whilst reading. + * @param channel the source channel + * @return payload data or {@code null} + * @throws IOException + */ + public static ByteBuffer getPayloadData(ReadableByteChannel channel) + throws IOException { + ByteBuffer buffer = ByteBuffer.allocate(BUFFER_SIZE); + try { + int amountRead = channel.read(buffer); + Assert.state(amountRead != -1, "Target server connection closed"); + buffer.flip(); + return buffer; + } + catch (InterruptedIOException ex) { + return null; + } + } + + /** + * Log incoming payload information at trace level to aid diagnostics. + */ + public void logIncoming() { + log("< "); + } + + /** + * Log incoming payload information at trace level to aid diagnostics. + */ + public void logOutgoing() { + log("> "); + } + + private void log(String prefix) { + if (logger.isTraceEnabled()) { + logger.trace(prefix + toHexString()); + } + } + + /** + * Return the payload as a hexadecimal string. + * @return the payload as a hex string + */ + public String toHexString() { + byte[] bytes = this.data.array(); + char[] hex = new char[this.data.remaining() * 2]; + for (int i = this.data.position(); i < this.data.remaining(); i++) { + int b = bytes[i] & 0xFF; + hex[i * 2] = HEX_CHARS[b >>> 4]; + hex[i * 2 + 1] = HEX_CHARS[b & 0x0F]; + } + return new String(hex); + } + +} diff --git a/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/payload/HttpTunnelPayloadForwarder.java b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/payload/HttpTunnelPayloadForwarder.java new file mode 100644 index 0000000000..328b1954c5 --- /dev/null +++ b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/payload/HttpTunnelPayloadForwarder.java @@ -0,0 +1,69 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.payload; + +import java.io.IOException; +import java.nio.channels.WritableByteChannel; +import java.util.HashMap; +import java.util.Map; + +import org.springframework.util.Assert; + +/** + * Utility class that forwards {@link HttpTunnelPayload} instances to a destination + * channel, respecting sequence order. + * + * @author Phillip Webb + * @since 1.3.0 + */ +public class HttpTunnelPayloadForwarder { + + private static final int MAXIMUM_QUEUE_SIZE = 100; + + private final WritableByteChannel targetChannel; + + private long lastRequestSeq = 0; + + private final Map queue = new HashMap(); + + /** + * Create a new {@link HttpTunnelPayloadForwarder} instance. + * @param targetChannel the target channel + */ + public HttpTunnelPayloadForwarder(WritableByteChannel targetChannel) { + Assert.notNull(targetChannel, "TargetChannel must not be null"); + this.targetChannel = targetChannel; + } + + public synchronized void forward(HttpTunnelPayload payload) throws IOException { + long seq = payload.getSequence(); + if (this.lastRequestSeq != seq - 1) { + Assert.state(this.queue.size() < MAXIMUM_QUEUE_SIZE, + "Too many messages queued"); + this.queue.put(seq, payload); + return; + } + payload.logOutgoing(); + payload.writeTo(this.targetChannel); + this.lastRequestSeq = seq; + HttpTunnelPayload queuedItem = this.queue.get(seq + 1); + if (queuedItem != null) { + forward(queuedItem); + } + } + +} diff --git a/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/payload/package-info.java b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/payload/package-info.java new file mode 100644 index 0000000000..fdf6429f18 --- /dev/null +++ b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/payload/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Classes to deal with payloads sent over a HTTP tunnel. + */ +package org.springframework.boot.developertools.tunnel.payload; + diff --git a/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/HttpTunnelServer.java b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/HttpTunnelServer.java new file mode 100644 index 0000000000..976109f45f --- /dev/null +++ b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/HttpTunnelServer.java @@ -0,0 +1,486 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.server; + +import java.io.IOException; +import java.net.ConnectException; +import java.nio.ByteBuffer; +import java.nio.channels.ByteChannel; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Iterator; +import java.util.concurrent.atomic.AtomicLong; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.boot.developertools.tunnel.payload.HttpTunnelPayload; +import org.springframework.boot.developertools.tunnel.payload.HttpTunnelPayloadForwarder; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.server.ServerHttpAsyncRequestControl; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.util.Assert; + +/** + * A server that can be used to tunnel TCP traffic over HTTP. Similar in design to the Bidirectional-streams Over Synchronous + * HTTP (BOSH) XMPP extension protocol, the server uses long polling with HTTP + * requests held open until a response is available. A typical traffic pattern would be as + * follows: + * + *
+ * [ CLIENT ]                      [ SERVER ]
+ *     | (a) Initial empty request     |
+ *     |------------------------------}|
+ *     | (b) Data I                    |
+ *  --}|------------------------------}|---}
+ *     | Response I (a)                |
+ *  {--|<------------------------------|{---
+ *     |                               |
+ *     | (c) Data II                   |
+ *  --}|------------------------------}|---}
+ *     | Response II (b)               |
+ *  {--|{------------------------------|{---
+ *     .                               .
+ *     .                               .
+ * 
+ * + * Each incoming request is held open to be used to carry the next available response. The + * server will hold at most two connections open at any given time. + *

+ * Requests should be made using HTTP GET or POST (depending if there is a payload), with + * any payload contained in the body. The following response codes can be returned from + * the server: + *

+ * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
StatusMeaning
200 (OK)Data payload response.
204 (No Content)The long poll has timed out and the client should start a new request.
429 (Too many requests)There are already enough connections open, this one can be dropped.
410 (Gone)The target server has disconnected.
+ *

+ * Requests and responses that contain payloads include a {@code x-seq} header that + * contains a running sequence number (used to ensure data is applied in the correct + * order). The first request containing a payload should have a {@code x-seq} value of + * {@code 1}. + * + * @author Phillip Webb + * @since 1.3.0 + * @see org.springframework.boot.developertools.tunnel.client.HttpTunnelConnection + */ +public class HttpTunnelServer { + + private static final int SECONDS = 1000; + + private static final int DEFAULT_LONG_POLL_TIMEOUT = 10 * SECONDS; + + private static final long DEFAULT_DISCONNECT_TIMEOUT = 30 * SECONDS; + + private static final MediaType DISCONNECT_MEDIA_TYPE = new MediaType("application", + "x-disconnect"); + + private static final Log logger = LogFactory.getLog(HttpTunnelServer.class); + + private final TargetServerConnection serverConnection; + + private int longPollTimeout = DEFAULT_LONG_POLL_TIMEOUT; + + private long disconnectTimeout = DEFAULT_DISCONNECT_TIMEOUT; + + private volatile ServerThread serverThread; + + /** + * Creates a new {@link HttpTunnelServer} instance. + * @param serverConnection the connection to the target server + */ + public HttpTunnelServer(TargetServerConnection serverConnection) { + Assert.notNull(serverConnection, "ServerConnection must not be null"); + this.serverConnection = serverConnection; + } + + /** + * Handle an incoming HTTP connection. + * @param request the HTTP request + * @param response the HTTP response + * @throws IOException + */ + public void handle(ServerHttpRequest request, ServerHttpResponse response) + throws IOException { + handle(new HttpConnection(request, response)); + } + + /** + * Handle an incoming HTTP connection. + * @param httpConnection the HTTP connection + * @throws IOException + */ + protected void handle(HttpConnection httpConnection) throws IOException { + try { + getServerThread().handleIncomingHttp(httpConnection); + httpConnection.waitForResponse(); + } + catch (ConnectException ex) { + httpConnection.respond(HttpStatus.GONE); + } + } + + /** + * Returns the active server thread, creating and starting it if necessary. + * @return the {@code ServerThread} (never {@code null}) + * @throws IOException + */ + protected ServerThread getServerThread() throws IOException { + synchronized (this) { + if (this.serverThread == null) { + ByteChannel channel = this.serverConnection.open(this.longPollTimeout); + this.serverThread = new ServerThread(channel); + this.serverThread.start(); + } + return this.serverThread; + } + } + + /** + * Called when the server thread exits. + */ + void clearServerThread() { + synchronized (this) { + this.serverThread = null; + } + } + + /** + * Set the long poll timeout for the server. + * @param longPollTimeout the long poll timeout in milliseconds + */ + public void setLongPollTimeout(int longPollTimeout) { + Assert.isTrue(longPollTimeout > 0, "LongPollTimeout must be a positive value"); + this.longPollTimeout = longPollTimeout; + } + + /** + * Set the maximum amount of time to wait for a client before closing the connection. + * @param disconnectTimeout the disconnect timeout in milliseconds + */ + public void setDisconnectTimeout(long disconnectTimeout) { + Assert.isTrue(disconnectTimeout > 0, "DisconnectTimeout must be a positive value"); + this.disconnectTimeout = disconnectTimeout; + } + + /** + * The main server thread used to transfer tunnel traffic. + */ + protected class ServerThread extends Thread { + + private final ByteChannel targetServer; + + private final Deque httpConnections; + + private final HttpTunnelPayloadForwarder payloadForwarder; + + private boolean closed; + + private AtomicLong responseSeq = new AtomicLong(); + + private long lastHttpRequestTime; + + public ServerThread(ByteChannel targetServer) { + Assert.notNull(targetServer, "TargetServer must not be null"); + this.targetServer = targetServer; + this.httpConnections = new ArrayDeque(2); + this.payloadForwarder = new HttpTunnelPayloadForwarder(targetServer); + } + + @Override + public void run() { + try { + try { + readAndForwardTargetServerData(); + } + catch (Exception ex) { + logger.trace("Unexpected exception from tunnel server", ex); + } + } + finally { + this.closed = true; + closeHttpConnections(); + closeTargetServer(); + HttpTunnelServer.this.clearServerThread(); + } + } + + private void readAndForwardTargetServerData() throws IOException { + while (this.targetServer.isOpen()) { + closeStaleHttpConnections(); + ByteBuffer data = HttpTunnelPayload.getPayloadData(this.targetServer); + synchronized (this.httpConnections) { + if (data != null) { + HttpTunnelPayload payload = new HttpTunnelPayload( + this.responseSeq.incrementAndGet(), data); + payload.logIncoming(); + HttpConnection connection = getOrWaitForHttpConnection(); + connection.respond(payload); + } + } + } + } + + private HttpConnection getOrWaitForHttpConnection() { + synchronized (this.httpConnections) { + HttpConnection httpConnection = this.httpConnections.pollFirst(); + while (httpConnection == null) { + try { + this.httpConnections.wait(HttpTunnelServer.this.longPollTimeout); + } + catch (InterruptedException ex) { + closeHttpConnections(); + } + httpConnection = this.httpConnections.pollFirst(); + } + return httpConnection; + } + } + + private void closeStaleHttpConnections() throws IOException { + checkNotDisconnected(); + synchronized (this.httpConnections) { + Iterator iterator = this.httpConnections.iterator(); + while (iterator.hasNext()) { + HttpConnection httpConnection = iterator.next(); + if (httpConnection.isOlderThan(HttpTunnelServer.this.longPollTimeout)) { + httpConnection.respond(HttpStatus.NO_CONTENT); + iterator.remove(); + } + } + } + } + + private void checkNotDisconnected() { + long timeout = HttpTunnelServer.this.disconnectTimeout; + long duration = System.currentTimeMillis() - this.lastHttpRequestTime; + Assert.state(duration < timeout, "Disconnect timeout"); + } + + private void closeHttpConnections() { + synchronized (this.httpConnections) { + while (!this.httpConnections.isEmpty()) { + try { + this.httpConnections.removeFirst().respond(HttpStatus.GONE); + } + catch (Exception ex) { + logger.trace("Unable to close remote HTTP connection"); + } + } + } + } + + private void closeTargetServer() { + try { + this.targetServer.close(); + } + catch (IOException ex) { + logger.trace("Unable to target server connection"); + } + } + + /** + * Handle an incoming {@link HttpConnection}. + * @param httpConnection the connection to handle. + * @throws IOException + */ + public void handleIncomingHttp(HttpConnection httpConnection) throws IOException { + if (this.closed) { + httpConnection.respond(HttpStatus.GONE); + } + synchronized (this.httpConnections) { + while (this.httpConnections.size() > 1) { + this.httpConnections.removeFirst().respond( + HttpStatus.TOO_MANY_REQUESTS); + } + this.lastHttpRequestTime = System.currentTimeMillis(); + this.httpConnections.addLast(httpConnection); + this.httpConnections.notify(); + } + forwardToTargetServer(httpConnection); + } + + private void forwardToTargetServer(HttpConnection httpConnection) + throws IOException { + if (httpConnection.isDisconnectRequest()) { + this.targetServer.close(); + interrupt(); + } + ServerHttpRequest request = httpConnection.getRequest(); + HttpTunnelPayload payload = HttpTunnelPayload.get(request); + if (payload != null) { + this.payloadForwarder.forward(payload); + } + } + + } + + /** + * Encapsulates a HTTP request/response pair. + */ + protected static class HttpConnection { + + private final long createTime; + + private final ServerHttpRequest request; + + private final ServerHttpResponse response; + + private ServerHttpAsyncRequestControl async; + + private volatile boolean complete = false; + + public HttpConnection(ServerHttpRequest request, ServerHttpResponse response) { + this.createTime = System.currentTimeMillis(); + this.request = request; + this.response = response; + this.async = startAsync(); + } + + /** + * Start asynchronous support or if unavailble return {@code null} to cause + * {@link #waitForResponse()} to block. + * @return the async request control + */ + protected ServerHttpAsyncRequestControl startAsync() { + try { + // Try to use async to save blocking + ServerHttpAsyncRequestControl async = this.request + .getAsyncRequestControl(this.response); + async.start(); + return async; + } + catch (Exception ex) { + return null; + } + } + + /** + * Return the underlying request. + * @return the request + */ + public final ServerHttpRequest getRequest() { + return this.request; + } + + /** + * Return the underlying response. + * @return the response + */ + protected final ServerHttpResponse getResponse() { + return this.response; + } + + /** + * Determine if a connection is older than the specified time. + * @param time the time to check + * @return {@code true} if the request is older than the time + */ + public boolean isOlderThan(int time) { + long runningTime = System.currentTimeMillis() - this.createTime; + return (runningTime > time); + } + + /** + * Cause the request to block or use asynchronous methods to wait until a response + * is available. + */ + public void waitForResponse() { + if (this.async == null) { + while (!this.complete) { + try { + synchronized (this) { + wait(1000); + } + } + catch (InterruptedException ex) { + } + } + } + } + + /** + * Detect if the request is actually a signal to disconnect. + * @return if the request is a signal to disconnect + */ + public boolean isDisconnectRequest() { + return DISCONNECT_MEDIA_TYPE.equals(this.request.getHeaders() + .getContentType()); + } + + /** + * Send a HTTP status response. + * @param status the status to send + * @throws IOException + */ + public void respond(HttpStatus status) throws IOException { + Assert.notNull(status, "Status must not be null"); + this.response.setStatusCode(status); + complete(); + } + + /** + * Send a payload response. + * @param payload the payload to send + * @throws IOException + */ + public void respond(HttpTunnelPayload payload) throws IOException { + Assert.notNull(payload, "Payload must not be null"); + this.response.setStatusCode(HttpStatus.OK); + payload.assignTo(this.response); + complete(); + } + + /** + * Called when a request is complete. + */ + protected void complete() { + if (this.async != null) { + this.async.complete(); + } + else { + synchronized (this) { + this.complete = true; + notifyAll(); + } + } + } + + } + +} diff --git a/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/HttpTunnelServerHandler.java b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/HttpTunnelServerHandler.java new file mode 100644 index 0000000000..2ad4c976e6 --- /dev/null +++ b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/HttpTunnelServerHandler.java @@ -0,0 +1,51 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.server; + +import java.io.IOException; + +import org.springframework.boot.developertools.remote.server.Handler; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.util.Assert; + +/** + * Adapts a {@link HttpTunnelServer} to a {@link Handler}. + * + * @author Phillip Webb + * @since 1.3.0 + */ +public class HttpTunnelServerHandler implements Handler { + + private HttpTunnelServer server; + + /** + * Create a new {@link HttpTunnelServerHandler} instance. + * @param server the server to adapt + */ + public HttpTunnelServerHandler(HttpTunnelServer server) { + Assert.notNull(server, "Server must not be null"); + this.server = server; + } + + @Override + public void handle(ServerHttpRequest request, ServerHttpResponse response) + throws IOException { + this.server.handle(request, response); + } + +} diff --git a/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/PortProvider.java b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/PortProvider.java new file mode 100644 index 0000000000..adb7518502 --- /dev/null +++ b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/PortProvider.java @@ -0,0 +1,34 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.server; + +/** + * Strategy interface to provide access to a port (which may change if an existing + * connection is closed). + * + * @author Phillip Webb + * @since 1.3.0 + */ +public interface PortProvider { + + /** + * Return the port number + * @return the port number + */ + int getPort(); + +} diff --git a/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/RemoteDebugPortProvider.java b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/RemoteDebugPortProvider.java new file mode 100644 index 0000000000..1dae7635c6 --- /dev/null +++ b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/RemoteDebugPortProvider.java @@ -0,0 +1,61 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.server; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.boot.lang.UsesUnsafeJava; +import org.springframework.util.Assert; + +/** + * {@link PortProvider} that provides the port being used by the Java remote debugging. + * + * @author Phillip Webb + */ +public class RemoteDebugPortProvider implements PortProvider { + + private static final String JDWP_ADDRESS_PROPERTY = "sun.jdwp.listenerAddress"; + + private static final Log logger = LogFactory.getLog(RemoteDebugPortProvider.class); + + @Override + public int getPort() { + Assert.state(isRemoteDebugRunning(), "Remote debug is not running"); + return getRemoteDebugPort(); + } + + public static boolean isRemoteDebugRunning() { + return getRemoteDebugPort() != -1; + } + + @UsesUnsafeJava + @SuppressWarnings("restriction") + private static int getRemoteDebugPort() { + String property = sun.misc.VMSupport.getAgentProperties().getProperty( + JDWP_ADDRESS_PROPERTY); + try { + if (property != null && property.contains(":")) { + return Integer.valueOf(property.split(":")[1]); + } + } + catch (Exception ex) { + logger.trace("Unable to get JDWP port from property value '" + property + "'"); + } + return -1; + } + +} diff --git a/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/SocketTargetServerConnection.java b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/SocketTargetServerConnection.java new file mode 100644 index 0000000000..ddc070257e --- /dev/null +++ b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/SocketTargetServerConnection.java @@ -0,0 +1,101 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.server; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.ByteChannel; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.SocketChannel; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.util.Assert; + +/** + * Socket based {@link TargetServerConnection}. + * + * @author Phillip Webb + * @since 1.3.0 + */ +public class SocketTargetServerConnection implements TargetServerConnection { + + private static final Log logger = LogFactory + .getLog(SocketTargetServerConnection.class); + + private final PortProvider portProvider; + + /** + * Create a new {@link SocketTargetServerConnection}. + * @param portProvider the port provider + */ + public SocketTargetServerConnection(PortProvider portProvider) { + Assert.notNull(portProvider, "PortProvider must not be null"); + this.portProvider = portProvider; + } + + @Override + public ByteChannel open(int socketTimeout) throws IOException { + SocketAddress address = new InetSocketAddress(this.portProvider.getPort()); + logger.trace("Opening tunnel connection to target server on " + address); + SocketChannel channel = SocketChannel.open(address); + channel.socket().setSoTimeout(socketTimeout); + return new TimeoutAwareChannel(channel); + } + + /** + * Wrapper to expose the {@link SocketChannel} in such a way that + * {@code SocketTimeoutExceptions} are still thrown from read methods. + */ + private static class TimeoutAwareChannel implements ByteChannel { + + private final SocketChannel socketChannel; + + private final ReadableByteChannel readChannel; + + public TimeoutAwareChannel(SocketChannel socketChannel) throws IOException { + this.socketChannel = socketChannel; + this.readChannel = Channels.newChannel(socketChannel.socket() + .getInputStream()); + } + + @Override + public int read(ByteBuffer dst) throws IOException { + return this.readChannel.read(dst); + } + + @Override + public int write(ByteBuffer src) throws IOException { + return this.socketChannel.write(src); + } + + @Override + public boolean isOpen() { + return this.socketChannel.isOpen(); + } + + @Override + public void close() throws IOException { + this.socketChannel.close(); + } + + } + +} diff --git a/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/StaticPortProvider.java b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/StaticPortProvider.java new file mode 100644 index 0000000000..34c129f6a1 --- /dev/null +++ b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/StaticPortProvider.java @@ -0,0 +1,41 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.server; + +import org.springframework.util.Assert; + +/** + * {@link PortProvider} for a static port that won't change. + * + * @author Phillip Webb + * @since 1.3.0 + */ +public class StaticPortProvider implements PortProvider { + + private final int port; + + public StaticPortProvider(int port) { + Assert.isTrue(port > 0, "Port must be positive"); + this.port = port; + } + + @Override + public int getPort() { + return this.port; + } + +} diff --git a/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/TargetServerConnection.java b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/TargetServerConnection.java new file mode 100644 index 0000000000..26c5d2565f --- /dev/null +++ b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/TargetServerConnection.java @@ -0,0 +1,38 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.server; + +import java.io.IOException; +import java.nio.channels.ByteChannel; + +/** + * Manages the connection to the ultimate tunnel target server. + * + * @author Phillip Webb + * @since 1.3.0 + */ +public interface TargetServerConnection { + + /** + * Open a connection to the target server with the specified timeout. + * @param timeout the read timeout + * @return a {@link ByteChannel} providing read/write access to the server + * @throws IOException + */ + ByteChannel open(int timeout) throws IOException; + +} diff --git a/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/package-info.java b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/package-info.java new file mode 100644 index 0000000000..85f70719a9 --- /dev/null +++ b/spring-boot-developer-tools/src/main/java/org/springframework/boot/developertools/tunnel/server/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Server side TCP tunnel support. + */ +package org.springframework.boot.developertools.tunnel.server; + diff --git a/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/client/HttpTunnelConnectionTests.java b/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/client/HttpTunnelConnectionTests.java new file mode 100644 index 0000000000..718c83286c --- /dev/null +++ b/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/client/HttpTunnelConnectionTests.java @@ -0,0 +1,166 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.client; + +import java.io.ByteArrayOutputStream; +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; +import java.util.concurrent.Executor; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.springframework.boot.developertools.test.MockClientHttpRequestFactory; +import org.springframework.boot.developertools.tunnel.client.HttpTunnelConnection.TunnelChannel; +import org.springframework.http.HttpStatus; +import org.springframework.util.SocketUtils; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link HttpTunnelConnection}. + * + * @author Phillip Webb + * @author Rob Winch + */ +public class HttpTunnelConnectionTests { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + private int port = SocketUtils.findAvailableTcpPort(); + + private String url; + + private ByteArrayOutputStream incommingData; + + private WritableByteChannel incomingChannel; + + @Mock + private Closeable closeable; + + private MockClientHttpRequestFactory requestFactory = new MockClientHttpRequestFactory(); + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + this.url = "http://localhost:" + this.port; + this.incommingData = new ByteArrayOutputStream(); + this.incomingChannel = Channels.newChannel(this.incommingData); + } + + @Test + public void urlMustNotBeNull() throws Exception { + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("URL must not be empty"); + new HttpTunnelConnection(null, this.requestFactory); + } + + @Test + public void urlMustNotBeEmpty() throws Exception { + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("URL must not be empty"); + new HttpTunnelConnection("", this.requestFactory); + } + + @Test + public void urlMustNotBeMalformed() throws Exception { + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("Malformed URL 'htttttp:///ttest'"); + new HttpTunnelConnection("htttttp:///ttest", this.requestFactory); + } + + @Test + public void requestFactoryMustNotBeNull() { + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("RequestFactory must not be null"); + new HttpTunnelConnection(this.url, null); + } + + @Test + public void closeTunnelChangesIsOpen() throws Exception { + this.requestFactory.willRespondAfterDelay(1000, HttpStatus.GONE); + WritableByteChannel channel = openTunnel(false); + assertThat(channel.isOpen(), equalTo(true)); + channel.close(); + assertThat(channel.isOpen(), equalTo(false)); + } + + @Test + public void closeTunnelCallsCloseableOnce() throws Exception { + this.requestFactory.willRespondAfterDelay(1000, HttpStatus.GONE); + WritableByteChannel channel = openTunnel(false); + verify(this.closeable, never()).close(); + channel.close(); + channel.close(); + verify(this.closeable, times(1)).close(); + } + + @Test + public void typicalTraffic() throws Exception { + this.requestFactory.willRespond("hi", "=2", "=3"); + TunnelChannel channel = openTunnel(true); + write(channel, "hello"); + write(channel, "1+1"); + write(channel, "1+2"); + assertThat(this.incommingData.toString(), equalTo("hi=2=3")); + } + + @Test + public void trafficWithLongPollTimeouts() throws Exception { + for (int i = 0; i < 10; i++) { + this.requestFactory.willRespond(HttpStatus.NO_CONTENT); + } + this.requestFactory.willRespond("hi"); + TunnelChannel channel = openTunnel(true); + write(channel, "hello"); + assertThat(this.incommingData.toString(), equalTo("hi")); + assertThat(this.requestFactory.getExecutedRequests().size(), greaterThan(10)); + } + + private void write(TunnelChannel channel, String string) throws IOException { + channel.write(ByteBuffer.wrap(string.getBytes())); + } + + private TunnelChannel openTunnel(boolean singleThreaded) throws Exception { + HttpTunnelConnection connection = new HttpTunnelConnection(this.url, + this.requestFactory, + (singleThreaded ? new CurrentThreadExecutor() : null)); + return connection.open(this.incomingChannel, this.closeable); + } + + private static class CurrentThreadExecutor implements Executor { + + @Override + public void execute(Runnable command) { + command.run(); + } + + } + +} diff --git a/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/client/TunnelClientTests.java b/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/client/TunnelClientTests.java new file mode 100644 index 0000000000..807d43b9a7 --- /dev/null +++ b/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/client/TunnelClientTests.java @@ -0,0 +1,199 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.client; + +import java.io.ByteArrayOutputStream; +import java.io.Closeable; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.SocketChannel; +import java.nio.channels.WritableByteChannel; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.springframework.util.SocketUtils; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link TunnelClient}. + * + * @author Phillip Webb + */ +public class TunnelClientTests { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + private int listenPort = SocketUtils.findAvailableTcpPort(); + + private MockTunnelConnection tunnelConnection = new MockTunnelConnection(); + + @Test + public void listenPortMustBePositive() throws Exception { + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("ListenPort must be positive"); + new TunnelClient(0, this.tunnelConnection); + } + + @Test + public void tunnelConnectionMustNotBeNull() throws Exception { + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("TunnelConnection must not be null"); + new TunnelClient(1, null); + } + + @Test + public void typicalTraffic() throws Exception { + TunnelClient client = new TunnelClient(this.listenPort, this.tunnelConnection); + client.start(); + SocketChannel channel = SocketChannel + .open(new InetSocketAddress(this.listenPort)); + channel.write(ByteBuffer.wrap("hello".getBytes())); + ByteBuffer buffer = ByteBuffer.allocate(5); + channel.read(buffer); + channel.close(); + this.tunnelConnection.verifyWritten("hello"); + assertThat(new String(buffer.array()), equalTo("olleh")); + } + + @Test + public void socketChannelClosedTriggersTunnelClose() throws Exception { + TunnelClient client = new TunnelClient(this.listenPort, this.tunnelConnection); + client.start(); + SocketChannel channel = SocketChannel + .open(new InetSocketAddress(this.listenPort)); + channel.close(); + client.getServerThread().stopAcceptingConnections(); + client.getServerThread().join(2000); + assertThat(this.tunnelConnection.getOpenedTimes(), equalTo(1)); + assertThat(this.tunnelConnection.isOpen(), equalTo(false)); + } + + @Test + public void stopTriggersTunnelClose() throws Exception { + TunnelClient client = new TunnelClient(this.listenPort, this.tunnelConnection); + client.start(); + SocketChannel channel = SocketChannel + .open(new InetSocketAddress(this.listenPort)); + client.stop(); + assertThat(this.tunnelConnection.getOpenedTimes(), equalTo(1)); + assertThat(this.tunnelConnection.isOpen(), equalTo(false)); + assertThat(channel.read(ByteBuffer.allocate(1)), equalTo(-1)); + } + + @Test + public void addListener() throws Exception { + TunnelClient client = new TunnelClient(this.listenPort, this.tunnelConnection); + TunnelClientListener listener = mock(TunnelClientListener.class); + client.addListener(listener); + client.start(); + SocketChannel channel = SocketChannel + .open(new InetSocketAddress(this.listenPort)); + channel.close(); + client.getServerThread().stopAcceptingConnections(); + client.getServerThread().join(2000); + verify(listener).onOpen(any(SocketChannel.class)); + verify(listener).onClose(any(SocketChannel.class)); + } + + private static class MockTunnelConnection implements TunnelConnection { + + private final ByteArrayOutputStream written = new ByteArrayOutputStream(); + + private boolean open; + + private int openedTimes; + + @Override + public WritableByteChannel open(WritableByteChannel incomingChannel, + Closeable closeable) throws Exception { + this.openedTimes++; + this.open = true; + return new TunnelChannel(incomingChannel, closeable); + } + + public void verifyWritten(String expected) { + verifyWritten(expected.getBytes()); + } + + public void verifyWritten(byte[] expected) { + synchronized (this.written) { + assertThat(this.written.toByteArray(), equalTo(expected)); + this.written.reset(); + } + } + + public boolean isOpen() { + return this.open; + } + + public int getOpenedTimes() { + return this.openedTimes; + } + + private class TunnelChannel implements WritableByteChannel { + + private final WritableByteChannel incomingChannel; + + private final Closeable closeable; + + public TunnelChannel(WritableByteChannel incomingChannel, Closeable closeable) { + this.incomingChannel = incomingChannel; + this.closeable = closeable; + } + + @Override + public boolean isOpen() { + return MockTunnelConnection.this.open; + } + + @Override + public void close() throws IOException { + MockTunnelConnection.this.open = false; + this.closeable.close(); + } + + @Override + public int write(ByteBuffer src) throws IOException { + int remaining = src.remaining(); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + Channels.newChannel(stream).write(src); + byte[] bytes = stream.toByteArray(); + synchronized (MockTunnelConnection.this.written) { + MockTunnelConnection.this.written.write(bytes); + } + byte[] reversed = new byte[bytes.length]; + for (int i = 0; i < reversed.length; i++) { + reversed[i] = bytes[bytes.length - 1 - i]; + } + this.incomingChannel.write(ByteBuffer.wrap(reversed)); + return remaining; + } + + } + + } + +} diff --git a/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/payload/HttpTunnelPayloadForwarderTests.java b/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/payload/HttpTunnelPayloadForwarderTests.java new file mode 100644 index 0000000000..260edf30b8 --- /dev/null +++ b/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/payload/HttpTunnelPayloadForwarderTests.java @@ -0,0 +1,85 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.payload; + +import java.io.ByteArrayOutputStream; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; + +/** + * Tests for {@link HttpTunnelPayloadForwarder}. + * + * @author Phillip Webb + */ +public class HttpTunnelPayloadForwarderTests { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void targetChannelMustNoBeNull() throws Exception { + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("TargetChannel must not be null"); + new HttpTunnelPayloadForwarder(null); + } + + @Test + public void forwardInSequence() throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + WritableByteChannel channel = Channels.newChannel(out); + HttpTunnelPayloadForwarder forwarder = new HttpTunnelPayloadForwarder(channel); + forwarder.forward(payload(1, "he")); + forwarder.forward(payload(2, "ll")); + forwarder.forward(payload(3, "o")); + assertThat(out.toByteArray(), equalTo("hello".getBytes())); + } + + @Test + public void forwardOutOfSequence() throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + WritableByteChannel channel = Channels.newChannel(out); + HttpTunnelPayloadForwarder forwarder = new HttpTunnelPayloadForwarder(channel); + forwarder.forward(payload(3, "o")); + forwarder.forward(payload(2, "ll")); + forwarder.forward(payload(1, "he")); + assertThat(out.toByteArray(), equalTo("hello".getBytes())); + } + + @Test + public void overflow() throws Exception { + WritableByteChannel channel = Channels.newChannel(new ByteArrayOutputStream()); + HttpTunnelPayloadForwarder forwarder = new HttpTunnelPayloadForwarder(channel); + this.thrown.expect(IllegalStateException.class); + this.thrown.expectMessage("Too many messages queued"); + for (int i = 2; i < 130; i++) { + forwarder.forward(payload(i, "data" + i)); + } + } + + private HttpTunnelPayload payload(long sequence, String data) { + return new HttpTunnelPayload(sequence, ByteBuffer.wrap(data.getBytes())); + } + +} diff --git a/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/payload/HttpTunnelPayloadTests.java b/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/payload/HttpTunnelPayloadTests.java new file mode 100644 index 0000000000..3d6d5f93d1 --- /dev/null +++ b/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/payload/HttpTunnelPayloadTests.java @@ -0,0 +1,151 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.payload; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.net.SocketTimeoutException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.http.server.ServletServerHttpResponse; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link HttpTunnelPayload}. + * + * @author Phillip Webb + */ +public class HttpTunnelPayloadTests { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void sequenceMustBePositive() throws Exception { + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("Sequence must be positive"); + new HttpTunnelPayload(0, ByteBuffer.allocate(1)); + } + + @Test + public void dataMustNotBeNull() throws Exception { + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("Data must not be null"); + new HttpTunnelPayload(1, null); + } + + @Test + public void getSequence() throws Exception { + HttpTunnelPayload payload = new HttpTunnelPayload(1, ByteBuffer.allocate(1)); + assertThat(payload.getSequence(), equalTo(1L)); + } + + @Test + public void getData() throws Exception { + ByteBuffer data = ByteBuffer.wrap("hello".getBytes()); + HttpTunnelPayload payload = new HttpTunnelPayload(1, data); + assertThat(getData(payload), equalTo(data.array())); + } + + @Test + public void assignTo() throws Exception { + ByteBuffer data = ByteBuffer.wrap("hello".getBytes()); + HttpTunnelPayload payload = new HttpTunnelPayload(2, data); + MockHttpServletResponse servletResponse = new MockHttpServletResponse(); + HttpOutputMessage response = new ServletServerHttpResponse(servletResponse); + payload.assignTo(response); + assertThat(servletResponse.getHeader("x-seq"), equalTo("2")); + assertThat(servletResponse.getContentAsString(), equalTo("hello")); + } + + @Test + public void getNoData() throws Exception { + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + HttpInputMessage request = new ServletServerHttpRequest(servletRequest); + HttpTunnelPayload payload = HttpTunnelPayload.get(request); + assertThat(payload, nullValue()); + } + + @Test + public void getWithMissingHeader() throws Exception { + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + servletRequest.setContent("hello".getBytes()); + HttpInputMessage request = new ServletServerHttpRequest(servletRequest); + this.thrown.expect(IllegalStateException.class); + this.thrown.expectMessage("Missing sequence header"); + HttpTunnelPayload.get(request); + } + + @Test + public void getWithData() throws Exception { + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + servletRequest.setContent("hello".getBytes()); + servletRequest.addHeader("x-seq", 123); + HttpInputMessage request = new ServletServerHttpRequest(servletRequest); + HttpTunnelPayload payload = HttpTunnelPayload.get(request); + assertThat(payload.getSequence(), equalTo(123L)); + assertThat(getData(payload), equalTo("hello".getBytes())); + } + + @Test + public void getPayloadData() throws Exception { + ReadableByteChannel channel = Channels.newChannel(new ByteArrayInputStream( + "hello".getBytes())); + ByteBuffer payloadData = HttpTunnelPayload.getPayloadData(channel); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + WritableByteChannel writeChannel = Channels.newChannel(out); + while (payloadData.hasRemaining()) { + writeChannel.write(payloadData); + } + assertThat(out.toByteArray(), equalTo("hello".getBytes())); + } + + @Test + public void getPayloadDataWithTimeout() throws Exception { + ReadableByteChannel channel = mock(ReadableByteChannel.class); + given(channel.read(any(ByteBuffer.class))) + .willThrow(new SocketTimeoutException()); + ByteBuffer payload = HttpTunnelPayload.getPayloadData(channel); + assertThat(payload, nullValue()); + } + + private byte[] getData(HttpTunnelPayload payload) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + WritableByteChannel channel = Channels.newChannel(out); + payload.writeTo(channel); + return out.toByteArray(); + } + +} diff --git a/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/server/HttpTunnelServerHandlerTests.java b/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/server/HttpTunnelServerHandlerTests.java new file mode 100644 index 0000000000..6b392b37e2 --- /dev/null +++ b/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/server/HttpTunnelServerHandlerTests.java @@ -0,0 +1,55 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.server; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link HttpTunnelServerHandler}. + * + * @author Phillip Webb + */ +public class HttpTunnelServerHandlerTests { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void serverMustNotBeNull() throws Exception { + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("Server must not be null"); + new HttpTunnelServerHandler(null); + } + + @Test + public void handleDelegatesToServer() throws Exception { + HttpTunnelServer server = mock(HttpTunnelServer.class); + HttpTunnelServerHandler handler = new HttpTunnelServerHandler(server); + ServerHttpRequest request = mock(ServerHttpRequest.class); + ServerHttpResponse response = mock(ServerHttpResponse.class); + handler.handle(request, response); + verify(server).handle(request, response); + } + +} diff --git a/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/server/HttpTunnelServerTests.java b/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/server/HttpTunnelServerTests.java new file mode 100644 index 0000000000..ff5601ed03 --- /dev/null +++ b/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/server/HttpTunnelServerTests.java @@ -0,0 +1,480 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.server; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.net.SocketTimeoutException; +import java.nio.ByteBuffer; +import java.nio.channels.ByteChannel; +import java.nio.channels.Channels; +import java.util.concurrent.BlockingDeque; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.springframework.boot.developertools.tunnel.payload.HttpTunnelPayload; +import org.springframework.boot.developertools.tunnel.server.HttpTunnelServer.HttpConnection; +import org.springframework.http.HttpStatus; +import org.springframework.http.server.ServerHttpAsyncRequestControl; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.http.server.ServletServerHttpResponse; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link HttpTunnelServer}. + * + * @author Phillip Webb + */ +public class HttpTunnelServerTests { + + private static final int DEFAULT_LONG_POLL_TIMEOUT = 10000; + + private static final byte[] NO_DATA = {}; + + private static final String SEQ_HEADER = "x-seq"; + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + private HttpTunnelServer server; + + @Mock + private TargetServerConnection serverConnection; + + private MockHttpServletRequest servletRequest; + + private MockHttpServletResponse servletResponse; + + private ServerHttpRequest request; + + private ServerHttpResponse response; + + private MockServerChannel serverChannel; + + @Before + public void setup() throws Exception { + MockitoAnnotations.initMocks(this); + this.server = new HttpTunnelServer(this.serverConnection); + given(this.serverConnection.open(anyInt())).willAnswer(new Answer() { + @Override + public ByteChannel answer(InvocationOnMock invocation) throws Throwable { + MockServerChannel channel = HttpTunnelServerTests.this.serverChannel; + channel.setTimeout((Integer) invocation.getArguments()[0]); + return channel; + } + }); + this.servletRequest = new MockHttpServletRequest(); + this.servletRequest.setAsyncSupported(true); + this.servletResponse = new MockHttpServletResponse(); + this.request = new ServletServerHttpRequest(this.servletRequest); + this.response = new ServletServerHttpResponse(this.servletResponse); + this.serverChannel = new MockServerChannel(); + } + + @Test + public void serverConnectionIsRequired() throws Exception { + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("ServerConnection must not be null"); + new HttpTunnelServer(null); + } + + @Test + public void serverConnectedOnFirstRequest() throws Exception { + verify(this.serverConnection, never()).open(anyInt()); + this.server.handle(this.request, this.response); + verify(this.serverConnection, times(1)).open(DEFAULT_LONG_POLL_TIMEOUT); + } + + @Test + public void longPollTimeout() throws Exception { + this.server.setLongPollTimeout(800); + this.server.handle(this.request, this.response); + verify(this.serverConnection, times(1)).open(800); + } + + @Test + public void longPollTimeoutMustBePositiveValue() throws Exception { + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("LongPollTimeout must be a positive value"); + this.server.setLongPollTimeout(0); + } + + @Test + public void initialRequestIsSentToServer() throws Exception { + this.servletRequest.addHeader(SEQ_HEADER, "1"); + this.servletRequest.setContent("hello".getBytes()); + this.server.handle(this.request, this.response); + this.serverChannel.disconnect(); + this.server.getServerThread().join(); + this.serverChannel.verifyReceived("hello"); + } + + @Test + public void intialRequestIsUsedForFirstServerResponse() throws Exception { + this.servletRequest.addHeader(SEQ_HEADER, "1"); + this.servletRequest.setContent("hello".getBytes()); + this.server.handle(this.request, this.response); + System.out.println("sending"); + this.serverChannel.send("hello"); + this.serverChannel.disconnect(); + this.server.getServerThread().join(); + assertThat(this.servletResponse.getContentAsString(), equalTo("hello")); + this.serverChannel.verifyReceived("hello"); + } + + @Test + public void initialRequestHasNoPayload() throws Exception { + this.server.handle(this.request, this.response); + this.serverChannel.disconnect(); + this.server.getServerThread().join(); + this.serverChannel.verifyReceived(NO_DATA); + } + + @Test + public void typicalReqestResponseTraffic() throws Exception { + MockHttpConnection h1 = new MockHttpConnection(); + this.server.handle(h1); + MockHttpConnection h2 = new MockHttpConnection("hello server", 1); + this.server.handle(h2); + this.serverChannel.verifyReceived("hello server"); + this.serverChannel.send("hello client"); + h1.verifyReceived("hello client", 1); + MockHttpConnection h3 = new MockHttpConnection("1+1", 2); + this.server.handle(h3); + this.serverChannel.send("=2"); + h2.verifyReceived("=2", 2); + MockHttpConnection h4 = new MockHttpConnection("1+2", 3); + this.server.handle(h4); + this.serverChannel.send("=3"); + h3.verifyReceived("=3", 3); + this.serverChannel.disconnect(); + this.server.getServerThread().join(); + } + + @Test + public void clientIsAwareOfServerClose() throws Exception { + MockHttpConnection h1 = new MockHttpConnection("1", 1); + this.server.handle(h1); + this.serverChannel.disconnect(); + this.server.getServerThread().join(); + assertThat(h1.getServletResponse().getStatus(), equalTo(410)); + } + + @Test + public void clientCanCloseServer() throws Exception { + MockHttpConnection h1 = new MockHttpConnection(); + this.server.handle(h1); + MockHttpConnection h2 = new MockHttpConnection("DISCONNECT", 1); + h2.getServletRequest().addHeader("Content-Type", "application/x-disconnect"); + this.server.handle(h2); + this.server.getServerThread().join(); + assertThat(h1.getServletResponse().getStatus(), equalTo(410)); + assertThat(this.serverChannel.isOpen(), equalTo(false)); + } + + @Test + public void neverMoreThanTwoHttpConnections() throws Exception { + MockHttpConnection h1 = new MockHttpConnection(); + this.server.handle(h1); + MockHttpConnection h2 = new MockHttpConnection("1", 2); + this.server.handle(h2); + MockHttpConnection h3 = new MockHttpConnection("2", 3); + this.server.handle(h3); + h1.waitForResponse(); + assertThat(h1.getServletResponse().getStatus(), equalTo(429)); + this.serverChannel.disconnect(); + this.server.getServerThread().join(); + } + + @Test + public void requestRecievedOutOfOrder() throws Exception { + MockHttpConnection h1 = new MockHttpConnection(); + MockHttpConnection h2 = new MockHttpConnection("1+2", 1); + MockHttpConnection h3 = new MockHttpConnection("+3", 2); + this.server.handle(h1); + this.server.handle(h3); + this.server.handle(h2); + this.serverChannel.verifyReceived("1+2+3"); + this.serverChannel.disconnect(); + this.server.getServerThread().join(); + } + + @Test + public void httpConnectionsAreClosedAfterLongPollTimeout() throws Exception { + this.server.setDisconnectTimeout(1000); + this.server.setLongPollTimeout(100); + MockHttpConnection h1 = new MockHttpConnection(); + this.server.handle(h1); + MockHttpConnection h2 = new MockHttpConnection(); + this.server.handle(h2); + Thread.sleep(400); + this.serverChannel.disconnect(); + this.server.getServerThread().join(); + assertThat(h1.getServletResponse().getStatus(), equalTo(204)); + assertThat(h2.getServletResponse().getStatus(), equalTo(204)); + } + + @Test + public void disconnectTimeout() throws Exception { + this.server.setDisconnectTimeout(100); + this.server.setLongPollTimeout(100); + MockHttpConnection h1 = new MockHttpConnection(); + this.server.handle(h1); + this.serverChannel.send("hello"); + this.server.getServerThread().join(); + assertThat(this.serverChannel.isOpen(), equalTo(false)); + } + + @Test + public void disconnectTimeoutMustBePositive() throws Exception { + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("DisconnectTimeout must be a positive value"); + this.server.setDisconnectTimeout(0); + } + + @Test + public void httpConnectionRespondWithPayload() throws Exception { + HttpConnection connection = new HttpConnection(this.request, this.response); + connection.waitForResponse(); + connection.respond(new HttpTunnelPayload(1, ByteBuffer.wrap("hello".getBytes()))); + assertThat(this.servletResponse.getStatus(), equalTo(200)); + assertThat(this.servletResponse.getContentAsString(), equalTo("hello")); + assertThat(this.servletResponse.getHeader(SEQ_HEADER), equalTo("1")); + } + + @Test + public void httpConnectionRespondWithStatus() throws Exception { + HttpConnection connection = new HttpConnection(this.request, this.response); + connection.waitForResponse(); + connection.respond(HttpStatus.I_AM_A_TEAPOT); + assertThat(this.servletResponse.getStatus(), equalTo(418)); + assertThat(this.servletResponse.getContentLength(), equalTo(0)); + } + + @Test + public void httpConnectionAsync() throws Exception { + ServerHttpAsyncRequestControl async = mock(ServerHttpAsyncRequestControl.class); + ServerHttpRequest request = mock(ServerHttpRequest.class); + given(request.getAsyncRequestControl(this.response)).willReturn(async); + HttpConnection connection = new HttpConnection(request, this.response); + connection.waitForResponse(); + verify(async).start(); + connection.respond(HttpStatus.NO_CONTENT); + verify(async).complete(); + } + + @Test + public void httpConnectionNonAsync() throws Exception { + testHttpConnectionNonAsync(0); + testHttpConnectionNonAsync(100); + } + + private void testHttpConnectionNonAsync(long sleepBeforeResponse) throws IOException, + InterruptedException { + ServerHttpRequest request = mock(ServerHttpRequest.class); + given(request.getAsyncRequestControl(this.response)).willThrow( + new IllegalArgumentException()); + final HttpConnection connection = new HttpConnection(request, this.response); + final AtomicBoolean responded = new AtomicBoolean(); + Thread connectionThread = new Thread() { + + @Override + public void run() { + connection.waitForResponse(); + responded.set(true); + } + + }; + connectionThread.start(); + assertThat(responded.get(), equalTo(false)); + Thread.sleep(sleepBeforeResponse); + connection.respond(HttpStatus.NO_CONTENT); + connectionThread.join(); + assertThat(responded.get(), equalTo(true)); + } + + @Test + public void httpConnectionRunning() throws Exception { + HttpConnection connection = new HttpConnection(this.request, this.response); + assertThat(connection.isOlderThan(100), equalTo(false)); + Thread.sleep(200); + assertThat(connection.isOlderThan(100), equalTo(true)); + } + + /** + * Mock {@link ByteChannel} used to simulate the server connection. + */ + private static class MockServerChannel implements ByteChannel { + + private static final ByteBuffer DISCONNECT = ByteBuffer.wrap(NO_DATA); + + private int timeout; + + private BlockingDeque outgoing = new LinkedBlockingDeque(); + + private ByteArrayOutputStream written = new ByteArrayOutputStream(); + + private AtomicBoolean open = new AtomicBoolean(true); + + public void setTimeout(int timeout) { + this.timeout = timeout; + } + + public void send(String content) { + send(content.getBytes()); + } + + public void send(byte[] bytes) { + this.outgoing.addLast(ByteBuffer.wrap(bytes)); + } + + public void disconnect() { + this.outgoing.addLast(DISCONNECT); + } + + public void verifyReceived(String expected) { + verifyReceived(expected.getBytes()); + } + + public void verifyReceived(byte[] expected) { + synchronized (this.written) { + assertThat(this.written.toByteArray(), equalTo(expected)); + this.written.reset(); + } + } + + @Override + public int read(ByteBuffer dst) throws IOException { + try { + ByteBuffer bytes = this.outgoing.pollFirst(this.timeout, + TimeUnit.MILLISECONDS); + if (bytes == null) { + throw new SocketTimeoutException(); + } + if (bytes == DISCONNECT) { + this.open.set(false); + return -1; + } + int initialRemaining = dst.remaining(); + bytes.limit(Math.min(bytes.limit(), initialRemaining)); + dst.put(bytes); + bytes.limit(bytes.capacity()); + return initialRemaining - dst.remaining(); + } + catch (InterruptedException ex) { + throw new IllegalStateException(ex); + } + } + + @Override + public int write(ByteBuffer src) throws IOException { + int remaining = src.remaining(); + synchronized (this.written) { + Channels.newChannel(this.written).write(src); + } + return remaining; + } + + @Override + public boolean isOpen() { + return this.open.get(); + } + + @Override + public void close() throws IOException { + this.open.set(false); + } + + } + + /** + * Mock {@link HttpConnection}. + */ + private static class MockHttpConnection extends HttpConnection { + + public MockHttpConnection() { + super(new ServletServerHttpRequest(new MockHttpServletRequest()), + new ServletServerHttpResponse(new MockHttpServletResponse())); + } + + public MockHttpConnection(String content, int seq) { + this(); + MockHttpServletRequest request = getServletRequest(); + request.setContent(content.getBytes()); + request.addHeader(SEQ_HEADER, String.valueOf(seq)); + } + + @Override + protected ServerHttpAsyncRequestControl startAsync() { + getServletRequest().setAsyncSupported(true); + return super.startAsync(); + } + + @Override + protected void complete() { + super.complete(); + getServletResponse().setCommitted(true); + } + + public MockHttpServletRequest getServletRequest() { + return (MockHttpServletRequest) ((ServletServerHttpRequest) getRequest()) + .getServletRequest(); + } + + public MockHttpServletResponse getServletResponse() { + return (MockHttpServletResponse) ((ServletServerHttpResponse) getResponse()) + .getServletResponse(); + } + + public void verifyReceived(String expectedContent, int expectedSeq) + throws Exception { + waitForServletResponse(); + MockHttpServletResponse resp = getServletResponse(); + assertThat(resp.getContentAsString(), equalTo(expectedContent)); + assertThat(resp.getHeader(SEQ_HEADER), equalTo(String.valueOf(expectedSeq))); + } + + public void waitForServletResponse() throws InterruptedException { + while (!getServletResponse().isCommitted()) { + Thread.sleep(10); + } + } + + } + +} diff --git a/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/server/SocketTargetServerConnectionTests.java b/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/server/SocketTargetServerConnectionTests.java new file mode 100644 index 0000000000..e0032a9ac8 --- /dev/null +++ b/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/server/SocketTargetServerConnectionTests.java @@ -0,0 +1,178 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.server; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketTimeoutException; +import java.nio.ByteBuffer; +import java.nio.channels.ByteChannel; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.util.SocketUtils; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.lessThan; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; + +/** + * Tests for {@link SocketTargetServerConnection}. + * + * @author Phillip Webb + */ +public class SocketTargetServerConnectionTests { + + private static final int DEFAULT_TIMEOUT = 1000; + + private int port; + + private MockServer server; + + private SocketTargetServerConnection connection; + + @Before + public void setup() throws IOException { + this.port = SocketUtils.findAvailableTcpPort(); + this.server = new MockServer(this.port); + StaticPortProvider portProvider = new StaticPortProvider(this.port); + this.connection = new SocketTargetServerConnection(portProvider); + } + + @Test + public void readData() throws Exception { + this.server.willSend("hello".getBytes()); + this.server.start(); + ByteChannel channel = this.connection.open(DEFAULT_TIMEOUT); + ByteBuffer buffer = ByteBuffer.allocate(5); + channel.read(buffer); + assertThat(buffer.array(), equalTo("hello".getBytes())); + } + + @Test + public void writeData() throws Exception { + this.server.expect("hello".getBytes()); + this.server.start(); + ByteChannel channel = this.connection.open(DEFAULT_TIMEOUT); + ByteBuffer buffer = ByteBuffer.wrap("hello".getBytes()); + channel.write(buffer); + this.server.closeAndVerify(); + } + + @Test + public void timeout() throws Exception { + this.server.delay(1000); + this.server.start(); + ByteChannel channel = this.connection.open(10); + long startTime = System.currentTimeMillis(); + try { + channel.read(ByteBuffer.allocate(5)); + fail("No socket timeout thrown"); + } + catch (SocketTimeoutException ex) { + // Expected + long runTime = System.currentTimeMillis() - startTime; + assertThat(runTime, greaterThanOrEqualTo(10L)); + assertThat(runTime, lessThan(10000L)); + } + } + + private static class MockServer { + + private ServerSocketChannel serverSocket; + + private byte[] send; + + private byte[] expect; + + private int delay; + + private ByteBuffer actualRead; + + private ServerThread thread; + + public MockServer(int port) throws IOException { + this.serverSocket = ServerSocketChannel.open(); + this.serverSocket.bind(new InetSocketAddress(port)); + } + + public void delay(int delay) { + this.delay = delay; + } + + public void willSend(byte[] send) { + this.send = send; + } + + public void expect(byte[] expect) { + this.expect = expect; + } + + public void start() { + this.thread = new ServerThread(); + this.thread.start(); + } + + public void closeAndVerify() throws InterruptedException { + close(); + assertThat(this.actualRead.array(), equalTo(this.expect)); + } + + public void close() throws InterruptedException { + while (this.thread.isAlive()) { + Thread.sleep(10); + } + } + + private class ServerThread extends Thread { + + @Override + public void run() { + try { + SocketChannel channel = MockServer.this.serverSocket.accept(); + Thread.sleep(MockServer.this.delay); + if (MockServer.this.send != null) { + ByteBuffer buffer = ByteBuffer.wrap(MockServer.this.send); + while (buffer.hasRemaining()) { + channel.write(buffer); + } + } + if (MockServer.this.expect != null) { + ByteBuffer buffer = ByteBuffer + .allocate(MockServer.this.expect.length); + while (buffer.hasRemaining()) { + channel.read(buffer); + } + MockServer.this.actualRead = buffer; + } + channel.close(); + } + catch (Exception ex) { + ex.printStackTrace(); + fail(); + } + } + + } + + } + +} diff --git a/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/server/StaticPortProviderTests.java b/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/server/StaticPortProviderTests.java new file mode 100644 index 0000000000..88d5c8ab12 --- /dev/null +++ b/spring-boot-developer-tools/src/test/java/org/springframework/boot/developertools/tunnel/server/StaticPortProviderTests.java @@ -0,0 +1,49 @@ +/* + * Copyright 2012-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.developertools.tunnel.server; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; + +/** + * Tests for {@link StaticPortProvider}. + * + * @author Phillip Webb + */ +public class StaticPortProviderTests { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void portMustBePostive() throws Exception { + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("Port must be positive"); + new StaticPortProvider(0); + } + + @Test + public void getPort() throws Exception { + StaticPortProvider provider = new StaticPortProvider(123); + assertThat(provider.getPort(), equalTo(123)); + } + +}