/* * SonarQube * Copyright (C) 2009-2017 SonarSource SA * mailto:info AT sonarsource DOT com * * This program is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public * License as published by the Free Software Foundation; either * version 3 of the License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public License * along with this program; if not, write to the Free Software Foundation, * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. */ package org.sonar.db; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Ordering; import java.io.InputStream; import java.math.BigDecimal; import java.sql.Clob; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; import java.sql.Timestamp; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Date; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Set; import javax.annotation.CheckForNull; import javax.annotation.Nullable; import org.apache.commons.dbutils.QueryRunner; import org.apache.commons.io.IOUtils; import org.apache.commons.lang.StringUtils; import org.dbunit.Assertion; import org.dbunit.DatabaseUnitException; import org.dbunit.assertion.DiffCollectingFailureHandler; import org.dbunit.assertion.Difference; import org.dbunit.database.DatabaseConfig; import org.dbunit.database.IDatabaseConnection; import org.dbunit.dataset.CompositeDataSet; import org.dbunit.dataset.IDataSet; import org.dbunit.dataset.ITable; import org.dbunit.dataset.ReplacementDataSet; import org.dbunit.dataset.filter.DefaultColumnFilter; import org.dbunit.dataset.xml.FlatXmlDataSet; import org.dbunit.ext.mssql.InsertIdentityOperation; import org.dbunit.operation.DatabaseOperation; import org.junit.rules.ExternalResource; import org.sonar.api.utils.log.Loggers; import org.sonar.core.util.stream.MoreCollectors; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.Lists.asList; import static com.google.common.collect.Lists.newArrayList; import static com.google.common.collect.Maps.newHashMap; import static java.sql.ResultSetMetaData.columnNoNulls; import static java.sql.ResultSetMetaData.columnNullable; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.fail; public class AbstractDbTester<T extends CoreTestDb> extends ExternalResource { protected static final Joiner COMMA_JOINER = Joiner.on(", "); protected final T db; public AbstractDbTester(T db) { this.db = db; } public void executeUpdateSql(String sql, Object... params) { try (Connection connection = getConnection()) { new QueryRunner().update(connection, sql, params); if (!connection.getAutoCommit()) { connection.commit(); } } catch (SQLException e) { SQLException nextException = e.getNextException(); if (nextException != null) { throw new IllegalStateException("Fail to execute sql: " + sql, new SQLException(e.getMessage(), nextException.getSQLState(), nextException.getErrorCode(), nextException)); } throw new IllegalStateException("Fail to execute sql: " + sql, e); } catch (Exception e) { throw new IllegalStateException("Fail to execute sql: " + sql, e); } } public void executeDdl(String ddl) { try (Connection connection = getConnection(); Statement stmt = connection.createStatement()) { stmt.execute(ddl); } catch (SQLException e) { throw new IllegalStateException("Failed to execute DDL: " + ddl, e); } } /** * Very simple helper method to insert some data into a table. * It's the responsibility of the caller to convert column values to string. */ public void executeInsert(String table, String firstColumn, Object... others) { executeInsert(table, mapOf(firstColumn, others)); } private static Map<String, Object> mapOf(String firstColumn, Object... values) { ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder(); List<Object> args = asList(firstColumn, values); for (int i = 0; i < args.size(); i++) { String key = args.get(i).toString(); Object value = args.get(i + 1); if (value != null) { builder.put(key, value); } i++; } return builder.build(); } /** * Very simple helper method to insert some data into a table. * It's the responsibility of the caller to convert column values to string. */ public void executeInsert(String table, Map<String, Object> valuesByColumn) { if (valuesByColumn.isEmpty()) { throw new IllegalArgumentException("Values cannot be empty"); } String sql = "insert into " + table.toLowerCase(Locale.ENGLISH) + " (" + COMMA_JOINER.join(valuesByColumn.keySet()) + ") values (" + COMMA_JOINER.join(Collections.nCopies(valuesByColumn.size(), '?')) + ")"; executeUpdateSql(sql, valuesByColumn.values().toArray(new Object[valuesByColumn.size()])); } /** * Returns the number of rows in the table. Example: * <pre>int issues = countRowsOfTable("issues")</pre> */ public int countRowsOfTable(String tableName) { return countRowsOfTable(tableName, new NewConnectionSupplier()); } protected int countRowsOfTable(String tableName, ConnectionSupplier connectionSupplier) { checkArgument(StringUtils.containsNone(tableName, " "), "Parameter must be the name of a table. Got " + tableName); return countSql("select count(1) from " + tableName.toLowerCase(Locale.ENGLISH), connectionSupplier); } /** * Executes a SQL request starting with "SELECT COUNT(something) FROM", for example: * <pre>int OpenIssues = countSql("select count('id') from issues where status is not null")</pre> */ public int countSql(String sql) { return countSql(sql, new NewConnectionSupplier()); } protected int countSql(String sql, ConnectionSupplier connectionSupplier) { checkArgument(StringUtils.contains(sql, "count("), "Parameter must be a SQL request containing 'count(x)' function. Got " + sql); try ( ConnectionSupplier supplier = connectionSupplier; PreparedStatement stmt = supplier.get().prepareStatement(sql); ResultSet rs = stmt.executeQuery()) { if (rs.next()) { return rs.getInt(1); } throw new IllegalStateException("No results for " + sql); } catch (Exception e) { throw new IllegalStateException("Fail to execute sql: " + sql, e); } } public List<Map<String, Object>> select(String selectSql) { return select(selectSql, new NewConnectionSupplier()); } protected List<Map<String, Object>> select(String selectSql, ConnectionSupplier connectionSupplier) { try ( ConnectionSupplier supplier = connectionSupplier; PreparedStatement stmt = supplier.get().prepareStatement(selectSql); ResultSet rs = stmt.executeQuery()) { return getHashMap(rs); } catch (Exception e) { throw new IllegalStateException("Fail to execute sql: " + selectSql, e); } } public Map<String, Object> selectFirst(String selectSql) { return selectFirst(selectSql, new NewConnectionSupplier()); } protected Map<String, Object> selectFirst(String selectSql, ConnectionSupplier connectionSupplier) { List<Map<String, Object>> rows = select(selectSql, connectionSupplier); if (rows.isEmpty()) { throw new IllegalStateException("No results for " + selectSql); } else if (rows.size() > 1) { throw new IllegalStateException("Too many results for " + selectSql); } return rows.get(0); } private static List<Map<String, Object>> getHashMap(ResultSet resultSet) throws Exception { ResultSetMetaData metaData = resultSet.getMetaData(); int colCount = metaData.getColumnCount(); List<Map<String, Object>> rows = newArrayList(); while (resultSet.next()) { Map<String, Object> columns = newHashMap(); for (int i = 1; i <= colCount; i++) { Object value = resultSet.getObject(i); if (value instanceof Clob) { Clob clob = (Clob) value; value = IOUtils.toString((clob.getAsciiStream())); doClobFree(clob); } else if (value instanceof BigDecimal) { // In Oracle, INTEGER types are mapped as BigDecimal BigDecimal bgValue = ((BigDecimal) value); if (bgValue.scale() == 0) { value = bgValue.longValue(); } else { value = bgValue.doubleValue(); } } else if (value instanceof Integer) { // To be consistent, all INTEGER types are mapped as Long value = ((Integer) value).longValue(); } else if (value instanceof Timestamp) { value = new Date(((Timestamp) value).getTime()); } columns.put(metaData.getColumnLabel(i), value); } rows.add(columns); } return rows; } public void prepareDbUnit(Class testClass, String... testNames) { InputStream[] streams = new InputStream[testNames.length]; try { for (int i = 0; i < testNames.length; i++) { String path = "/" + testClass.getName().replace('.', '/') + "/" + testNames[i]; streams[i] = testClass.getResourceAsStream(path); if (streams[i] == null) { throw new IllegalStateException("DbUnit file not found: " + path); } } prepareDbUnit(streams); db.getCommands().resetPrimaryKeys(db.getDatabase().getDataSource()); } catch (SQLException e) { throw translateException("Could not setup DBUnit data", e); } finally { for (InputStream stream : streams) { IOUtils.closeQuietly(stream); } } } private void prepareDbUnit(InputStream... dataSetStream) { IDatabaseConnection connection = null; try { IDataSet[] dataSets = new IDataSet[dataSetStream.length]; for (int i = 0; i < dataSetStream.length; i++) { dataSets[i] = dbUnitDataSet(dataSetStream[i]); } db.getDbUnitTester().setDataSet(new CompositeDataSet(dataSets)); connection = dbUnitConnection(); new InsertIdentityOperation(DatabaseOperation.INSERT).execute(connection, db.getDbUnitTester().getDataSet()); } catch (Exception e) { throw translateException("Could not setup DBUnit data", e); } finally { closeQuietly(connection); } } public void assertDbUnitTable(Class testClass, String filename, String table, String... columns) { IDatabaseConnection connection = dbUnitConnection(); try { IDataSet dataSet = connection.createDataSet(); String path = "/" + testClass.getName().replace('.', '/') + "/" + filename; IDataSet expectedDataSet = dbUnitDataSet(testClass.getResourceAsStream(path)); ITable filteredTable = DefaultColumnFilter.includedColumnsTable(dataSet.getTable(table), columns); ITable filteredExpectedTable = DefaultColumnFilter.includedColumnsTable(expectedDataSet.getTable(table), columns); Assertion.assertEquals(filteredExpectedTable, filteredTable); } catch (DatabaseUnitException e) { fail(e.getMessage()); } catch (SQLException e) { throw translateException("Error while checking results", e); } finally { closeQuietly(connection); } } public void assertDbUnit(Class testClass, String filename, String... tables) { assertDbUnit(testClass, filename, new String[0], tables); } public void assertDbUnit(Class testClass, String filename, String[] excludedColumnNames, String... tables) { IDatabaseConnection connection = null; try { connection = dbUnitConnection(); IDataSet dataSet = connection.createDataSet(); String path = "/" + testClass.getName().replace('.', '/') + "/" + filename; InputStream inputStream = testClass.getResourceAsStream(path); if (inputStream == null) { throw new IllegalStateException(String.format("File '%s' does not exist", path)); } IDataSet expectedDataSet = dbUnitDataSet(inputStream); for (String table : tables) { DiffCollectingFailureHandler diffHandler = new DiffCollectingFailureHandler(); ITable filteredTable = DefaultColumnFilter.excludedColumnsTable(dataSet.getTable(table), excludedColumnNames); ITable filteredExpectedTable = DefaultColumnFilter.excludedColumnsTable(expectedDataSet.getTable(table), excludedColumnNames); Assertion.assertEquals(filteredExpectedTable, filteredTable, diffHandler); // Evaluate the differences and ignore some column values List diffList = diffHandler.getDiffList(); for (Object o : diffList) { Difference diff = (Difference) o; if (!"[ignore]".equals(diff.getExpectedValue())) { throw new DatabaseUnitException(diff.toString()); } } } } catch (DatabaseUnitException e) { e.printStackTrace(); fail(e.getMessage()); } catch (Exception e) { throw translateException("Error while checking results", e); } finally { closeQuietly(connection); } } public void assertColumnDefinition(String table, String column, int expectedType, @Nullable Integer expectedSize) { assertColumnDefinition(table, column, expectedType, expectedSize, null); } public void assertColumnDefinition(String table, String column, int expectedType, @Nullable Integer expectedSize, @Nullable Boolean isNullable) { try (Connection connection = getConnection(); PreparedStatement stmt = connection.prepareStatement("select * from " + table); ResultSet res = stmt.executeQuery()) { Integer columnIndex = getColumnIndex(res, column); if (columnIndex == null) { fail("The column '" + column + "' does not exist"); } assertThat(res.getMetaData().getColumnType(columnIndex)).isEqualTo(expectedType); if (expectedSize != null) { assertThat(res.getMetaData().getColumnDisplaySize(columnIndex)).isEqualTo(expectedSize); } if (isNullable != null) { assertThat(res.getMetaData().isNullable(columnIndex)).isEqualTo(isNullable ? columnNullable : columnNoNulls); } } catch (Exception e) { throw new IllegalStateException("Fail to check column", e); } } public void assertColumnDoesNotExist(String table, String column) throws SQLException { try (Connection connection = getConnection(); PreparedStatement stmt = connection.prepareStatement("select * from " + table); ResultSet res = stmt.executeQuery()) { assertThat(getColumnNames(res)).doesNotContain(column); } } public void assertTableDoesNotExist(String table) { try (Connection connection = getConnection()) { boolean tableExists = DatabaseUtils.tableExists(table, connection); assertThat(tableExists).isFalse(); } catch (Exception e) { throw new IllegalStateException("Fail to check if table exists", e); } } /** * Verify that non-unique index exists on columns */ public void assertIndex(String tableName, String indexName, String expectedColumn, String... expectedSecondaryColumns) { assertIndexImpl(tableName, indexName, false, expectedColumn, expectedSecondaryColumns); } /** * Verify that unique index exists on columns */ public void assertUniqueIndex(String tableName, String indexName, String expectedColumn, String... expectedSecondaryColumns) { assertIndexImpl(tableName, indexName, true, expectedColumn, expectedSecondaryColumns); } private void assertIndexImpl(String tableName, String indexName, boolean expectedUnique, String expectedColumn, String... expectedSecondaryColumns) { try (Connection connection = getConnection(); ResultSet rs = connection.getMetaData().getIndexInfo(null, null, tableName.toUpperCase(Locale.ENGLISH), false, false)) { List<String> onColumns = new ArrayList<>(); while (rs.next()) { if (indexName.equalsIgnoreCase(rs.getString("INDEX_NAME"))) { assertThat(rs.getBoolean("NON_UNIQUE")).isEqualTo(!expectedUnique); int position = rs.getInt("ORDINAL_POSITION"); onColumns.add(position - 1, rs.getString("COLUMN_NAME").toLowerCase(Locale.ENGLISH)); } } assertThat(asList(expectedColumn, expectedSecondaryColumns)).isEqualTo(onColumns); } catch (SQLException e) { throw new IllegalStateException("Fail to check index", e); } } /** * Verify that index with name {@code indexName} does not exist on the table {@code tableName} */ public void assertIndexDoesNotExist(String tableName, String indexName) { try (Connection connection = getConnection(); ResultSet rs = connection.getMetaData().getIndexInfo(null, null, tableName.toUpperCase(Locale.ENGLISH), false, false)) { List<String> indices = new ArrayList<>(); while (rs.next()) { indices.add(rs.getString("INDEX_NAME").toLowerCase(Locale.ENGLISH)); } assertThat(indices).doesNotContain(indexName); } catch (SQLException e) { throw new IllegalStateException("Fail to check existence of index", e); } } public void assertPrimaryKey(String tableName, @Nullable String expectedPkName, String columnName, String... otherColumnNames) { try (Connection connection = getConnection()) { PK pk = pkOf(connection, tableName.toUpperCase(Locale.ENGLISH)); if (pk == null) { pkOf(connection, tableName.toLowerCase(Locale.ENGLISH)); } assertThat(pk).as("No primary key is defined on table %s", tableName).isNotNull(); if (expectedPkName != null) { assertThat(pk.getName()).isEqualToIgnoringCase(expectedPkName); } List<String> expectedColumns = ImmutableList.copyOf(Iterables.concat(Collections.singletonList(columnName), Arrays.asList(otherColumnNames))); assertThat(pk.getColumns()).as("Primary key does not have the '%s' expected columns", expectedColumns.size()).hasSize(expectedColumns.size()); Iterator<String> expectedColumnsIt = expectedColumns.iterator(); Iterator<String> actualColumnsIt = pk.getColumns().iterator(); while (expectedColumnsIt.hasNext() && actualColumnsIt.hasNext()) { assertThat(actualColumnsIt.next()).isEqualToIgnoringCase(expectedColumnsIt.next()); } } catch (SQLException e) { throw new IllegalStateException("Fail to check primary key", e); } } @CheckForNull private PK pkOf(Connection connection, String tableName) throws SQLException { try (ResultSet resultSet = connection.getMetaData().getPrimaryKeys(null, null, tableName)) { String pkName = null; List<PkColumn> columnNames = null; while (resultSet.next()) { if (columnNames == null) { pkName = resultSet.getString("PK_NAME"); columnNames = new ArrayList<>(1); } else { assertThat(pkName).as("Multiple primary keys found").isEqualTo(resultSet.getString("PK_NAME")); } columnNames.add(new PkColumn(resultSet.getInt("KEY_SEQ") - 1, resultSet.getString("COLUMN_NAME"))); } if (columnNames == null) { return null; } return new PK( pkName, columnNames.stream() .sorted(PkColumn.ORDERING_BY_INDEX) .map(PkColumn::getName) .collect(MoreCollectors.toList())); } } private static final class PkColumn { private static final Ordering<PkColumn> ORDERING_BY_INDEX = Ordering.natural().onResultOf(PkColumn::getIndex); /** 0-based */ private final int index; private final String name; private PkColumn(int index, String name) { this.index = index; this.name = name; } public int getIndex() { return index; } public String getName() { return name; } } @CheckForNull private Integer getColumnIndex(ResultSet res, String column) { try { ResultSetMetaData meta = res.getMetaData(); int numCol = meta.getColumnCount(); for (int i = 1; i < numCol + 1; i++) { if (meta.getColumnLabel(i).toLowerCase().equals(column.toLowerCase())) { return i; } } return null; } catch (Exception e) { throw new IllegalStateException("Fail to get column index"); } } private Set<String> getColumnNames(ResultSet res) { try { Set<String> columnNames = new HashSet<>(); ResultSetMetaData meta = res.getMetaData(); int numCol = meta.getColumnCount(); for (int i = 1; i < numCol + 1; i++) { columnNames.add(meta.getColumnLabel(i).toLowerCase()); } return columnNames; } catch (Exception e) { throw new IllegalStateException("Fail to get column names"); } } private IDataSet dbUnitDataSet(InputStream stream) { try { ReplacementDataSet dataSet = new ReplacementDataSet(new FlatXmlDataSet(stream)); dataSet.addReplacementObject("[null]", null); dataSet.addReplacementObject("[false]", Boolean.FALSE); dataSet.addReplacementObject("[true]", Boolean.TRUE); return dataSet; } catch (Exception e) { throw translateException("Could not read the dataset stream", e); } } private IDatabaseConnection dbUnitConnection() { try { IDatabaseConnection connection = db.getDbUnitTester().getConnection(); connection.getConfig().setProperty(DatabaseConfig.PROPERTY_DATATYPE_FACTORY, db.getDbUnitFactory()); return connection; } catch (Exception e) { throw translateException("Error while getting connection", e); } } public static RuntimeException translateException(String msg, Exception cause) { RuntimeException runtimeException = new RuntimeException(String.format("%s: [%s] %s", msg, cause.getClass().getName(), cause.getMessage())); runtimeException.setStackTrace(cause.getStackTrace()); return runtimeException; } private static void doClobFree(Clob clob) throws SQLException { try { clob.free(); } catch (AbstractMethodError e) { // JTS driver do not implement free() as it's using JDBC 3.0 } } private void closeQuietly(@Nullable IDatabaseConnection connection) { try { if (connection != null) { connection.close(); } } catch (SQLException e) { // ignore } } public Connection openConnection() throws SQLException { return getConnection(); } private Connection getConnection() throws SQLException { return db.getDatabase().getDataSource().getConnection(); } public Database database() { return db.getDatabase(); } public DatabaseCommands getCommands() { return db.getCommands(); } /** * An {@link AutoCloseable} supplier of {@link Connection}. */ protected interface ConnectionSupplier extends AutoCloseable { Connection get() throws SQLException; @Override void close(); } private static class PK { @CheckForNull private final String name; private final List<String> columns; private PK(@Nullable String name, List<String> columns) { this.name = name; this.columns = ImmutableList.copyOf(columns); } @CheckForNull public String getName() { return name; } public List<String> getColumns() { return columns; } } private class NewConnectionSupplier implements ConnectionSupplier { private Connection connection; @Override public Connection get() throws SQLException { if (this.connection == null) { this.connection = getConnection(); } return this.connection; } @Override public void close() { if (this.connection != null) { try { this.connection.close(); } catch (SQLException e) { Loggers.get(CoreDbTester.class).warn("Fail to close connection", e); // do not re-throw the exception } } } } }