/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.zeppelin.rest;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
import java.io.File;
import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.apache.zeppelin.interpreter.InterpreterResult;
import org.apache.zeppelin.interpreter.InterpreterSetting;
import org.apache.zeppelin.notebook.Note;
import org.apache.zeppelin.notebook.Paragraph;
import org.apache.zeppelin.scheduler.Job.Status;
import org.apache.zeppelin.server.ZeppelinServer;
import org.apache.zeppelin.user.AuthenticationInfo;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import com.google.gson.Gson;
/**
* Test against spark cluster.
* Spark cluster is started by CI server using testing/startSparkCluster.sh
*/
public class ZeppelinSparkClusterTest extends AbstractTestRestApi {
Gson gson = new Gson();
AuthenticationInfo anonymous;
@BeforeClass
public static void init() throws Exception {
AbstractTestRestApi.startUp();
}
@AfterClass
public static void destroy() throws Exception {
AbstractTestRestApi.shutDown();
}
@Before
public void setUp() {
anonymous = new AuthenticationInfo("anonymous");
}
private void waitForFinish(Paragraph p) {
while (p.getStatus() != Status.FINISHED
&& p.getStatus() != Status.ERROR
&& p.getStatus() != Status.ABORT) {
try {
Thread.sleep(100);
} catch (InterruptedException e) {
LOG.error("Exception in WebDriverManager while getWebDriver ", e);
}
}
}
@Test
public void scalaOutputTest() throws IOException {
// create new note
Note note = ZeppelinServer.notebook.createNote(anonymous);
Paragraph p = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
Map config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%spark import java.util.Date\n" +
"import java.net.URL\n" +
"println(\"hello\")\n"
);
p.setAuthenticationInfo(anonymous);
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals("import java.util.Date\n" +
"import java.net.URL\n" +
"hello\n", p.getResult().message().get(0).getData());
ZeppelinServer.notebook.removeNote(note.getId(), anonymous);
}
@Test
public void basicRDDTransformationAndActionTest() throws IOException {
// create new note
Note note = ZeppelinServer.notebook.createNote(anonymous);
// run markdown paragraph, again
Paragraph p = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
Map config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%spark print(sc.parallelize(1 to 10).reduce(_ + _))");
p.setAuthenticationInfo(anonymous);
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals("55", p.getResult().message().get(0).getData());
ZeppelinServer.notebook.removeNote(note.getId(), anonymous);
}
@Test
public void sparkSQLTest() throws IOException {
// create new note
Note note = ZeppelinServer.notebook.createNote(anonymous);
int sparkVersion = getSparkVersionNumber(note);
// DataFrame API is available from spark 1.3
if (sparkVersion >= 13) {
// test basic dataframe api
Paragraph p = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
Map config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%spark val df=sqlContext.createDataFrame(Seq((\"hello\",20)))\n" +
"df.collect()");
p.setAuthenticationInfo(anonymous);
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertTrue(p.getResult().message().get(0).getData().contains(
"Array[org.apache.spark.sql.Row] = Array([hello,20])"));
// test display DataFrame
p = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%spark val df=sqlContext.createDataFrame(Seq((\"hello\",20)))\n" +
"z.show(df)");
p.setAuthenticationInfo(anonymous);
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals(InterpreterResult.Type.TABLE, p.getResult().message().get(1).getType());
assertEquals("_1\t_2\nhello\t20\n", p.getResult().message().get(1).getData());
// test display DataSet
if (sparkVersion >= 20) {
p = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%spark val ds=spark.createDataset(Seq((\"hello\",20)))\n" +
"z.show(ds)");
p.setAuthenticationInfo(anonymous);
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals(InterpreterResult.Type.TABLE, p.getResult().message().get(1).getType());
assertEquals("_1\t_2\nhello\t20\n", p.getResult().message().get(1).getData());
}
ZeppelinServer.notebook.removeNote(note.getId(), anonymous);
}
}
@Test
public void sparkRTest() throws IOException {
// create new note
Note note = ZeppelinServer.notebook.createNote(anonymous);
int sparkVersion = getSparkVersionNumber(note);
if (isSparkR() && sparkVersion >= 14) { // sparkr supported from 1.4.0
// restart spark interpreter
List<InterpreterSetting> settings =
ZeppelinServer.notebook.getBindedInterpreterSettings(note.getId());
for (InterpreterSetting setting : settings) {
if (setting.getName().equals("spark")) {
ZeppelinServer.notebook.getInterpreterSettingManager().restart(setting.getId());
break;
}
}
String sqlContextName = "sqlContext";
if (sparkVersion >= 20) {
sqlContextName = "spark";
}
Paragraph p = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
Map config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%r localDF <- data.frame(name=c(\"a\", \"b\", \"c\"), age=c(19, 23, 18))\n" +
"df <- createDataFrame(" + sqlContextName + ", localDF)\n" +
"count(df)"
);
p.setAuthenticationInfo(anonymous);
note.run(p.getId());
waitForFinish(p);
System.err.println("sparkRTest=" + p.getResult().message().get(0).getData());
assertEquals(Status.FINISHED, p.getStatus());
assertEquals("[1] 3", p.getResult().message().get(0).getData().trim());
}
ZeppelinServer.notebook.removeNote(note.getId(), anonymous);
}
@Test
public void pySparkTest() throws IOException {
// create new note
Note note = ZeppelinServer.notebook.createNote(anonymous);
note.setName("note");
int sparkVersion = getSparkVersionNumber(note);
if (isPyspark() && sparkVersion >= 12) { // pyspark supported from 1.2.1
// run markdown paragraph, again
Paragraph p = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
Map config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%pyspark print(sc.parallelize(range(1, 11)).reduce(lambda a, b: a + b))");
p.setAuthenticationInfo(anonymous);
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals("55\n", p.getResult().message().get(0).getData());
if (sparkVersion >= 13) {
// run sqlContext test
p = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%pyspark from pyspark.sql import Row\n" +
"df=sqlContext.createDataFrame([Row(id=1, age=20)])\n" +
"df.collect()");
p.setAuthenticationInfo(anonymous);
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals("[Row(age=20, id=1)]\n", p.getResult().message().get(0).getData());
// test display Dataframe
p = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%pyspark from pyspark.sql import Row\n" +
"df=sqlContext.createDataFrame([Row(id=1, age=20)])\n" +
"z.show(df)");
p.setAuthenticationInfo(anonymous);
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals(InterpreterResult.Type.TABLE, p.getResult().message().get(0).getType());
// TODO (zjffdu), one more \n is appended, need to investigate why.
assertEquals("age\tid\n20\t1\n", p.getResult().message().get(0).getData());
// test udf
p = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%pyspark sqlContext.udf.register(\"f1\", lambda x: len(x))\n" +
"sqlContext.sql(\"select f1(\\\"abc\\\") as len\").collect()");
p.setAuthenticationInfo(anonymous);
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals("[Row(len=u'3')]\n", p.getResult().message().get(0).getData());
// test exception
p = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
/**
%pyspark
a=1
print(a2)
*/
p.setText("%pyspark a=1\n\nprint(a2)");
p.setAuthenticationInfo(anonymous);
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.ERROR, p.getStatus());
assertTrue(p.getResult().message().get(0).getData()
.contains("Fail to execute line 3: print(a2)"));
assertTrue(p.getResult().message().get(0).getData()
.contains("name 'a2' is not defined"));
}
if (sparkVersion >= 20) {
// run SparkSession test
p = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%pyspark from pyspark.sql import Row\n" +
"df=sqlContext.createDataFrame([Row(id=1, age=20)])\n" +
"df.collect()");
p.setAuthenticationInfo(anonymous);
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals("[Row(age=20, id=1)]\n", p.getResult().message().get(0).getData());
// test udf
p = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
// use SQLContext to register UDF but use this UDF through SparkSession
p.setText("%pyspark sqlContext.udf.register(\"f1\", lambda x: len(x))\n" +
"spark.sql(\"select f1(\\\"abc\\\") as len\").collect()");
p.setAuthenticationInfo(anonymous);
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals("[Row(len=u'3')]\n", p.getResult().message().get(0).getData());
}
}
ZeppelinServer.notebook.removeNote(note.getId(), anonymous);
}
@Test
public void pySparkAutoConvertOptionTest() throws IOException {
// create new note
Note note = ZeppelinServer.notebook.createNote(anonymous);
note.setName("note");
int sparkVersionNumber = getSparkVersionNumber(note);
if (isPyspark() && sparkVersionNumber >= 14) { // auto_convert enabled from spark 1.4
// run markdown paragraph, again
Paragraph p = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
Map config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
String sqlContextName = "sqlContext";
if (sparkVersionNumber >= 20) {
sqlContextName = "spark";
}
p.setText("%pyspark\nfrom pyspark.sql.functions import *\n"
+ "print(" + sqlContextName + ".range(0, 10).withColumn('uniform', rand(seed=10) * 3.14).count())");
p.setAuthenticationInfo(anonymous);
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals("10\n", p.getResult().message().get(0).getData());
}
ZeppelinServer.notebook.removeNote(note.getId(), anonymous);
}
@Test
public void zRunTest() throws IOException {
// create new note
Note note = ZeppelinServer.notebook.createNote(anonymous);
Paragraph p0 = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
Map config0 = p0.getConfig();
config0.put("enabled", true);
p0.setConfig(config0);
p0.setText("%spark z.run(1)");
p0.setAuthenticationInfo(anonymous);
Paragraph p1 = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
Map config1 = p1.getConfig();
config1.put("enabled", true);
p1.setConfig(config1);
p1.setText("%spark val a=10");
p1.setAuthenticationInfo(anonymous);
Paragraph p2 = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
Map config2 = p2.getConfig();
config2.put("enabled", true);
p2.setConfig(config2);
p2.setText("%spark print(a)");
p2.setAuthenticationInfo(anonymous);
note.run(p0.getId());
waitForFinish(p0);
assertEquals(Status.FINISHED, p0.getStatus());
// z.run is not blocking call. So p1 may not be finished when p0 is done.
waitForFinish(p1);
note.run(p2.getId());
waitForFinish(p2);
assertEquals(Status.FINISHED, p2.getStatus());
assertEquals("10", p2.getResult().message().get(0).getData());
Paragraph p3 = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
Map config3 = p3.getConfig();
config3.put("enabled", true);
p3.setConfig(config3);
p3.setText("%spark println(new java.util.Date())");
p3.setAuthenticationInfo(anonymous);
p0.setText(String.format("%%spark z.runNote(\"%s\")", note.getId()));
note.run(p0.getId());
waitForFinish(p0);
waitForFinish(p1);
waitForFinish(p2);
waitForFinish(p3);
assertEquals(Status.FINISHED, p3.getStatus());
String p3result = p3.getResult().message().get(0).getData();
assertNotEquals(null, p3result);
assertNotEquals("", p3result);
p0.setText(String.format("%%spark z.run(\"%s\", \"%s\")", note.getId(), p3.getId()));
p3.setText("%%spark println(\"END\")");
note.run(p0.getId());
waitForFinish(p0);
waitForFinish(p3);
assertNotEquals(p3result, p3.getResult().message());
ZeppelinServer.notebook.removeNote(note.getId(), anonymous);
}
@Test
public void pySparkDepLoaderTest() throws IOException {
// create new note
Note note = ZeppelinServer.notebook.createNote(anonymous);
int sparkVersionNumber = getSparkVersionNumber(note);
if (isPyspark() && sparkVersionNumber >= 14) {
// restart spark interpreter
List<InterpreterSetting> settings =
ZeppelinServer.notebook.getBindedInterpreterSettings(note.getId());
for (InterpreterSetting setting : settings) {
if (setting.getName().equals("spark")) {
ZeppelinServer.notebook.getInterpreterSettingManager().restart(setting.getId());
break;
}
}
// load dep
Paragraph p0 = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
Map config = p0.getConfig();
config.put("enabled", true);
p0.setConfig(config);
p0.setText("%dep z.load(\"com.databricks:spark-csv_2.11:1.2.0\")");
p0.setAuthenticationInfo(anonymous);
note.run(p0.getId());
waitForFinish(p0);
assertEquals(Status.FINISHED, p0.getStatus());
// write test csv file
File tmpFile = File.createTempFile("test", "csv");
FileUtils.write(tmpFile, "a,b\n1,2");
// load data using libraries from dep loader
Paragraph p1 = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
p1.setConfig(config);
String sqlContextName = "sqlContext";
if (sparkVersionNumber >= 20) {
sqlContextName = "spark";
}
p1.setText("%pyspark\n" +
"from pyspark.sql import SQLContext\n" +
"print(" + sqlContextName + ".read.format('com.databricks.spark.csv')" +
".load('"+ tmpFile.getAbsolutePath() +"').count())");
p1.setAuthenticationInfo(anonymous);
note.run(p1.getId());
waitForFinish(p1);
assertEquals(Status.FINISHED, p1.getStatus());
assertEquals("2\n", p1.getResult().message().get(0).getData());
}
ZeppelinServer.notebook.removeNote(note.getId(), anonymous);
}
/**
* Get spark version number as a numerical value.
* eg. 1.1.x => 11, 1.2.x => 12, 1.3.x => 13 ...
*/
private int getSparkVersionNumber(Note note) {
Paragraph p = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
note.setName("note");
Map config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%spark print(sc.version)");
p.setAuthenticationInfo(anonymous);
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
String sparkVersion = p.getResult().message().get(0).getData();
System.out.println("Spark version detected " + sparkVersion);
String[] split = sparkVersion.split("\\.");
int version = Integer.parseInt(split[0]) * 10 + Integer.parseInt(split[1]);
return version;
}
@Test
public void testSparkZeppelinContextDynamicForms() throws IOException {
Note note = ZeppelinServer.notebook.createNote(anonymous);
Paragraph p = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
note.setName("note");
Map config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
String code = "%spark.spark println(z.textbox(\"my_input\", \"default_name\"))\n" +
"println(z.select(\"my_select\", \"1\"," +
"Seq((\"1\", \"select_1\"), (\"2\", \"select_2\"))))\n" +
"val items=z.checkbox(\"my_checkbox\", Seq(\"2\"), " +
"Seq((\"1\", \"check_1\"), (\"2\", \"check_2\")))\n" +
"println(items(0))";
p.setText(code);
p.setAuthenticationInfo(anonymous);
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
Iterator<String> formIter = p.settings.getForms().keySet().iterator();
assert(formIter.next().equals("my_input"));
assert(formIter.next().equals("my_select"));
assert(formIter.next().equals("my_checkbox"));
// check dynamic forms values
String[] result = p.getResult().message().get(0).getData().split("\n");
assertEquals(4, result.length);
assertEquals("default_name", result[0]);
assertEquals("1", result[1]);
assertEquals("items: Seq[Object] = Buffer(2)", result[2]);
assertEquals("2", result[3]);
}
@Test
public void testPySparkZeppelinContextDynamicForms() throws IOException {
Note note = ZeppelinServer.notebook.createNote(anonymous);
Paragraph p = note.addNewParagraph(AuthenticationInfo.ANONYMOUS);
note.setName("note");
Map config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
String code = "%spark.pyspark print(z.input('my_input', 'default_name'))\n" +
"print(z.select('my_select', " +
"[('1', 'select_1'), ('2', 'select_2')], defaultValue='1'))\n" +
"items=z.checkbox('my_checkbox', " +
"[('1', 'check_1'), ('2', 'check_2')], defaultChecked=['2'])\n" +
"print(items[0])";
p.setText(code);
p.setAuthenticationInfo(anonymous);
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
Iterator<String> formIter = p.settings.getForms().keySet().iterator();
assert(formIter.next().equals("my_input"));
assert(formIter.next().equals("my_select"));
assert(formIter.next().equals("my_checkbox"));
// check dynamic forms values
String[] result = p.getResult().message().get(0).getData().split("\n");
assertEquals(3, result.length);
assertEquals("default_name", result[0]);
assertEquals("1", result[1]);
assertEquals("2", result[2]);
}
}