/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.tajo.plan.function; import org.apache.tajo.catalog.CatalogUtil; import org.apache.tajo.catalog.FunctionDesc; import org.apache.tajo.common.TajoDataTypes; import org.apache.tajo.datum.Datum; import org.apache.tajo.datum.DatumFactory; import org.apache.tajo.plan.function.python.PythonScriptEngine; import org.apache.tajo.storage.Tuple; import java.io.IOException; public class PythonAggFunctionInvoke extends AggFunctionInvoke implements Cloneable { private transient PythonScriptEngine scriptEngine; private transient PythonAggFunctionContext prevContext; private static int nextContextId = 0; /** * Aggregated result should be kept in Tajo task rather than Python UDAF to control memory usage. * {@link PythonAggFunctionContext} is to support executing aggregation with keys. * It stores a snapshot of Python UDAF class instance as a json string. * * For each UDAF call with different aggregation key, * {@link PythonAggFunctionInvoke} calls {@link PythonAggFunctionInvoke#updateContextIfNecessary} to backup and restore * intermediate aggregation states for the previous key and the current key, respectively. */ public static class PythonAggFunctionContext implements FunctionContext { final int id; // id to identify each context String jsonData; // snapshot of Python class public PythonAggFunctionContext() { this.id = nextContextId++; } public void setJsonData(String jsonData) { this.jsonData = jsonData; } public String getJsonData() { return jsonData; } } public PythonAggFunctionInvoke(FunctionDesc functionDesc) { super(functionDesc); } @Override public void init(FunctionInvokeContext context) throws IOException { this.scriptEngine = (PythonScriptEngine) context.getScriptEngine(); } @Override public FunctionContext newContext() { return new PythonAggFunctionContext(); } /** * Context does not need to be updated per every UDAF call. * If the current aggregation key is same with the previous one, * python-side context doesn't need to be updated because it already contains necessary intermediate result. * * @param context */ private void updateContextIfNecessary(FunctionContext context) { PythonAggFunctionContext givenContext = (PythonAggFunctionContext) context; if (prevContext == null || prevContext.id != givenContext.id) { try { if (prevContext != null) { scriptEngine.updateJavaSideContext(prevContext); } scriptEngine.updatePythonSideContext(givenContext); prevContext = givenContext; } catch (IOException e) { throw new RuntimeException(e); } } } @Override public void eval(FunctionContext context, Tuple params) { updateContextIfNecessary(context); scriptEngine.callAggFunc(context, params); } @Override public void merge(FunctionContext context, Tuple params) { if (params.isBlankOrNull(0)) { return; } updateContextIfNecessary(context); scriptEngine.callAggFunc(context, params); } @Override public Datum getPartialResult(FunctionContext context) { updateContextIfNecessary(context); // partial results are stored as json strings. String result = scriptEngine.getPartialResult(context); return DatumFactory.createText(result); } @Override public TajoDataTypes.DataType getPartialResultType() { return CatalogUtil.newSimpleDataType(TajoDataTypes.Type.TEXT); } @Override public Datum terminate(FunctionContext context) { updateContextIfNecessary(context); return scriptEngine.getFinalResult(context); } @Override public Object clone() throws CloneNotSupportedException { // nothing to do return super.clone(); } }