/* * Copyright 2017 ZhangJiupeng * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package cc.agentx.util.tunnel; import java.io.IOException; import java.net.ConnectException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.Set; /** * a transparent socket tunnel * <p> * data exchange with src will be automatically redirect to dst.<br/> * for users, data forwarding is hardly perceptible.<br/> * for developers, it is no need to change your external code.<br/> * </p> * just enjoy! */ public class SocketTunnel implements Runnable { private InetSocketAddress srcAddr; private InetSocketAddress dstAddr; private ServerSocketChannel server; private Map<SocketChannel, SocketChannel> bridge; private Selector selector; private ByteBuffer buffer; public SocketTunnel(InetSocketAddress src, InetSocketAddress dst) { this(src, dst, 65536); } public SocketTunnel(InetSocketAddress src, InetSocketAddress dst, int bufferSize) { this.srcAddr = src; this.dstAddr = dst; this.bridge = new HashMap<>(1 << 6, 0.75f); this.buffer = ByteBuffer.allocate(bufferSize); } public static void main(String[] args) throws IOException { if (args.length != 2) { System.out.println("Example - " + SocketTunnel.class.getName() + " 0.0.0.0:80 123.123.123.123:80"); return; } String[] src = args[0].split(":"); String[] dst = args[1].split(":"); new SocketTunnel( new InetSocketAddress(src[0], Integer.parseInt(src[1])), new InetSocketAddress(dst[0], Integer.parseInt(dst[1])) ).startup(); System.out.println("socket tunnel started!\t" + args[0] + " -> " + args[1]); } public void startup() throws IOException { selector = Selector.open(); server = ServerSocketChannel.open(); server.configureBlocking(false); server.socket().setReuseAddress(true); server.socket().bind(srcAddr); server.register(selector, SelectionKey.OP_ACCEPT); } public void shutdown() throws IOException { if (server.isOpen()) { server.close(); } selector.selectNow(); buffer.clear(); } public void restart() throws IOException { server.close(); selector.selectNow(); server = ServerSocketChannel.open(); server.configureBlocking(false); server.socket().setReuseAddress(true); server.socket().bind(srcAddr); server.register(selector, SelectionKey.OP_ACCEPT); bridge.clear(); } @Override public void run() { try { while (true) { selector.select(); Set<SelectionKey> keys = selector.selectedKeys(); Iterator<SelectionKey> i = keys.iterator(); if (i.hasNext()) { SelectionKey key = i.next(); keys.remove(key); handleEvent(key); } } } catch (IOException e) { e.printStackTrace(); } } public void handleEvent(SelectionKey key) { if (key.isAcceptable()) buildConnection(key); else if (key.isReadable()) transferData(key); } private void buildConnection(SelectionKey key) { try { // connect to dstAddr, sign bridge SocketChannel dstSocket; try { dstSocket = SocketChannel.open(dstAddr); } catch (ConnectException ce) { System.err.print("connection broke (" + ce.getMessage() + "), restarting... "); restart(); System.err.println("ok!"); return; } SocketChannel srcSocket; srcSocket = server.accept(); srcSocket.configureBlocking(false); srcSocket.socket().setSoLinger(true, 0); srcSocket.register(selector, SelectionKey.OP_READ); // copy src-socket attributes dstSocket.socket().setReceiveBufferSize(srcSocket.socket().getReceiveBufferSize()); dstSocket.socket().setSoTimeout(srcSocket.socket().getSoTimeout()); dstSocket.socket().setTcpNoDelay(srcSocket.socket().getTcpNoDelay()); dstSocket.socket().setKeepAlive(srcSocket.socket().getKeepAlive()); dstSocket.socket().setOOBInline(srcSocket.socket().getOOBInline()); // urgent data dstSocket.socket().setSoLinger(true, 0); dstSocket.configureBlocking(false); bridge.put(srcSocket, dstSocket); bridge.put(dstSocket, srcSocket); dstSocket.register(selector, SelectionKey.OP_READ); } catch (IOException e) { e.printStackTrace(); } key.interestOps(SelectionKey.OP_ACCEPT); } public void transferData(SelectionKey key) { SocketChannel activeSocket = (SocketChannel) key.channel(); SocketChannel passiveSocket = bridge.get(activeSocket); try { // throws when closed activeSocket.read(buffer); byte[] bytes = new byte[buffer.position()]; if (buffer.position() == 0) { // end of stream doRecycle(activeSocket, passiveSocket, key); return; } else { System.arraycopy(buffer.array(), 0, bytes, 0, bytes.length); } buffer.flip(); if (activeSocket.socket().getLocalSocketAddress() == srcAddr) { passiveSocket.write(buffer); } else { passiveSocket.write(buffer); } buffer.clear(); } catch (Exception e) { doRecycle(activeSocket, passiveSocket, key); } finally { buffer.clear(); } } public void doRecycle(SocketChannel activeSocket, SocketChannel passiveSocket, SelectionKey key) { try { if (passiveSocket != null) passiveSocket.close(); } catch (IOException ioe) { ioe.printStackTrace(); } finally { bridge.remove(passiveSocket); bridge.remove(activeSocket); key.cancel(); } } public InetSocketAddress getDstAddr() { return dstAddr; } public InetSocketAddress getSrcAddr() { return srcAddr; } }