package marubinotto.util;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import javax.sql.DataSource;
import org.apache.commons.lang.UnhandledException;
import org.dbunit.Assertion;
import org.dbunit.DatabaseUnitException;
import org.dbunit.database.DatabaseConfig;
import org.dbunit.database.DatabaseConnection;
import org.dbunit.database.IDatabaseConnection;
import org.dbunit.dataset.Column;
import org.dbunit.dataset.DataSetException;
import org.dbunit.dataset.DefaultDataSet;
import org.dbunit.dataset.DefaultTable;
import org.dbunit.dataset.DefaultTableMetaData;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.ITable;
import org.dbunit.dataset.ITableMetaData;
import org.dbunit.dataset.datatype.DataType;
import org.dbunit.dataset.filter.DefaultColumnFilter;
import org.dbunit.dataset.xml.XmlDataSet;
import org.dbunit.dataset.xml.XmlDataSetWriter;
import org.dbunit.ext.h2.H2DataTypeFactory;
import org.dbunit.operation.DatabaseOperation;
import org.springframework.jdbc.datasource.DataSourceUtils;
import org.springframework.jdbc.datasource.SingleConnectionDataSource;
import org.springframework.util.ClassUtils;
public class RdbUtils {
public static final String H2_JDBC_DRIVER = "org.h2.Driver";
public static final String H2_JDBC_URL = "jdbc:h2:mem:";
public static final String H2_JDBC_USERNAME = "sa";
public static final String H2_JDBC_PASSWORD = "";
public static DataSource getInMemoryDataSource(String name) {
try {
Class.forName(H2_JDBC_DRIVER, true, ClassUtils.getDefaultClassLoader());
}
catch (ClassNotFoundException e) {
throw new UnhandledException(e);
}
String url = H2_JDBC_URL;
if (name != null) {
url = url + name;
}
String username = H2_JDBC_USERNAME;
String password = H2_JDBC_PASSWORD;
return new SingleConnectionDataSource(url, username, password, true);
}
// Create DataSet
public static IDataSet createDataSet(String tableName, Object[][] data)
throws DataSetException {
return new DefaultDataSet(createTable(tableName, data));
}
public static ITable createTable(String tableName, Object[][] data)
throws DataSetException {
if (data.length == 0) {
throw new DataSetException("data is empty.");
}
ITableMetaData metaData = getMetaDataFromHeader(tableName, data[0]);
DefaultTable table = new DefaultTable(metaData);
for (int i = 1; i < data.length; i++) {
table.addRow(data[i]);
}
return table;
}
private static ITableMetaData getMetaDataFromHeader(String tableName,
Object[] header) throws DataSetException {
if (header.length == 0) {
throw new DataSetException("header is empty.");
}
Column[] columns = new Column[header.length];
for (int i = 0; i < header.length; i++) {
columns[i] = new Column(header[i].toString(), DataType.UNKNOWN);
}
return new DefaultTableMetaData(tableName, columns);
}
// Database connection
public static Connection getSpringTransactionalConnection(
DataSource dataSource) {
return DataSourceUtils.getConnection(dataSource);
}
private static IDatabaseConnection setUpConnection(Connection jdbcConnection)
throws SQLException {
IDatabaseConnection connection = new DatabaseConnection(jdbcConnection);
connection.getConfig().setProperty(
DatabaseConfig.PROPERTY_DATATYPE_FACTORY, new H2DataTypeFactory());
return connection;
}
private static final String EXPORT_ENCODING = "UTF-8";
public static void exportAllAsXml(Connection jdbcConnection,
OutputStream output) throws SQLException, DataSetException, IOException {
exportAsXml(jdbcConnection, null, output);
}
public static void exportAsXml(Connection jdbcConnection,
String[] tableNames, OutputStream output) throws SQLException,
DataSetException, IOException {
Assert.Arg.notNull(jdbcConnection, "jdbcConnection");
Assert.Arg.notNull(output, "output");
IDatabaseConnection connection = setUpConnection(jdbcConnection);
IDataSet dataSet = null;
if (tableNames != null) {
dataSet = connection.createDataSet(tableNames);
}
else {
dataSet = connection.createDataSet();
}
Writer writer = new OutputStreamWriter(output, EXPORT_ENCODING);
XmlDataSetWriter xmlWriter = new XmlDataSetWriter(writer, EXPORT_ENCODING);
xmlWriter.write(dataSet);
}
public static void cleanImportXml(Connection jdbcConnection, InputStream input)
throws SQLException, IOException, DatabaseUnitException {
Assert.Arg.notNull(jdbcConnection, "jdbcConnection");
Assert.Arg.notNull(input, "input");
XmlDataSet dataSet = new XmlDataSet(input);
IDatabaseConnection connection = setUpConnection(jdbcConnection);
DatabaseOperation.CLEAN_INSERT.execute(connection, dataSet);
}
public static void deleteAll(Connection jdbcConnection, String tableName)
throws DatabaseUnitException, SQLException {
IDatabaseConnection connection = setUpConnection(jdbcConnection);
IDataSet dataSet = new DefaultDataSet(new DefaultTable(tableName));
DatabaseOperation.DELETE_ALL.execute(connection, dataSet);
}
public static void cleanInsert(Connection jdbcConnection, String tableName,
Object[][] data) throws DatabaseUnitException, SQLException {
IDatabaseConnection connection = setUpConnection(jdbcConnection);
IDataSet dataSet = createDataSet(tableName, data);
DatabaseOperation.CLEAN_INSERT.execute(connection, dataSet);
}
public static void update(Connection jdbcConnection, String tableName,
Object[][] data) throws DatabaseUnitException, SQLException {
IDatabaseConnection connection = setUpConnection(jdbcConnection);
IDataSet dataSet = createDataSet(tableName, data);
DatabaseOperation.UPDATE.execute(connection, dataSet);
}
public static void cleanInsertMergedDataSet(Connection jdbcConnection,
String table, Object[][] base, Object[][] diff) throws Exception {
if (diff != null) {
RdbUtils.cleanInsert(jdbcConnection, table, RdbUtils.merge(base, diff));
}
else {
RdbUtils.cleanInsert(jdbcConnection, table, base);
}
}
public static void insert(Connection jdbcConnection, String tableName,
Object[][] data) throws DatabaseUnitException, SQLException {
IDatabaseConnection connection = setUpConnection(jdbcConnection);
IDataSet dataSet = createDataSet(tableName, data);
DatabaseOperation.INSERT.execute(connection, dataSet);
}
public static ITable getTableData(Connection jdbcConnection, String tableName)
throws SQLException, DataSetException {
IDatabaseConnection connection = setUpConnection(jdbcConnection);
IDataSet databaseDataSet = connection.createDataSet();
return databaseDataSet.getTable(tableName);
}
public static ITable query(Connection jdbcConnection, String resultName,
String sql) throws SQLException, DataSetException {
IDatabaseConnection connection = setUpConnection(jdbcConnection);
return connection.createQueryTable(resultName, sql);
}
public static void assertTableEmpty(Connection jdbcConnection,
String tableName) throws SQLException, DataSetException {
ITable table = getTableData(jdbcConnection, tableName);
junit.framework.Assert.assertEquals("Table <" + tableName
+ "> should be empty.", 0, table.getRowCount());
}
public static void assertEquals(ITable expectedTable, ITable actualTable)
throws Exception {
Assertion.assertEquals(expectedTable, DefaultColumnFilter
.includedColumnsTable(actualTable, expectedTable.getTableMetaData()
.getColumns()));
}
@SuppressWarnings({"rawtypes", "unchecked"})
public static Object[][] merge(Object[][] base, Object[][] diff) {
Assert.Arg.notNull(base, "base");
Assert.Arg.notNull(diff, "diff");
if (base.length == 0) {
return diff;
}
if (diff.length == 0) {
return base;
}
List baseHeader = Arrays.asList(base[0]);
Object[] diffHeader = diff[0];
LinkedHashSet mergedHeader = new LinkedHashSet();
mergedHeader.addAll(baseHeader);
mergedHeader.addAll(Arrays.asList(diffHeader));
int resultFieldCount = mergedHeader.size();
int resultRowCount = Math.max(base.length, diff.length);
Object[][] result = new String[resultRowCount][0];
for (int rowIndex = 0; rowIndex < resultRowCount; rowIndex++) {
Object[] mergedRow = new String[resultFieldCount];
if (rowIndex < diff.length) {
if (rowIndex < base.length) {
System.arraycopy(base[rowIndex], 0, mergedRow, 0,
base[rowIndex].length);
}
// Decide to add or replace for each diff field values
int addedFeildCounter = 0;
for (int fieldIndex = 0; fieldIndex < diff[rowIndex].length; fieldIndex++) {
Object fieldName = diffHeader[fieldIndex];
Object value = diff[rowIndex][fieldIndex];
if (baseHeader.contains(fieldName)) {
int fieldIndexToOverwrite = baseHeader.indexOf(fieldName);
mergedRow[fieldIndexToOverwrite] = value;
}
else {
mergedRow[baseHeader.size() + addedFeildCounter] = value;
addedFeildCounter++;
}
}
}
else if (rowIndex < base.length) {
System
.arraycopy(base[rowIndex], 0, mergedRow, 0, base[rowIndex].length);
}
result[rowIndex] = mergedRow;
}
return result;
}
}