/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.tajo.plan.function.python; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; import org.apache.tajo.BuiltinStorages; import org.apache.tajo.TajoConstants; import org.apache.tajo.catalog.*; import org.apache.tajo.catalog.proto.CatalogProtos; import org.apache.tajo.common.TajoDataTypes; import org.apache.tajo.conf.TajoConf; import org.apache.tajo.datum.Datum; import org.apache.tajo.function.FunctionInvocation; import org.apache.tajo.function.FunctionSignature; import org.apache.tajo.function.FunctionSupplement; import org.apache.tajo.function.UDFInvocationDesc; import org.apache.tajo.plan.function.FunctionContext; import org.apache.tajo.plan.function.PythonAggFunctionInvoke.PythonAggFunctionContext; import org.apache.tajo.plan.function.stream.*; import org.apache.tajo.storage.Tuple; import org.apache.tajo.storage.VTuple; import org.apache.tajo.unit.StorageUnit; import org.apache.tajo.util.FileUtil; import org.apache.tajo.util.TUtil; import java.io.*; import java.net.URI; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.regex.Pattern; /** * {@link PythonScriptEngine} is responsible for registering python functions and maintaining the controller process. * The controller is a python process that executes the python UDFs. * (Please refer to 'tajo-core/src/main/resources/python/controller.py') * Data are exchanged via standard I/O between PythonScriptEngine and the controller. */ public class PythonScriptEngine extends TajoScriptEngine { private static final Log LOG = LogFactory.getLog(PythonScriptEngine.class); public static final String FILE_EXTENSION = ".py"; /** * Register functions defined in a python script * * @param path path to the python script file * @param namespace namespace where the functions will be defined * @return set of function descriptions * @throws IOException */ public static Set<FunctionDesc> registerFunctions(URI path, String namespace) throws IOException { // TODO: we should support the namespace for python functions. Set<FunctionDesc> functionDescs = new HashSet<>(); List<FunctionInfo> functions = null; try (InputStream in = getScriptAsStream(path)) { functions = getFunctions(in); } for(FunctionInfo funcInfo : functions) { try { FunctionSignature signature; FunctionInvocation invocation = new FunctionInvocation(); FunctionSupplement supplement = new FunctionSupplement(); if (funcInfo instanceof ScalarFuncInfo) { ScalarFuncInfo scalarFuncInfo = (ScalarFuncInfo) funcInfo; TajoDataTypes.DataType returnType = getReturnTypes(scalarFuncInfo)[0]; signature = new FunctionSignature(CatalogProtos.FunctionType.UDF, scalarFuncInfo.funcName, returnType, createParamTypes(scalarFuncInfo.paramNum)); UDFInvocationDesc invocationDesc = new UDFInvocationDesc(CatalogProtos.UDFtype.PYTHON, scalarFuncInfo.funcName, path.getPath(), true); invocation.setUDF(invocationDesc); functionDescs.add(new FunctionDesc(signature, invocation, supplement)); } else { AggFuncInfo aggFuncInfo = (AggFuncInfo) funcInfo; if (isValidUdaf(aggFuncInfo)) { TajoDataTypes.DataType returnType = getReturnTypes(aggFuncInfo.getFinalResultInfo)[0]; signature = new FunctionSignature(CatalogProtos.FunctionType.UDA, aggFuncInfo.funcName, returnType, createParamTypes(aggFuncInfo.evalInfo.paramNum)); UDFInvocationDesc invocationDesc = new UDFInvocationDesc(CatalogProtos.UDFtype.PYTHON, aggFuncInfo.className, path.getPath(), false); invocation.setUDF(invocationDesc); functionDescs.add(new FunctionDesc(signature, invocation, supplement)); } } } catch (Throwable t) { // ignore invalid functions LOG.warn("Cannot parse function " + funcInfo, t); } } return functionDescs; } private static boolean isValidUdaf(AggFuncInfo aggFuncInfo) { if (aggFuncInfo.className != null && aggFuncInfo.evalInfo != null && aggFuncInfo.mergeInfo != null && aggFuncInfo.getPartialResultInfo != null && aggFuncInfo.getFinalResultInfo != null) { return true; } return false; } private static TajoDataTypes.DataType[] createParamTypes(int paramNum) { TajoDataTypes.DataType[] paramTypes = new TajoDataTypes.DataType[paramNum]; for (int i = 0; i < paramNum; i++) { paramTypes[i] = TajoDataTypes.DataType.newBuilder().setType(TajoDataTypes.Type.ANY).build(); } return paramTypes; } private static TajoDataTypes.DataType[] getReturnTypes(ScalarFuncInfo scalarFuncInfo) { TajoDataTypes.DataType[] returnTypes = new TajoDataTypes.DataType[scalarFuncInfo.returnTypes.length]; for (int i = 0; i < scalarFuncInfo.returnTypes.length; i++) { returnTypes[i] = CatalogUtil.newSimpleDataType(TajoDataTypes.Type.valueOf(scalarFuncInfo.returnTypes[i])); } return returnTypes; } private static final Pattern pSchema = Pattern.compile("^\\s*\\W+output_type.*"); private static final Pattern pDef = Pattern.compile("^\\s*def\\s+(\\w+)\\s*.+"); private static final Pattern pClass = Pattern.compile("class.*"); private interface FunctionInfo { } private static class AggFuncInfo implements FunctionInfo { String className; String funcName; ScalarFuncInfo evalInfo; ScalarFuncInfo mergeInfo; ScalarFuncInfo getPartialResultInfo; ScalarFuncInfo getFinalResultInfo; @Override public String toString() { return className + "." + funcName; } } private static class ScalarFuncInfo implements FunctionInfo { String[] returnTypes; String funcName; int paramNum; public ScalarFuncInfo(String[] quotedSchemaStrings, String funcName, int paramNum) { this.returnTypes = new String[quotedSchemaStrings.length]; for (int i = 0; i < quotedSchemaStrings.length; i++) { String quotedString = quotedSchemaStrings[i].trim(); String[] tokens = quotedString.substring(1, quotedString.length()-1).split(":"); this.returnTypes[i] = tokens.length == 1 ? tokens[0].toUpperCase() : tokens[1].toUpperCase(); } this.funcName = funcName; this.paramNum = paramNum; } @Override public String toString() { return funcName; } } // TODO: python parser must be improved. private static List<FunctionInfo> getFunctions(InputStream is) throws IOException { List<FunctionInfo> functions = new ArrayList<>(); InputStreamReader in = new InputStreamReader(is, Charset.defaultCharset()); BufferedReader br = new BufferedReader(in); String line = br.readLine(); String[] quotedSchemaStrings = null; AggFuncInfo aggFuncInfo = null; while (line != null) { try { if (pSchema.matcher(line).matches()) { int start = line.indexOf("(") + 1; //drop brackets int end = line.lastIndexOf(")"); quotedSchemaStrings = line.substring(start, end).trim().split(","); } else if (pDef.matcher(line).matches()) { boolean isUdaf = aggFuncInfo != null && line.indexOf("def ") > 0; int nameStart = line.indexOf("def ") + "def ".length(); int nameEnd = line.indexOf('('); int signatureEnd = line.indexOf(')'); String[] params = line.substring(nameEnd + 1, signatureEnd).split(","); int paramNum; if (params.length == 1) { paramNum = params[0].equals("") ? 0 : 1; } else { paramNum = params.length; } if (params[0].trim().equals("self")) { paramNum--; } String functionName = line.substring(nameStart, nameEnd).trim(); quotedSchemaStrings = quotedSchemaStrings == null ? new String[]{"'blob'"} : quotedSchemaStrings; if (isUdaf) { if (functionName.equals("eval")) { aggFuncInfo.evalInfo = new ScalarFuncInfo(quotedSchemaStrings, functionName, paramNum); } else if (functionName.equals("merge")) { aggFuncInfo.mergeInfo = new ScalarFuncInfo(quotedSchemaStrings, functionName, paramNum); } else if (functionName.equals("get_partial_result")) { aggFuncInfo.getPartialResultInfo = new ScalarFuncInfo(quotedSchemaStrings, functionName, paramNum); } else if (functionName.equals("get_final_result")) { aggFuncInfo.getFinalResultInfo = new ScalarFuncInfo(quotedSchemaStrings, functionName, paramNum); } } else { aggFuncInfo = null; functions.add(new ScalarFuncInfo(quotedSchemaStrings, functionName, paramNum)); } quotedSchemaStrings = null; } else if (pClass.matcher(line).matches()) { // UDAF if (aggFuncInfo != null) { functions.add(aggFuncInfo); } aggFuncInfo = new AggFuncInfo(); int classNameStart = line.indexOf("class ") + "class ".length(); int classNameEnd = line.indexOf("("); if (classNameEnd < 0) { classNameEnd = line.indexOf(":"); } aggFuncInfo.className = line.substring(classNameStart, classNameEnd).trim(); aggFuncInfo.funcName = aggFuncInfo.className.toLowerCase(); } } catch (Throwable t) { // ignore unexpected function and source lines LOG.warn(t); } line = br.readLine(); } if (aggFuncInfo != null) { functions.add(aggFuncInfo); } br.close(); in.close(); return functions; } private static final String PYTHON_LANGUAGE = "python"; private static final String TAJO_UTIL_NAME = "tajo_util.py"; private static final String CONTROLLER_NAME = "controller.py"; private static final String BASE_DIR = FileUtils.getTempDirectoryPath() + File.separator + "tajo-" + System.getProperty("user.name") + File.separator + "python"; private static final String PYTHON_CONTROLLER_JAR_PATH = "/python/" + CONTROLLER_NAME; // Relative to root of tajo jar. private static final String PYTHON_TAJO_UTIL_JAR_PATH = "/python/" + TAJO_UTIL_NAME; // Relative to root of tajo jar. // Indexes for arguments being passed to external process enum COMMAND_IDX { UDF_LANGUAGE, PATH_TO_CONTROLLER_FILE, UDF_FILE_NAME, UDF_FILE_PATH, PATH_TO_FILE_CACHE, OUT_SCHEMA, FUNCTION_OR_CLASS_NAME, FUNCTION_TYPE, } private TajoConf systemConf; private Process process; // Handle to the external execution of python functions private InputHandler inputHandler; private OutputHandler outputHandler; private DataOutputStream stdin; // stdin of the process private InputStream stdout; // stdout of the process private InputStream stderr; // stderr of the process private final FunctionSignature functionSignature; private final UDFInvocationDesc invocationDesc; private Schema inSchema; private Schema outSchema; private int[] projectionCols; private final CSVLineSerDe lineSerDe = new CSVLineSerDe(); private TableMeta pipeMeta; private final Tuple EMPTY_INPUT = new VTuple(0); private final Schema EMPTY_SCHEMA = SchemaBuilder.builder().build(); public PythonScriptEngine(FunctionDesc functionDesc) { if (!functionDesc.getInvocation().hasPython() && !functionDesc.getInvocation().hasPythonAggregation()) { throw new IllegalStateException("Function type must be 'python'"); } functionSignature = functionDesc.getSignature(); invocationDesc = functionDesc.getInvocation().getUDF(); setSchema(); } public PythonScriptEngine(FunctionDesc functionDesc, boolean firstPhase, boolean lastPhase) { if (!functionDesc.getInvocation().hasPython() && !functionDesc.getInvocation().hasPythonAggregation()) { throw new IllegalStateException("Function type must be 'python'"); } functionSignature = functionDesc.getSignature(); invocationDesc = functionDesc.getInvocation().getUDF(); this.firstPhase = firstPhase; this.lastPhase = lastPhase; setSchema(); } @Override public void start(Configuration conf) throws IOException { this.systemConf = TUtil.checkTypeAndGet(conf, TajoConf.class); this.pipeMeta = CatalogUtil.newTableMeta(BuiltinStorages.TEXT, systemConf); startUdfController(); setStreams(); createInputHandlers(); if (LOG.isDebugEnabled()) { LOG.debug("PythonScriptExecutor starts up"); } } @Override public void shutdown() { FileUtil.cleanup(LOG, stdin, stdout, stderr, inputHandler, outputHandler); stdin = null; stdout = stderr = null; inputHandler = null; outputHandler = null; try { int exitCode = process.waitFor(); if (systemConf.get(TajoConstants.TEST_KEY, Boolean.FALSE.toString()).equalsIgnoreCase(Boolean.TRUE.toString())) { LOG.warn("Process exit code: " + exitCode); } else { if (LOG.isDebugEnabled()) { LOG.debug("Process exit code: " + exitCode); } } } catch (InterruptedException e) { LOG.warn(e.getMessage(), e); } if (LOG.isDebugEnabled()) { LOG.debug("PythonScriptExecutor shuts down"); } } private void startUdfController() throws IOException { ProcessBuilder processBuilder = StreamingUtil.createProcess(buildCommand()); process = processBuilder.start(); } /** * Build a command to execute an external process. * @return * @throws IOException */ private String[] buildCommand() throws IOException { String[] command = new String[8]; String funcName = invocationDesc.getName(); String filePath = invocationDesc.getPath(); command[COMMAND_IDX.UDF_LANGUAGE.ordinal()] = PYTHON_LANGUAGE; command[COMMAND_IDX.PATH_TO_CONTROLLER_FILE.ordinal()] = getControllerPath(); int lastSeparator = filePath.lastIndexOf(File.separator) + 1; String fileName = filePath.substring(lastSeparator); fileName = fileName.endsWith(FILE_EXTENSION) ? fileName.substring(0, fileName.length()-3) : fileName; command[COMMAND_IDX.UDF_FILE_NAME.ordinal()] = fileName; command[COMMAND_IDX.UDF_FILE_PATH.ordinal()] = lastSeparator <= 0 ? "." : filePath.substring(0, lastSeparator - 1); String fileCachePath = systemConf.get(TajoConf.ConfVars.PYTHON_CODE_DIR.keyname()); if (fileCachePath == null) { throw new IOException(TajoConf.ConfVars.PYTHON_CODE_DIR.keyname() + " must be set."); } command[COMMAND_IDX.PATH_TO_FILE_CACHE.ordinal()] = "'" + fileCachePath + "'"; command[COMMAND_IDX.OUT_SCHEMA.ordinal()] = outSchema.getColumn(0).getDataType().getType().name().toLowerCase(); command[COMMAND_IDX.FUNCTION_OR_CLASS_NAME.ordinal()] = funcName; command[COMMAND_IDX.FUNCTION_TYPE.ordinal()] = invocationDesc.isScalarFunction() ? "UDF" : "UDAF"; return command; } private void setSchema() { if (invocationDesc.isScalarFunction()) { TajoDataTypes.DataType[] paramTypes = functionSignature.getParamTypes(); SchemaBuilder inSchemaBuilder = SchemaBuilder.builder(); for (int i = 0; i < paramTypes.length; i++) { inSchemaBuilder.add(new Column("in_" + i, paramTypes[i])); } inSchema = inSchemaBuilder.build(); outSchema = SchemaBuilder.builder() .addAll(new Column[]{new Column("out", functionSignature.getReturnType())}) .build(); } else { // UDAF if (firstPhase) { // first phase TajoDataTypes.DataType[] paramTypes = functionSignature.getParamTypes(); SchemaBuilder inSchemaBuilder = SchemaBuilder.builder(); for (int i = 0; i < paramTypes.length; i++) { inSchemaBuilder.add(new Column("in_" + i, paramTypes[i])); } inSchema = inSchemaBuilder.build(); outSchema = SchemaBuilder.builder().add(new Column("json", TajoDataTypes.Type.TEXT)).build(); } else if (lastPhase) { inSchema = SchemaBuilder.builder().add(new Column("json", TajoDataTypes.Type.TEXT)).build(); outSchema = SchemaBuilder.builder().add(new Column("out", functionSignature.getReturnType())).build(); } else { // intermediate phase inSchema = outSchema = SchemaBuilder.builder().add(new Column("json", TajoDataTypes.Type.TEXT)).build(); } } projectionCols = new int[outSchema.size()]; for (int i = 0; i < outSchema.size(); i++) { projectionCols[i] = i; } } private void createInputHandlers() throws IOException { setSchema(); TextLineSerializer serializer = lineSerDe.createSerializer(pipeMeta); serializer.init(); this.inputHandler = new InputHandler(serializer); inputHandler.bindTo(stdin); TextLineDeserializer deserializer = lineSerDe.createDeserializer(outSchema, pipeMeta, projectionCols); deserializer.init(); this.outputHandler = new OutputHandler(deserializer); outputHandler.bindTo(stdout); } /** * Get the standard input, output, and error streams of the external process * * @throws IOException */ private void setStreams() { stdout = new DataInputStream(new BufferedInputStream(process.getInputStream())); stdin = new DataOutputStream(new BufferedOutputStream(process.getOutputStream())); stderr = new DataInputStream(new BufferedInputStream(process.getErrorStream())); } private static final File pythonScriptBaseDir = new File(PythonScriptEngine.getBaseDirPath()); private static final File pythonScriptControllerCopy = new File(PythonScriptEngine.getControllerPath()); private static final File pythonScriptUtilCopy = new File(PythonScriptEngine.getTajoUtilPath()); public static void initPythonScriptEngineFiles() throws IOException { if (!pythonScriptBaseDir.exists()) { pythonScriptBaseDir.mkdirs(); } // Controller and util should be always overwritten. PythonScriptEngine.loadController(pythonScriptControllerCopy); PythonScriptEngine.loadTajoUtil(pythonScriptUtilCopy); } public static void loadController(File controllerCopy) throws IOException { try (InputStream controllerInputStream = PythonScriptEngine.class.getResourceAsStream(PYTHON_CONTROLLER_JAR_PATH)) { FileUtils.copyInputStreamToFile(controllerInputStream, controllerCopy); } } public static void loadTajoUtil(File utilCopy) throws IOException { try (InputStream utilInputStream = PythonScriptEngine.class.getResourceAsStream(PYTHON_TAJO_UTIL_JAR_PATH)) { FileUtils.copyInputStreamToFile(utilInputStream, utilCopy); } } public static String getBaseDirPath() { LOG.info("Python base dir is " + BASE_DIR); return BASE_DIR; } /** * Find the path to the controller file for the streaming language. * * @return * @throws IOException */ public static String getControllerPath() { return BASE_DIR + File.separator + CONTROLLER_NAME; } public static String getTajoUtilPath() { return BASE_DIR + File.separator + TAJO_UTIL_NAME; } /** * Call Python scalar functions. * * @param input input tuple * @return evaluated result datum */ @Override public Datum callScalarFunc(Tuple input) { try { inputHandler.putNext(input, inSchema); stdin.flush(); } catch (Throwable e) { throwException(stderr, new RuntimeException("Failed adding input to inputQueue", e)); } Datum result = null; try { Tuple next = outputHandler.getNext(); if (next != null) { result = next.asDatum(0); } else { throw new RuntimeException("Cannot get output result from python controller"); } } catch (Throwable e) { throwException(stderr, new RuntimeException("Problem getting output: " + e.getMessage(), e)); } return result; } /** * Call Python aggregation functions. * * @param functionContext python function context * @param input input tuple */ @Override public void callAggFunc(FunctionContext functionContext, Tuple input) { String methodName; if (firstPhase) { // eval methodName = "eval"; } else { // merge methodName = "merge"; } try { inputHandler.putNext(methodName, input, inSchema); stdin.flush(); } catch (Throwable e) { throwException(stderr, new RuntimeException("Failed adding input to inputQueue while executing " + methodName + " with " + input, e)); } try { outputHandler.getNext(); } catch (Exception e) { throwException(stderr, new RuntimeException("Problem getting output: " + e.getMessage(), e)); } } /** * Get the standard error streams of the external process and throw the exception * * @throws RuntimeException */ private void throwException(InputStream stderr, RuntimeException e) throws RuntimeException { try { if (stderr.available() > 0) { byte[] bytes = new byte[Math.min(stderr.available(), 100 * StorageUnit.KB)]; IOUtils.readFully(stderr, bytes); String message = new String(bytes, Charset.defaultCharset()); throw new RuntimeException("Python exception caused by: " + message, e); } else { throw e; } } catch (IOException ioe) { throw new RuntimeException(ioe.getMessage(), ioe); } } /** * Restore the intermediate result in Python UDAF with the snapshot stored in the function context. * * @param functionContext */ public void updatePythonSideContext(PythonAggFunctionContext functionContext) throws IOException { try { inputHandler.putNext("update_context", functionContext); stdin.flush(); } catch (Throwable e) { throwException(stderr, new RuntimeException("Failed adding input to inputQueue", e)); } try { outputHandler.getNext(); } catch (Throwable e) { throwException(stderr, new RuntimeException("Problem getting output: " + e.getMessage(), e)); } } /** * Get the snapshot of the intermediate result in the Python UDAF. * * @param functionContext */ public void updateJavaSideContext(PythonAggFunctionContext functionContext) throws IOException { try { inputHandler.putNext("get_context", EMPTY_INPUT, EMPTY_SCHEMA); stdin.flush(); } catch (Throwable e) { throwException(stderr, new RuntimeException("Failed adding input to inputQueue", e)); } try { outputHandler.getNext(functionContext); } catch (Throwable e) { throwException(stderr, new RuntimeException("Problem getting output: " + e.getMessage(), e)); } } /** * Get intermediate result after the first stage. * * @param functionContext * @return */ @Override public String getPartialResult(FunctionContext functionContext) { try { inputHandler.putNext("get_partial_result", EMPTY_INPUT, EMPTY_SCHEMA); stdin.flush(); } catch (Throwable e) { throwException(stderr, new RuntimeException("Failed adding input to inputQueue", e)); } String result = null; try { result = outputHandler.getPartialResultString(); } catch (Throwable e) { throwException(stderr, new RuntimeException("Problem getting output: " + e.getMessage(), e)); } return result; } /** * Get final result after the last stage. * * @param functionContext * @return */ @Override public Datum getFinalResult(FunctionContext functionContext) { try { inputHandler.putNext("get_final_result", EMPTY_INPUT, EMPTY_SCHEMA); stdin.flush(); } catch (Exception e) { throwException(stderr, new RuntimeException("Failed adding input to inputQueue", e)); } Datum result = null; try { Tuple next = outputHandler.getNext(); if (next != null) { result = next.asDatum(0); } else { throw new RuntimeException("Cannot get output result from python controller"); } } catch (Exception e) { throwException(stderr, new RuntimeException("Problem getting output: " + e.getMessage(), e)); } return result; } }