/* Copyright (c) 2011 Danish Maritime Authority. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package test.util; import static java.util.Objects.requireNonNull; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketAddress; import java.util.UUID; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; /** * * @author Kasper Nielsen */ public class ProxyTester { final CopyOnWriteArrayList<Connection> connections = new CopyOnWriteArrayList<>(); final SocketAddress proxyAddress; final SocketAddress remoteAddress; final ExecutorService es = Executors.newSingleThreadExecutor(); final ScheduledExecutorService ses = Executors.newSingleThreadScheduledExecutor(); volatile ServerSocket ss; volatile CountDownLatch pause = new CountDownLatch(0); public ProxyTester(SocketAddress proxyAddress, SocketAddress remoteAddress) { this.proxyAddress = requireNonNull(proxyAddress); this.remoteAddress = requireNonNull(remoteAddress); } public synchronized void pause() { if (pause.getCount() == 0) { pause = new CountDownLatch(1); } } public void noPause() { pause.countDown(); } public void start() throws IOException { ss = new ServerSocket(); ss.bind(proxyAddress); System.out.println("Starting proxy"); es.submit(new Runnable() { public void run() { for (;;) { try { // System.out.println("XX"); // pause.await(); // System.out.println("ZZ"); final Socket in = ss.accept(); // System.out.println("PP"); // pause.await(); // System.out.println("TT"); Socket out = new Socket(); out.connect(remoteAddress); Connection con = new Connection(in, out); con.inToOut.start(); con.outToIn.start(); connections.add(con); // System.out.println("Adding proxy connection " + con); } catch (Throwable t) { t.printStackTrace(); return; } } } }); } public void killRandom() { while (!connections.isEmpty()) { Connection[] a = connections.toArray(new Connection[connections.size()]); if (a.length > 0) { Connection con = a[ThreadLocalRandom.current().nextInt(a.length)]; if (connections.remove(con)) { close(con); return; } } } } public Future<?> killRandom(long time, TimeUnit unit) { return ses.scheduleWithFixedDelay(new Runnable() { @Override public void run() { killRandom(); } }, 0, time, unit); } public void killAll() { while (!connections.isEmpty()) { killRandom(); } } public void shutdown() throws InterruptedException { pause.countDown(); ses.shutdown(); es.shutdown(); try { ss.close(); } catch (IOException e) { e.printStackTrace(); } killAll(); es.awaitTermination(10, TimeUnit.SECONDS); } private void close(Connection c) { if (c != null) { try { c.incoming.close(); } catch (IOException e) { e.printStackTrace(); } try { c.outgoing.close(); } catch (IOException e) { e.printStackTrace(); } // System.out.println("CLOSING Proxy " + c); } } static class Connection { final Socket incoming; final Socket outgoing; final Thread inToOut; final Thread outToIn; final String id = UUID.randomUUID().toString(); Connection(Socket incoming, Socket outgoing) { this.incoming = requireNonNull(incoming); this.outgoing = requireNonNull(outgoing); this.inToOut = new Thread(() -> inToOut()); this.outToIn = new Thread(() -> outToIn()); } void inToOut() { for (;;) { try { byte[] buffer = new byte[1024]; // Adjust if you want int bytesRead; InputStream is = incoming.getInputStream(); OutputStream os = outgoing.getOutputStream(); while ((bytesRead = is.read(buffer)) != -1) { os.write(buffer, 0, bytesRead); } } catch (Throwable t) { return; } finally { try { // System.out.println("Proxy " + id + " Incoming closed"); incoming.close(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } } } public String toString() { return id + " [" + incoming.getLocalPort() + " -> " + outgoing.getPort() + "]"; } void outToIn() { for (;;) { try { byte[] buffer = new byte[1024]; // Adjust if you want int bytesRead; InputStream is = outgoing.getInputStream(); OutputStream os = incoming.getOutputStream(); while ((bytesRead = is.read(buffer)) != -1) { os.write(buffer, 0, bytesRead); } } catch (Throwable t) { return; } finally { try { // System.out.println("Proxy " + id + " Outgoing closed"); outgoing.close(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } } } } }