package com.anjlab.csv2db; import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.UUID; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; import java.util.concurrent.Callable; import com.codahale.metrics.RatioGauge; import com.codahale.metrics.Timer; public class SharedBlockingQueueMediator implements Mediator { private final int[] deadQueueConsumers; private final int[] deadRouterProducers; private final BlockingQueue<Map<String, Object>>[] routerQueues; private final BlockingQueue<String[]> queue; private final Map<String, Object> terminalNameValues; private final String terminalMessage; private final Timer queuePuts; private final Timer queueTakes; private final Timer[] routerQueuePuts; private final Timer[] routerQueueTakes; @SuppressWarnings("unchecked") public SharedBlockingQueueMediator(Configuration config, int numberOfThreads) { // Each thread will take batch of lines with the size of batchSize from the queue, // that's why it's necessary to always have enough lines for those who read from its thread queue = new ArrayBlockingQueue<String[]>(config.getBatchSize() * numberOfThreads); terminalNameValues = new HashMap<>(); terminalMessage = UUID.randomUUID().toString(); deadRouterProducers = new int[numberOfThreads]; Arrays.fill(deadRouterProducers, 0); deadQueueConsumers = new int[numberOfThreads]; Arrays.fill(deadQueueConsumers, 0); if (config.isIgnoreDuplicatePK()) { routerQueues = new BlockingQueue[numberOfThreads]; if (Import.isMetricsEnabled()) { routerQueuePuts = new Timer[numberOfThreads]; routerQueueTakes = new Timer[numberOfThreads]; } else { routerQueuePuts = null; routerQueueTakes = null; } for (int i = 0; i < numberOfThreads; i++) { // This is likely not the best estimation of router's queue size, // but we need to limit it with some value routerQueues[i] = new ArrayBlockingQueue<>(config.getBatchSize() * numberOfThreads); if (Import.isMetricsEnabled()) { routerQueuePuts[i] = Import.METRIC_REGISTRY.timer("mediator.router." + i + ".queue.puts"); routerQueueTakes[i] = Import.METRIC_REGISTRY.timer("mediator.router." + i + ".queue.takes"); final int threadId = i; Import.registerMetric("mediator.router." + i + ".queue.ratio", new RatioGauge() { @Override protected Ratio getRatio() { return Ratio.of( routerQueuePuts[threadId].getOneMinuteRate(), routerQueueTakes[threadId].getOneMinuteRate()); } }); } } } else { routerQueues = null; routerQueuePuts = null; routerQueueTakes = null; } if (Import.isMetricsEnabled()) { queuePuts = Import.METRIC_REGISTRY.timer("mediator.queue.puts"); queueTakes = Import.METRIC_REGISTRY.timer("mediator.queue.takes"); Import.registerMetric("mediator.queue.ratio", new RatioGauge() { @Override protected Ratio getRatio() { return Ratio.of(queuePuts.getOneMinuteRate(), queueTakes.getOneMinuteRate()); } }); } else { queuePuts = null; queueTakes = null; } } @Override public void dispatch(String[] line) throws InterruptedException { Import.measureTime(queuePuts, new VoidCallable<InterruptedException>() { @Override public void run() throws InterruptedException { queue.put(line); } }); } @Override public void producerDone() throws InterruptedException { Import.measureTime(queuePuts, new VoidCallable<InterruptedException>() { @Override public void run() throws InterruptedException { queue.put(new String[] { terminalMessage }); } }); } @Override public Object take(int forThreadId) throws InterruptedException { if (isInTerminalPhase(forThreadId)) { return terminalPhaseTake(forThreadId); } while (routerQueueHasData(forThreadId)) { Object nameValues = takeFromRouter(forThreadId); if (nameValues == terminalNameValues) { deadRouterProducers[forThreadId]++; } else { return nameValues; } } String[] line = Import.measureTime(queueTakes, new Callable<String[]>() { @Override public String[] call() throws InterruptedException { return queue.take(); } }); if (!isTerminalLine(line)) { return line; } // Let other consumers know that producer has finished reading lines, // and there won't be new records in the shared queue producerDone(); if (isRouterEnabled()) { deadQueueConsumers[forThreadId]++; deadRouterProducers[forThreadId]++; // Notify other consumers that this thread has done processing shared queue, // and is waiting for confirmation from other consumers that they don't have any // messages for it for (int i = 0; i < routerQueues.length; i++) { if (i != forThreadId) { dispatch(terminalNameValues, i); } } return terminalPhaseTake(forThreadId); } // Empty array is a terminal line return new String[0]; } private boolean isInTerminalPhase(int threadId) { // Some other thread might completed earlier // which could increment deadRouterProducers for this thread // before it went through `producerDone` execution path. // That's why we can not rely on deadRouterProducers here. return deadQueueConsumers[threadId] > 0; } private Object takeFromRouter(int forThreadId) { Timer timer = routerQueueTakes == null ? null : routerQueueTakes[forThreadId]; return Import.measureTime(timer, new Callable<Object>() { @Override public Object call() throws InterruptedException { return routerQueues[forThreadId].take(); } }); } private Object terminalPhaseTake(int forThreadId) { while (deadRouterProducers[forThreadId] < deadRouterProducers.length) { Object nameValues = takeFromRouter(forThreadId); if (nameValues == terminalNameValues) { deadRouterProducers[forThreadId]++; } else { return nameValues; } } // Empty array is a terminal line return new String[0]; } private boolean routerQueueHasData(int forThreadId) { return isRouterEnabled() && !routerQueues[forThreadId].isEmpty(); } private boolean isRouterEnabled() { return routerQueues != null; } private boolean isTerminalLine(String[] line) { return line.length == 0 || terminalMessage.equals((line)[0]); } @Override public void consumerDone(int threadId) throws InterruptedException { Import.measureTime(queuePuts, new VoidCallable<InterruptedException>() { @Override public void run() throws InterruptedException { queue.put(new String[] { terminalMessage, String.valueOf(threadId) }); } }); } @Override public void dispatch(final Map<String, Object> nameValues, final int forThreadId) throws InterruptedException { Timer timer = routerQueuePuts == null ? null : routerQueuePuts[forThreadId]; Import.measureTime(timer, new VoidCallable<InterruptedException>() { @Override public void run() throws InterruptedException { routerQueues[forThreadId].put(nameValues); } }); } }