/* Copyright (c) 2011 Danish Maritime Authority.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package test.util;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.net.URI;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSocketFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class SocketProxy {
static final transient Logger LOG = LoggerFactory.getLogger(SocketProxy.class);
public static final int ACCEPT_TIMEOUT_MILLIS = 100;
private URI proxyUrl;
private URI target;
private Acceptor acceptor;
private ServerSocket serverSocket;
CountDownLatch closed = new CountDownLatch(1);
public List<Bridge> connections = new LinkedList<>();
private int listenPort;
int receiveBufferSize = -1;
private boolean pauseAtStart;
private int acceptBacklog = 50;
public SocketProxy() throws Exception {}
public SocketProxy(URI uri) throws Exception {
this(0, uri);
}
public SocketProxy(int port, URI uri) throws Exception {
listenPort = port;
target = uri;
open();
}
public void setReceiveBufferSize(int receiveBufferSize) {
this.receiveBufferSize = receiveBufferSize;
}
public void setTarget(URI tcpBrokerUri) {
target = tcpBrokerUri;
}
public void open() throws Exception {
serverSocket = createServerSocket(target);
serverSocket.setReuseAddress(true);
if (receiveBufferSize > 0) {
serverSocket.setReceiveBufferSize(receiveBufferSize);
}
if (proxyUrl == null) {
serverSocket.bind(new InetSocketAddress(listenPort), acceptBacklog);
proxyUrl = urlFromSocket(target, serverSocket);
} else {
serverSocket.bind(new InetSocketAddress(proxyUrl.getPort()));
}
acceptor = new Acceptor(serverSocket, target);
if (pauseAtStart) {
acceptor.pause();
}
new Thread(null, acceptor, "SocketProxy-Acceptor-" + serverSocket.getLocalPort()).start();
closed = new CountDownLatch(1);
}
private boolean isSsl(URI target) {
return "ssl".equals(target.getScheme());
}
private ServerSocket createServerSocket(URI target) throws Exception {
if (isSsl(target)) {
return SSLServerSocketFactory.getDefault().createServerSocket();
}
return new ServerSocket();
}
Socket createSocket(URI target) throws Exception {
if (isSsl(target)) {
return SSLSocketFactory.getDefault().createSocket();
}
return new Socket();
}
public URI getUrl() {
return proxyUrl;
}
/*
* close all proxy connections and acceptor
*/
public void close() {
List<Bridge> connections;
synchronized (this.connections) {
connections = new ArrayList<>(this.connections);
}
LOG.info("close, numConnections=" + connections.size());
for (Bridge con : connections) {
closeConnection(con);
}
acceptor.close();
closed.countDown();
}
/*
* close all proxy receive connections, leaving acceptor open
*/
public void halfClose() {
List<Bridge> connections;
synchronized (this.connections) {
connections = new ArrayList<>(this.connections);
}
LOG.info("halfClose, numConnections=" + connections.size());
for (Bridge con : connections) {
halfCloseConnection(con);
}
}
public boolean waitUntilClosed(long timeoutSeconds) throws InterruptedException {
return closed.await(timeoutSeconds, TimeUnit.SECONDS);
}
/*
* called after a close to restart the acceptor on the same port
*/
public void reopen() {
LOG.info("reopen");
try {
open();
} catch (Exception e) {
LOG.debug("exception on reopen url:" + getUrl(), e);
}
}
/*
* pause accepting new connections and data transfer through existing proxy connections. All sockets remain open
*/
public void pause() {
synchronized (connections) {
LOG.info("pause, numConnections=" + connections.size());
acceptor.pause();
for (Bridge con : connections) {
con.pause();
}
}
}
/*
* continue after pause
*/
public void goOn() {
synchronized (connections) {
LOG.info("goOn, numConnections=" + connections.size());
for (Bridge con : connections) {
con.goOn();
}
}
acceptor.goOn();
}
private void closeConnection(Bridge c) {
try {
c.close();
} catch (Exception e) {
LOG.debug("exception on close of: " + c, e);
}
}
private void halfCloseConnection(Bridge c) {
try {
c.halfClose();
} catch (Exception e) {
LOG.debug("exception on half close of: " + c, e);
}
}
public boolean isPauseAtStart() {
return pauseAtStart;
}
public void setPauseAtStart(boolean pauseAtStart) {
this.pauseAtStart = pauseAtStart;
}
public int getAcceptBacklog() {
return acceptBacklog;
}
public void setAcceptBacklog(int acceptBacklog) {
this.acceptBacklog = acceptBacklog;
}
private URI urlFromSocket(URI uri, ServerSocket serverSocket) throws Exception {
int listenPort = serverSocket.getLocalPort();
return new URI(uri.getScheme(), uri.getUserInfo(), uri.getHost(), listenPort, uri.getPath(), uri.getQuery(),
uri.getFragment());
}
public class Bridge {
Socket receiveSocket;
private Socket sendSocket;
private Pump requestThread;
private Pump responseThread;
public Bridge(Socket socket, URI target) throws Exception {
receiveSocket = socket;
sendSocket = createSocket(target);
if (receiveBufferSize > 0) {
sendSocket.setReceiveBufferSize(receiveBufferSize);
}
sendSocket.connect(new InetSocketAddress(target.getHost(), target.getPort()));
linkWithThreads(receiveSocket, sendSocket);
LOG.info("proxy connection " + sendSocket + ", receiveBufferSize=" + sendSocket.getReceiveBufferSize());
}
public void goOn() {
responseThread.goOn();
requestThread.goOn();
}
public void pause() {
requestThread.pause();
responseThread.pause();
}
public void close() throws Exception {
synchronized (connections) {
connections.remove(this);
}
receiveSocket.close();
sendSocket.close();
}
public void halfClose() throws Exception {
receiveSocket.close();
}
private void linkWithThreads(Socket source, Socket dest) {
requestThread = new Pump(source, dest);
requestThread.start();
responseThread = new Pump(dest, source);
responseThread.start();
}
public class Pump extends Thread {
protected Socket src;
private Socket destination;
private AtomicReference<CountDownLatch> pause = new AtomicReference<>();
public Pump(Socket source, Socket dest) {
super("SocketProxy-DataTransfer-" + source.getPort() + ":" + dest.getPort());
src = source;
destination = dest;
pause.set(new CountDownLatch(0));
}
public void pause() {
pause.set(new CountDownLatch(1));
}
public void goOn() {
pause.get().countDown();
}
public void run() {
byte[] buf = new byte[1024];
try {
InputStream in = src.getInputStream();
OutputStream out = destination.getOutputStream();
while (true) {
int len = in.read(buf);
if (len == -1) {
LOG.debug("read eof from:" + src);
break;
}
pause.get().await();
out.write(buf, 0, len);
}
} catch (Exception e) {
LOG.debug("read/write failed, reason: " + e.getLocalizedMessage());
try {
if (!receiveSocket.isClosed()) {
// for halfClose, on read/write failure if we close the
// remote end will see a close at the same time.
close();
}
} catch (Exception ignore) {}
}
}
}
}
public class Acceptor implements Runnable {
private ServerSocket socket;
private URI target;
private AtomicReference<CountDownLatch> pause = new AtomicReference<>();
public Acceptor(ServerSocket serverSocket, URI uri) {
socket = serverSocket;
target = uri;
pause.set(new CountDownLatch(0));
try {
socket.setSoTimeout(ACCEPT_TIMEOUT_MILLIS);
} catch (SocketException e) {
e.printStackTrace();
}
}
public void pause() {
pause.set(new CountDownLatch(1));
}
public void goOn() {
pause.get().countDown();
}
public void run() {
try {
while (!socket.isClosed()) {
pause.get().await();
try {
Socket source = socket.accept();
pause.get().await();
if (receiveBufferSize > 0) {
source.setReceiveBufferSize(receiveBufferSize);
}
LOG.info("accepted " + source + ", receiveBufferSize:" + source.getReceiveBufferSize());
synchronized (connections) {
connections.add(new Bridge(source, target));
}
} catch (SocketTimeoutException expected) {}
}
} catch (Exception e) {
LOG.debug("acceptor: finished for reason: " + e.getLocalizedMessage());
}
}
public void close() {
try {
socket.close();
closed.countDown();
goOn();
} catch (IOException ignored) {}
}
}
}