/*
* SonarQube
* Copyright (C) 2009-2017 SonarSource SA
* mailto:info AT sonarsource DOT com
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3 of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program; if not, write to the Free Software Foundation,
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
package org.sonar.db;
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;
import java.io.InputStream;
import java.math.BigDecimal;
import java.sql.Clob;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import javax.annotation.CheckForNull;
import javax.annotation.Nullable;
import org.apache.commons.dbutils.QueryRunner;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.StringUtils;
import org.dbunit.Assertion;
import org.dbunit.DatabaseUnitException;
import org.dbunit.assertion.DiffCollectingFailureHandler;
import org.dbunit.assertion.Difference;
import org.dbunit.database.DatabaseConfig;
import org.dbunit.database.IDatabaseConnection;
import org.dbunit.dataset.CompositeDataSet;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.ITable;
import org.dbunit.dataset.ReplacementDataSet;
import org.dbunit.dataset.filter.DefaultColumnFilter;
import org.dbunit.dataset.xml.FlatXmlDataSet;
import org.dbunit.ext.mssql.InsertIdentityOperation;
import org.dbunit.operation.DatabaseOperation;
import org.junit.rules.ExternalResource;
import org.sonar.api.utils.log.Loggers;
import org.sonar.core.util.stream.MoreCollectors;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.Lists.asList;
import static com.google.common.collect.Lists.newArrayList;
import static com.google.common.collect.Maps.newHashMap;
import static java.sql.ResultSetMetaData.columnNoNulls;
import static java.sql.ResultSetMetaData.columnNullable;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.fail;
public class AbstractDbTester<T extends CoreTestDb> extends ExternalResource {
protected static final Joiner COMMA_JOINER = Joiner.on(", ");
protected final T db;
public AbstractDbTester(T db) {
this.db = db;
}
public void executeUpdateSql(String sql, Object... params) {
try (Connection connection = getConnection()) {
new QueryRunner().update(connection, sql, params);
if (!connection.getAutoCommit()) {
connection.commit();
}
} catch (SQLException e) {
SQLException nextException = e.getNextException();
if (nextException != null) {
throw new IllegalStateException("Fail to execute sql: " + sql,
new SQLException(e.getMessage(), nextException.getSQLState(), nextException.getErrorCode(), nextException));
}
throw new IllegalStateException("Fail to execute sql: " + sql, e);
} catch (Exception e) {
throw new IllegalStateException("Fail to execute sql: " + sql, e);
}
}
public void executeDdl(String ddl) {
try (Connection connection = getConnection();
Statement stmt = connection.createStatement()) {
stmt.execute(ddl);
} catch (SQLException e) {
throw new IllegalStateException("Failed to execute DDL: " + ddl, e);
}
}
/**
* Very simple helper method to insert some data into a table.
* It's the responsibility of the caller to convert column values to string.
*/
public void executeInsert(String table, String firstColumn, Object... others) {
executeInsert(table, mapOf(firstColumn, others));
}
private static Map<String, Object> mapOf(String firstColumn, Object... values) {
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
List<Object> args = asList(firstColumn, values);
for (int i = 0; i < args.size(); i++) {
String key = args.get(i).toString();
Object value = args.get(i + 1);
if (value != null) {
builder.put(key, value);
}
i++;
}
return builder.build();
}
/**
* Very simple helper method to insert some data into a table.
* It's the responsibility of the caller to convert column values to string.
*/
public void executeInsert(String table, Map<String, Object> valuesByColumn) {
if (valuesByColumn.isEmpty()) {
throw new IllegalArgumentException("Values cannot be empty");
}
String sql = "insert into " + table.toLowerCase(Locale.ENGLISH) + " (" +
COMMA_JOINER.join(valuesByColumn.keySet()) +
") values (" +
COMMA_JOINER.join(Collections.nCopies(valuesByColumn.size(), '?')) +
")";
executeUpdateSql(sql, valuesByColumn.values().toArray(new Object[valuesByColumn.size()]));
}
/**
* Returns the number of rows in the table. Example:
* <pre>int issues = countRowsOfTable("issues")</pre>
*/
public int countRowsOfTable(String tableName) {
return countRowsOfTable(tableName, new NewConnectionSupplier());
}
protected int countRowsOfTable(String tableName, ConnectionSupplier connectionSupplier) {
checkArgument(StringUtils.containsNone(tableName, " "), "Parameter must be the name of a table. Got " + tableName);
return countSql("select count(1) from " + tableName.toLowerCase(Locale.ENGLISH), connectionSupplier);
}
/**
* Executes a SQL request starting with "SELECT COUNT(something) FROM", for example:
* <pre>int OpenIssues = countSql("select count('id') from issues where status is not null")</pre>
*/
public int countSql(String sql) {
return countSql(sql, new NewConnectionSupplier());
}
protected int countSql(String sql, ConnectionSupplier connectionSupplier) {
checkArgument(StringUtils.contains(sql, "count("),
"Parameter must be a SQL request containing 'count(x)' function. Got " + sql);
try (
ConnectionSupplier supplier = connectionSupplier;
PreparedStatement stmt = supplier.get().prepareStatement(sql);
ResultSet rs = stmt.executeQuery()) {
if (rs.next()) {
return rs.getInt(1);
}
throw new IllegalStateException("No results for " + sql);
} catch (Exception e) {
throw new IllegalStateException("Fail to execute sql: " + sql, e);
}
}
public List<Map<String, Object>> select(String selectSql) {
return select(selectSql, new NewConnectionSupplier());
}
protected List<Map<String, Object>> select(String selectSql, ConnectionSupplier connectionSupplier) {
try (
ConnectionSupplier supplier = connectionSupplier;
PreparedStatement stmt = supplier.get().prepareStatement(selectSql);
ResultSet rs = stmt.executeQuery()) {
return getHashMap(rs);
} catch (Exception e) {
throw new IllegalStateException("Fail to execute sql: " + selectSql, e);
}
}
public Map<String, Object> selectFirst(String selectSql) {
return selectFirst(selectSql, new NewConnectionSupplier());
}
protected Map<String, Object> selectFirst(String selectSql, ConnectionSupplier connectionSupplier) {
List<Map<String, Object>> rows = select(selectSql, connectionSupplier);
if (rows.isEmpty()) {
throw new IllegalStateException("No results for " + selectSql);
} else if (rows.size() > 1) {
throw new IllegalStateException("Too many results for " + selectSql);
}
return rows.get(0);
}
private static List<Map<String, Object>> getHashMap(ResultSet resultSet) throws Exception {
ResultSetMetaData metaData = resultSet.getMetaData();
int colCount = metaData.getColumnCount();
List<Map<String, Object>> rows = newArrayList();
while (resultSet.next()) {
Map<String, Object> columns = newHashMap();
for (int i = 1; i <= colCount; i++) {
Object value = resultSet.getObject(i);
if (value instanceof Clob) {
Clob clob = (Clob) value;
value = IOUtils.toString((clob.getAsciiStream()));
doClobFree(clob);
} else if (value instanceof BigDecimal) {
// In Oracle, INTEGER types are mapped as BigDecimal
BigDecimal bgValue = ((BigDecimal) value);
if (bgValue.scale() == 0) {
value = bgValue.longValue();
} else {
value = bgValue.doubleValue();
}
} else if (value instanceof Integer) {
// To be consistent, all INTEGER types are mapped as Long
value = ((Integer) value).longValue();
} else if (value instanceof Timestamp) {
value = new Date(((Timestamp) value).getTime());
}
columns.put(metaData.getColumnLabel(i), value);
}
rows.add(columns);
}
return rows;
}
public void prepareDbUnit(Class testClass, String... testNames) {
InputStream[] streams = new InputStream[testNames.length];
try {
for (int i = 0; i < testNames.length; i++) {
String path = "/" + testClass.getName().replace('.', '/') + "/" + testNames[i];
streams[i] = testClass.getResourceAsStream(path);
if (streams[i] == null) {
throw new IllegalStateException("DbUnit file not found: " + path);
}
}
prepareDbUnit(streams);
db.getCommands().resetPrimaryKeys(db.getDatabase().getDataSource());
} catch (SQLException e) {
throw translateException("Could not setup DBUnit data", e);
} finally {
for (InputStream stream : streams) {
IOUtils.closeQuietly(stream);
}
}
}
private void prepareDbUnit(InputStream... dataSetStream) {
IDatabaseConnection connection = null;
try {
IDataSet[] dataSets = new IDataSet[dataSetStream.length];
for (int i = 0; i < dataSetStream.length; i++) {
dataSets[i] = dbUnitDataSet(dataSetStream[i]);
}
db.getDbUnitTester().setDataSet(new CompositeDataSet(dataSets));
connection = dbUnitConnection();
new InsertIdentityOperation(DatabaseOperation.INSERT).execute(connection, db.getDbUnitTester().getDataSet());
} catch (Exception e) {
throw translateException("Could not setup DBUnit data", e);
} finally {
closeQuietly(connection);
}
}
public void assertDbUnitTable(Class testClass, String filename, String table, String... columns) {
IDatabaseConnection connection = dbUnitConnection();
try {
IDataSet dataSet = connection.createDataSet();
String path = "/" + testClass.getName().replace('.', '/') + "/" + filename;
IDataSet expectedDataSet = dbUnitDataSet(testClass.getResourceAsStream(path));
ITable filteredTable = DefaultColumnFilter.includedColumnsTable(dataSet.getTable(table), columns);
ITable filteredExpectedTable = DefaultColumnFilter.includedColumnsTable(expectedDataSet.getTable(table), columns);
Assertion.assertEquals(filteredExpectedTable, filteredTable);
} catch (DatabaseUnitException e) {
fail(e.getMessage());
} catch (SQLException e) {
throw translateException("Error while checking results", e);
} finally {
closeQuietly(connection);
}
}
public void assertDbUnit(Class testClass, String filename, String... tables) {
assertDbUnit(testClass, filename, new String[0], tables);
}
public void assertDbUnit(Class testClass, String filename, String[] excludedColumnNames, String... tables) {
IDatabaseConnection connection = null;
try {
connection = dbUnitConnection();
IDataSet dataSet = connection.createDataSet();
String path = "/" + testClass.getName().replace('.', '/') + "/" + filename;
InputStream inputStream = testClass.getResourceAsStream(path);
if (inputStream == null) {
throw new IllegalStateException(String.format("File '%s' does not exist", path));
}
IDataSet expectedDataSet = dbUnitDataSet(inputStream);
for (String table : tables) {
DiffCollectingFailureHandler diffHandler = new DiffCollectingFailureHandler();
ITable filteredTable = DefaultColumnFilter.excludedColumnsTable(dataSet.getTable(table), excludedColumnNames);
ITable filteredExpectedTable = DefaultColumnFilter.excludedColumnsTable(expectedDataSet.getTable(table), excludedColumnNames);
Assertion.assertEquals(filteredExpectedTable, filteredTable, diffHandler);
// Evaluate the differences and ignore some column values
List diffList = diffHandler.getDiffList();
for (Object o : diffList) {
Difference diff = (Difference) o;
if (!"[ignore]".equals(diff.getExpectedValue())) {
throw new DatabaseUnitException(diff.toString());
}
}
}
} catch (DatabaseUnitException e) {
e.printStackTrace();
fail(e.getMessage());
} catch (Exception e) {
throw translateException("Error while checking results", e);
} finally {
closeQuietly(connection);
}
}
public void assertColumnDefinition(String table, String column, int expectedType, @Nullable Integer expectedSize) {
assertColumnDefinition(table, column, expectedType, expectedSize, null);
}
public void assertColumnDefinition(String table, String column, int expectedType, @Nullable Integer expectedSize, @Nullable Boolean isNullable) {
try (Connection connection = getConnection();
PreparedStatement stmt = connection.prepareStatement("select * from " + table);
ResultSet res = stmt.executeQuery()) {
Integer columnIndex = getColumnIndex(res, column);
if (columnIndex == null) {
fail("The column '" + column + "' does not exist");
}
assertThat(res.getMetaData().getColumnType(columnIndex)).isEqualTo(expectedType);
if (expectedSize != null) {
assertThat(res.getMetaData().getColumnDisplaySize(columnIndex)).isEqualTo(expectedSize);
}
if (isNullable != null) {
assertThat(res.getMetaData().isNullable(columnIndex)).isEqualTo(isNullable ? columnNullable : columnNoNulls);
}
} catch (Exception e) {
throw new IllegalStateException("Fail to check column", e);
}
}
public void assertColumnDoesNotExist(String table, String column) throws SQLException {
try (Connection connection = getConnection();
PreparedStatement stmt = connection.prepareStatement("select * from " + table);
ResultSet res = stmt.executeQuery()) {
assertThat(getColumnNames(res)).doesNotContain(column);
}
}
public void assertTableDoesNotExist(String table) {
try (Connection connection = getConnection()) {
boolean tableExists = DatabaseUtils.tableExists(table, connection);
assertThat(tableExists).isFalse();
} catch (Exception e) {
throw new IllegalStateException("Fail to check if table exists", e);
}
}
/**
* Verify that non-unique index exists on columns
*/
public void assertIndex(String tableName, String indexName, String expectedColumn, String... expectedSecondaryColumns) {
assertIndexImpl(tableName, indexName, false, expectedColumn, expectedSecondaryColumns);
}
/**
* Verify that unique index exists on columns
*/
public void assertUniqueIndex(String tableName, String indexName, String expectedColumn, String... expectedSecondaryColumns) {
assertIndexImpl(tableName, indexName, true, expectedColumn, expectedSecondaryColumns);
}
private void assertIndexImpl(String tableName, String indexName, boolean expectedUnique, String expectedColumn, String... expectedSecondaryColumns) {
try (Connection connection = getConnection();
ResultSet rs = connection.getMetaData().getIndexInfo(null, null, tableName.toUpperCase(Locale.ENGLISH), false, false)) {
List<String> onColumns = new ArrayList<>();
while (rs.next()) {
if (indexName.equalsIgnoreCase(rs.getString("INDEX_NAME"))) {
assertThat(rs.getBoolean("NON_UNIQUE")).isEqualTo(!expectedUnique);
int position = rs.getInt("ORDINAL_POSITION");
onColumns.add(position - 1, rs.getString("COLUMN_NAME").toLowerCase(Locale.ENGLISH));
}
}
assertThat(asList(expectedColumn, expectedSecondaryColumns)).isEqualTo(onColumns);
} catch (SQLException e) {
throw new IllegalStateException("Fail to check index", e);
}
}
/**
* Verify that index with name {@code indexName} does not exist on the table {@code tableName}
*/
public void assertIndexDoesNotExist(String tableName, String indexName) {
try (Connection connection = getConnection();
ResultSet rs = connection.getMetaData().getIndexInfo(null, null, tableName.toUpperCase(Locale.ENGLISH), false, false)) {
List<String> indices = new ArrayList<>();
while (rs.next()) {
indices.add(rs.getString("INDEX_NAME").toLowerCase(Locale.ENGLISH));
}
assertThat(indices).doesNotContain(indexName);
} catch (SQLException e) {
throw new IllegalStateException("Fail to check existence of index", e);
}
}
public void assertPrimaryKey(String tableName, @Nullable String expectedPkName, String columnName, String... otherColumnNames) {
try (Connection connection = getConnection()) {
PK pk = pkOf(connection, tableName.toUpperCase(Locale.ENGLISH));
if (pk == null) {
pkOf(connection, tableName.toLowerCase(Locale.ENGLISH));
}
assertThat(pk).as("No primary key is defined on table %s", tableName).isNotNull();
if (expectedPkName != null) {
assertThat(pk.getName()).isEqualToIgnoringCase(expectedPkName);
}
List<String> expectedColumns = ImmutableList.copyOf(Iterables.concat(Collections.singletonList(columnName), Arrays.asList(otherColumnNames)));
assertThat(pk.getColumns()).as("Primary key does not have the '%s' expected columns", expectedColumns.size()).hasSize(expectedColumns.size());
Iterator<String> expectedColumnsIt = expectedColumns.iterator();
Iterator<String> actualColumnsIt = pk.getColumns().iterator();
while (expectedColumnsIt.hasNext() && actualColumnsIt.hasNext()) {
assertThat(actualColumnsIt.next()).isEqualToIgnoringCase(expectedColumnsIt.next());
}
} catch (SQLException e) {
throw new IllegalStateException("Fail to check primary key", e);
}
}
@CheckForNull
private PK pkOf(Connection connection, String tableName) throws SQLException {
try (ResultSet resultSet = connection.getMetaData().getPrimaryKeys(null, null, tableName)) {
String pkName = null;
List<PkColumn> columnNames = null;
while (resultSet.next()) {
if (columnNames == null) {
pkName = resultSet.getString("PK_NAME");
columnNames = new ArrayList<>(1);
} else {
assertThat(pkName).as("Multiple primary keys found").isEqualTo(resultSet.getString("PK_NAME"));
}
columnNames.add(new PkColumn(resultSet.getInt("KEY_SEQ") - 1, resultSet.getString("COLUMN_NAME")));
}
if (columnNames == null) {
return null;
}
return new PK(
pkName,
columnNames.stream()
.sorted(PkColumn.ORDERING_BY_INDEX)
.map(PkColumn::getName)
.collect(MoreCollectors.toList()));
}
}
private static final class PkColumn {
private static final Ordering<PkColumn> ORDERING_BY_INDEX = Ordering.natural().onResultOf(PkColumn::getIndex);
/** 0-based */
private final int index;
private final String name;
private PkColumn(int index, String name) {
this.index = index;
this.name = name;
}
public int getIndex() {
return index;
}
public String getName() {
return name;
}
}
@CheckForNull
private Integer getColumnIndex(ResultSet res, String column) {
try {
ResultSetMetaData meta = res.getMetaData();
int numCol = meta.getColumnCount();
for (int i = 1; i < numCol + 1; i++) {
if (meta.getColumnLabel(i).toLowerCase().equals(column.toLowerCase())) {
return i;
}
}
return null;
} catch (Exception e) {
throw new IllegalStateException("Fail to get column index");
}
}
private Set<String> getColumnNames(ResultSet res) {
try {
Set<String> columnNames = new HashSet<>();
ResultSetMetaData meta = res.getMetaData();
int numCol = meta.getColumnCount();
for (int i = 1; i < numCol + 1; i++) {
columnNames.add(meta.getColumnLabel(i).toLowerCase());
}
return columnNames;
} catch (Exception e) {
throw new IllegalStateException("Fail to get column names");
}
}
private IDataSet dbUnitDataSet(InputStream stream) {
try {
ReplacementDataSet dataSet = new ReplacementDataSet(new FlatXmlDataSet(stream));
dataSet.addReplacementObject("[null]", null);
dataSet.addReplacementObject("[false]", Boolean.FALSE);
dataSet.addReplacementObject("[true]", Boolean.TRUE);
return dataSet;
} catch (Exception e) {
throw translateException("Could not read the dataset stream", e);
}
}
private IDatabaseConnection dbUnitConnection() {
try {
IDatabaseConnection connection = db.getDbUnitTester().getConnection();
connection.getConfig().setProperty(DatabaseConfig.PROPERTY_DATATYPE_FACTORY, db.getDbUnitFactory());
return connection;
} catch (Exception e) {
throw translateException("Error while getting connection", e);
}
}
public static RuntimeException translateException(String msg, Exception cause) {
RuntimeException runtimeException = new RuntimeException(String.format("%s: [%s] %s", msg, cause.getClass().getName(), cause.getMessage()));
runtimeException.setStackTrace(cause.getStackTrace());
return runtimeException;
}
private static void doClobFree(Clob clob) throws SQLException {
try {
clob.free();
} catch (AbstractMethodError e) {
// JTS driver do not implement free() as it's using JDBC 3.0
}
}
private void closeQuietly(@Nullable IDatabaseConnection connection) {
try {
if (connection != null) {
connection.close();
}
} catch (SQLException e) {
// ignore
}
}
public Connection openConnection() throws SQLException {
return getConnection();
}
private Connection getConnection() throws SQLException {
return db.getDatabase().getDataSource().getConnection();
}
public Database database() {
return db.getDatabase();
}
public DatabaseCommands getCommands() {
return db.getCommands();
}
/**
* An {@link AutoCloseable} supplier of {@link Connection}.
*/
protected interface ConnectionSupplier extends AutoCloseable {
Connection get() throws SQLException;
@Override
void close();
}
private static class PK {
@CheckForNull
private final String name;
private final List<String> columns;
private PK(@Nullable String name, List<String> columns) {
this.name = name;
this.columns = ImmutableList.copyOf(columns);
}
@CheckForNull
public String getName() {
return name;
}
public List<String> getColumns() {
return columns;
}
}
private class NewConnectionSupplier implements ConnectionSupplier {
private Connection connection;
@Override
public Connection get() throws SQLException {
if (this.connection == null) {
this.connection = getConnection();
}
return this.connection;
}
@Override
public void close() {
if (this.connection != null) {
try {
this.connection.close();
} catch (SQLException e) {
Loggers.get(CoreDbTester.class).warn("Fail to close connection", e);
// do not re-throw the exception
}
}
}
}
}