/**
* Copyright 2015 StreamSets Inc.
*
* Licensed under the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.streamsets.pipeline.cluster;
import com.streamsets.pipeline.BootstrapCluster;
import com.streamsets.pipeline.EmbeddedSDC;
import com.streamsets.pipeline.EmbeddedSDCPool;
import com.streamsets.pipeline.api.impl.ClusterSource;
import com.streamsets.pipeline.api.impl.Utils;
import com.streamsets.pipeline.impl.ClusterFunction;
import com.streamsets.pipeline.spark.RecordCloner;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.lang.reflect.Method;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Properties;
public class ClusterFunctionImpl implements ClusterFunction {
private static final Logger LOG = LoggerFactory.getLogger(ClusterFunctionImpl.class);
private static final boolean IS_TRACE_ENABLED = LOG.isTraceEnabled();
private static volatile EmbeddedSDCPool sdcPool;
private static volatile boolean initialized = false;
private static volatile String errorStackTrace;
private volatile Object offset;
private static final String GET_BATCH = "getBatch";
private static final String SET_ERRORS = "setErrors";
private static final String CONTINUE_PROCESSING = "continueProcessing";
private static synchronized void initialize(Properties properties, Integer id, String rootDataDir) throws Exception {
if (initialized) {
return;
}
File dataDir = new File(System.getProperty("user.dir"), "data");
FileUtils.copyDirectory(new File(rootDataDir), dataDir);
System.setProperty("sdc.data.dir", dataDir.getAbsolutePath());
// must occur before creating the EmbeddedSDCPool as
// the hdfs target validation evaluates the sdc:id EL
NumberFormat numberFormat = NumberFormat.getInstance();
numberFormat.setMinimumIntegerDigits(6);
numberFormat.setGroupingUsed(false);
final String sdcId = numberFormat.format(id);
Utils.setSdcIdCallable(() -> sdcId);
sdcPool = EmbeddedSDCPool.createPool(properties);
initialized = true;
}
public void setSparkProcessorCount(int count) {
sdcPool.setSparkProcessorCount(count);
}
public static synchronized boolean isInitialized() {
return initialized;
}
public static ClusterFunction create(Properties properties, Integer id, String rootDataDir) throws Exception {
initialize(Utils.checkNotNull(properties, "Properties"), id, rootDataDir);
ClusterFunctionImpl fn = new ClusterFunctionImpl();
fn.setSparkProcessorCount(BootstrapCluster.getSparkProcessorLibraryNames().size());
return fn;
}
@Override
public Iterable startBatch(List<Map.Entry> batch) throws Exception {
if (IS_TRACE_ENABLED) {
LOG.trace("In executor function " + " " + Thread.currentThread().getName() + ": " + batch.size());
}
if (errorStackTrace != null) {
LOG.info("Not proceeding as error in previous run");
throw new RuntimeException(errorStackTrace);
}
EmbeddedSDC sdc = null;
try {
sdc = sdcPool.getNotStartedSDC();
ClusterSource source = sdc.getSource();
offset = source.put(batch);
return getNextBatch(0, sdc);
} catch (Exception | Error e) {
// Get the stacktrace as string as the spark driver wont have the jars
// required to deserialize the classes from the exception cause
errorStackTrace = getErrorStackTrace(e);
throw new RuntimeException(errorStackTrace);
} finally {
if (sdc != null) {
if (IS_TRACE_ENABLED) {
LOG.trace("Checking SDC: " + sdc + " back in after starting batch");
}
try {
sdcPool.checkInAfterReadingBatch(0, sdc);
} catch (Exception ex) {
errorStackTrace = getErrorStackTrace(ex);
throw new RuntimeException(errorStackTrace);
}
}
}
}
@SuppressWarnings("unchecked")
public static Iterable<Object> getNextBatch(int id, EmbeddedSDC sdc) {
return Optional.ofNullable(sdc.getSparkProcessorAt(id)).flatMap(t -> {
try {
Object processor = t.getClass().getMethod("get").invoke(t);
final Method getBatch = processor.getClass().getDeclaredMethod(GET_BATCH);
List<Object> cloned = new ArrayList<>();
for (Object record : (Iterable<Object>) getBatch.invoke(processor)) {
cloned.add(RecordCloner.clone(record));
}
return Optional.of(cloned);
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}).orElse(Collections.emptyList());
}
@Override
public void writeErrorRecords(List errors, int id) throws Exception {
EmbeddedSDC sdc = sdcPool.getSDCBatchRead(id);
try {
if (IS_TRACE_ENABLED) {
if (sdc == null) {
LOG.trace("No SDC at ID " + id);
} else {
LOG.trace("Writing " + errors.size() + " errors to SDC " + sdc.toString());
}
}
writeErrorsToProcessor(errors, id, sdc);
} finally {
if (sdc != null) {
if (IS_TRACE_ENABLED) {
LOG.trace("Checking SDC: " + sdc +" back in after writing errors for id " + id);
}
// No exception since this does not proceed to next batch, so no waitForCommit call.
sdcPool.checkInAfterReadingBatch(id, sdc);
}
}
}
public static void writeErrorsToProcessor(List errors, int id, EmbeddedSDC sdc) {
Optional.ofNullable(sdc.getSparkProcessorAt(id)).ifPresent(t -> {
try {
Object processor = t.getClass().getMethod("get").invoke(t);
final Method setErrors = processor.getClass().getDeclaredMethod(SET_ERRORS, List.class);
setErrors.invoke(processor, errors);
} catch (Exception ex) {
throw new RuntimeException(ex);
}
});
}
@Override
public Iterable forwardTransformedBatch(Iterator<Object> batch, int id) throws Exception {
EmbeddedSDC sdc = sdcPool.getSDCBatchRead(id);
try {
if (IS_TRACE_ENABLED) {
if (sdc == null) {
LOG.trace("No SDC at ID " + id);
} else {
LOG.trace("Writing batch to SDC " + sdc.toString());
}
}
return writeTransformedToProcessor(batch, id, sdc);
} finally {
if (sdc != null) {
if (IS_TRACE_ENABLED) {
LOG.trace("Checking SDC: " + sdc +" back in after writing batch for id " + id + 1);
}
try {
sdcPool.checkInAfterReadingBatch(id + 1, sdc);
} catch (Exception ex) {
errorStackTrace = getErrorStackTrace(ex);
throw new RuntimeException(errorStackTrace);
}
}
}
}
@SuppressWarnings("unchecked")
public static Iterable<Object> writeTransformedToProcessor(Iterator<Object> batch, int id, EmbeddedSDC sdc) {
return Optional.ofNullable(sdc.getSparkProcessorAt(id)).map(t -> {
try {
Class sparkDProcessorClass = t.getClass();
Object processor = sparkDProcessorClass.getMethod("get").invoke(t);
final Method continueProcessing = processor.getClass().getDeclaredMethod(CONTINUE_PROCESSING, Iterator.class);
continueProcessing.invoke(processor, batch);
return getNextBatch(id + 1, sdc);
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}).orElse(Collections.emptyList());
}
private String getErrorStackTrace(Throwable e) {
StringWriter errorWriter = new StringWriter();
PrintWriter pw = new PrintWriter(errorWriter);
pw.println();
e.printStackTrace(pw);
pw.flush();
return errorWriter.toString();
}
@Override
public void shutdown() throws Exception {
LOG.info("Shutdown");
Utils.checkState(initialized, "Not initialized");
sdcPool.shutdown();
}
}