/*
* Copyright (c) 2004, PostgreSQL Global Development Group
* See the LICENSE file in the project root for more information.
*/
package org.postgresql.test.jdbc2;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import org.postgresql.test.TestUtil;
import org.junit.Test;
import java.sql.Array;
import java.sql.CallableStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.SQLWarning;
import java.sql.Statement;
import java.sql.Types;
/*
* CallableStatement tests.
*
* @author Paul Bethe
*/
public class CallableStmtTest extends BaseTest4 {
@Override
public void setUp() throws Exception {
super.setUp();
TestUtil.createTable(con, "int_table", "id int");
Statement stmt = con.createStatement();
stmt.execute(
"CREATE OR REPLACE FUNCTION testspg__getString (varchar) "
+ "RETURNS varchar AS ' DECLARE inString alias for $1; begin "
+ "return ''bob''; end; ' LANGUAGE plpgsql;");
stmt.execute(
"CREATE OR REPLACE FUNCTION testspg__getDouble (float) "
+ "RETURNS float AS ' DECLARE inString alias for $1; begin "
+ "return 42.42; end; ' LANGUAGE plpgsql;");
stmt.execute(
"CREATE OR REPLACE FUNCTION testspg__getVoid (float) "
+ "RETURNS void AS ' DECLARE inString alias for $1; begin "
+ " return; end; ' LANGUAGE plpgsql;");
stmt.execute(
"CREATE OR REPLACE FUNCTION testspg__getInt (int) RETURNS int "
+ " AS 'DECLARE inString alias for $1; begin "
+ "return 42; end;' LANGUAGE plpgsql;");
stmt.execute(
"CREATE OR REPLACE FUNCTION testspg__getShort (int2) RETURNS int2 "
+ " AS 'DECLARE inString alias for $1; begin "
+ "return 42; end;' LANGUAGE plpgsql;");
stmt.execute(
"CREATE OR REPLACE FUNCTION testspg__getNumeric (numeric) "
+ "RETURNS numeric AS ' DECLARE inString alias for $1; "
+ "begin return 42; end; ' LANGUAGE plpgsql;");
stmt.execute(
"CREATE OR REPLACE FUNCTION testspg__getNumericWithoutArg() "
+ "RETURNS numeric AS ' "
+ "begin return 42; end; ' LANGUAGE plpgsql;");
stmt.execute(
"CREATE OR REPLACE FUNCTION testspg__getarray() RETURNS int[] as "
+ "'SELECT ''{1,2}''::int[];' LANGUAGE sql");
stmt.execute(
"CREATE OR REPLACE FUNCTION testspg__raisenotice() RETURNS int as "
+ "'BEGIN RAISE NOTICE ''hello''; RAISE NOTICE ''goodbye''; RETURN 1; END;' LANGUAGE plpgsql");
stmt.execute(
"CREATE OR REPLACE FUNCTION testspg__insertInt(int) RETURNS int as "
+ "'BEGIN INSERT INTO int_table(id) VALUES ($1); RETURN 1; END;' LANGUAGE plpgsql");
stmt.close();
}
@Override
public void tearDown() throws SQLException {
Statement stmt = con.createStatement();
TestUtil.dropTable(con, "int_table");
stmt.execute("drop FUNCTION testspg__getString (varchar);");
stmt.execute("drop FUNCTION testspg__getDouble (float);");
stmt.execute("drop FUNCTION testspg__getVoid(float);");
stmt.execute("drop FUNCTION testspg__getInt (int);");
stmt.execute("drop FUNCTION testspg__getShort(int2)");
stmt.execute("drop FUNCTION testspg__getNumeric (numeric);");
stmt.execute("drop FUNCTION testspg__getNumericWithoutArg ();");
stmt.execute("DROP FUNCTION testspg__getarray();");
stmt.execute("DROP FUNCTION testspg__raisenotice();");
stmt.execute("DROP FUNCTION testspg__insertInt(int);");
super.tearDown();
}
final String func = "{ ? = call ";
final String pkgName = "testspg__";
@Test
public void testGetUpdateCount() throws SQLException {
assumeCallableStatementsSupported();
CallableStatement call = con.prepareCall(func + pkgName + "getDouble (?) }");
call.setDouble(2, 3.04);
call.registerOutParameter(1, Types.DOUBLE);
call.execute();
assertEquals(-1, call.getUpdateCount());
assertNull(call.getResultSet());
assertEquals(42.42, call.getDouble(1), 0.00001);
call.close();
// test without an out parameter
call = con.prepareCall("{ call " + pkgName + "getDouble(?) }");
call.setDouble(1, 3.04);
call.execute();
assertEquals(-1, call.getUpdateCount());
ResultSet rs = call.getResultSet();
assertNotNull(rs);
assertTrue(rs.next());
assertEquals(42.42, rs.getDouble(1), 0.00001);
assertTrue(!rs.next());
rs.close();
assertEquals(-1, call.getUpdateCount());
assertTrue(!call.getMoreResults());
call.close();
}
@Test
public void testGetDouble() throws Throwable {
assumeCallableStatementsSupported();
CallableStatement call = con.prepareCall(func + pkgName + "getDouble (?) }");
call.setDouble(2, 3.04);
call.registerOutParameter(1, Types.DOUBLE);
call.execute();
assertEquals(42.42, call.getDouble(1), 0.00001);
// test without an out parameter
call = con.prepareCall("{ call " + pkgName + "getDouble(?) }");
call.setDouble(1, 3.04);
call.execute();
call = con.prepareCall("{ call " + pkgName + "getVoid(?) }");
call.setDouble(1, 3.04);
call.execute();
}
@Test
public void testGetInt() throws Throwable {
assumeCallableStatementsSupported();
CallableStatement call = con.prepareCall(func + pkgName + "getInt (?) }");
call.setInt(2, 4);
call.registerOutParameter(1, Types.INTEGER);
call.execute();
assertEquals(42, call.getInt(1));
}
@Test
public void testGetShort() throws Throwable {
assumeCallableStatementsSupported();
if (TestUtil.isProtocolVersion(con, 3)) {
CallableStatement call = con.prepareCall(func + pkgName + "getShort (?) }");
call.setShort(2, (short) 4);
call.registerOutParameter(1, Types.SMALLINT);
call.execute();
assertEquals(42, call.getShort(1));
}
}
@Test
public void testGetNumeric() throws Throwable {
assumeCallableStatementsSupported();
CallableStatement call = con.prepareCall(func + pkgName + "getNumeric (?) }");
call.setBigDecimal(2, new java.math.BigDecimal(4));
call.registerOutParameter(1, Types.NUMERIC);
call.execute();
assertEquals(new java.math.BigDecimal(42), call.getBigDecimal(1));
}
@Test
public void testGetNumericWithoutArg() throws Throwable {
assumeCallableStatementsSupported();
CallableStatement call = con.prepareCall(func + pkgName + "getNumericWithoutArg () }");
call.registerOutParameter(1, Types.NUMERIC);
call.execute();
assertEquals(new java.math.BigDecimal(42), call.getBigDecimal(1));
}
@Test
public void testGetString() throws Throwable {
assumeCallableStatementsSupported();
CallableStatement call = con.prepareCall(func + pkgName + "getString (?) }");
call.setString(2, "foo");
call.registerOutParameter(1, Types.VARCHAR);
call.execute();
assertEquals("bob", call.getString(1));
}
@Test
public void testGetArray() throws SQLException {
assumeCallableStatementsSupported();
CallableStatement call = con.prepareCall(func + pkgName + "getarray()}");
call.registerOutParameter(1, Types.ARRAY);
call.execute();
Array arr = call.getArray(1);
ResultSet rs = arr.getResultSet();
assertTrue(rs.next());
assertEquals(1, rs.getInt(1));
assertTrue(rs.next());
assertEquals(2, rs.getInt(1));
assertTrue(!rs.next());
}
@Test
public void testRaiseNotice() throws SQLException {
assumeCallableStatementsSupported();
Statement statement = con.createStatement();
statement.execute("SET SESSION client_min_messages = 'NOTICE'");
CallableStatement call = con.prepareCall(func + pkgName + "raisenotice()}");
call.registerOutParameter(1, Types.INTEGER);
call.execute();
SQLWarning warn = call.getWarnings();
assertNotNull(warn);
assertEquals("hello", warn.getMessage());
warn = warn.getNextWarning();
assertNotNull(warn);
assertEquals("goodbye", warn.getMessage());
assertEquals(1, call.getInt(1));
}
@Test
public void testWasNullBeforeFetch() throws SQLException {
CallableStatement cs = con.prepareCall("{? = call lower(?)}");
cs.registerOutParameter(1, Types.VARCHAR);
cs.setString(2, "Hi");
try {
cs.wasNull();
fail("expected exception");
} catch (Exception e) {
assertTrue(e instanceof SQLException);
}
}
@Test
public void testFetchBeforeExecute() throws SQLException {
CallableStatement cs = con.prepareCall("{? = call lower(?)}");
cs.registerOutParameter(1, Types.VARCHAR);
cs.setString(2, "Hi");
try {
cs.getString(1);
fail("expected exception");
} catch (Exception e) {
assertTrue(e instanceof SQLException);
}
}
@Test
public void testFetchWithNoResults() throws SQLException {
CallableStatement cs = con.prepareCall("{call now()}");
cs.execute();
try {
cs.getObject(1);
fail("expected exception");
} catch (Exception e) {
assertTrue(e instanceof SQLException);
}
}
@Test
public void testBadStmt() throws Throwable {
tryOneBadStmt("{ ?= " + pkgName + "getString (?) }");
tryOneBadStmt("{ ?= call getString (?) ");
tryOneBadStmt("{ = ? call getString (?); }");
}
protected void tryOneBadStmt(String sql) throws SQLException {
try {
con.prepareCall(sql);
fail("Bad statement (" + sql + ") was not caught.");
} catch (SQLException e) {
}
}
@Test
public void testBatchCall() throws SQLException {
CallableStatement call = con.prepareCall("{ call " + pkgName + "insertInt(?) }");
call.setInt(1, 1);
call.addBatch();
call.setInt(1, 2);
call.addBatch();
call.setInt(1, 3);
call.addBatch();
call.executeBatch();
call.close();
Statement stmt = con.createStatement();
ResultSet rs = stmt.executeQuery("SELECT id FROM int_table ORDER BY id");
assertTrue(rs.next());
assertEquals(1, rs.getInt(1));
assertTrue(rs.next());
assertEquals(2, rs.getInt(1));
assertTrue(rs.next());
assertEquals(3, rs.getInt(1));
assertTrue(!rs.next());
}
}