/*
* Copyright 2013 The Sculptor Project Team, including the original
* author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.sculptor.framework.test;
import java.io.IOException;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Properties;
import java.util.Set;
import javax.persistence.EntityManager;
import javax.persistence.Table;
import javax.sql.DataSource;
import org.dbunit.database.DatabaseConfig;
import org.dbunit.database.DatabaseConnection;
import org.dbunit.database.IDatabaseConnection;
import org.junit.After;
import org.junit.Before;
import org.sculptor.framework.test.ejbtestbean.jpa.JpaTestLocal;
import org.sculptor.framework.util.db.DbUnitDataSourceUtils;
import org.sculptor.framework.util.db.HsqlDataTypeFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Base class for <a href=
* "http://www.oracle.com/technetwork/java/javaee/tech/persistence-jsp-140049.html"
* >JPA</a> and <a href="http://www.dbunit.org">DBUnit</a> tests in a <a
* href="http://openejb.apache.org/">OpenEJB</a> environment.
* <p>
* Inject dependencies to EJBs with the ordinary <code>@EJB</code> annotation.
* <p>
* Override the method {@link #getDataSetFile} to specify XML file with DBUnit
* test data.
*
* @author Patrik Nordwall
*
*/
public abstract class AbstractOpenEJBDbUnitTest extends AbstractOpenEJBTest {
private final Logger log = LoggerFactory.getLogger(getClass());
private EntityManager entityManager;
private DataSource dataSource;
private JpaTestLocal jpaTestBean;
public AbstractOpenEJBDbUnitTest() {
}
@Before
@Override
public void initialize() throws Exception {
super.initialize();
setUpDatabaseTester();
}
protected Set<String> getPersistentUnitNames() {
try {
PersistenceXmlParser persistenceXmlParser = new PersistenceXmlParser();
String persistenceXml = DataHelper.content("/META-INF/persistence.xml");
persistenceXmlParser.parse(persistenceXml);
return persistenceXmlParser.getPersictenceUnitNames();
} catch (IOException e) {
throw new RuntimeException(e.getMessage(), e);
}
}
@Override
protected void initOpenEjb() throws Exception {
super.initOpenEjb();
jpaTestBean = lookup(getTestBeanJndiName());
if (jpaTestBean == null) {
throw new IllegalStateException("Couldn't find " + getMessagingTestBeanJndiName());
}
entityManager = jpaTestBean.getEntityManager();
dataSource = jpaTestBean.getDataSource();
}
protected String getTestBeanJndiName() {
return "JpaTestBeanLocal";
}
@Override
protected void additionalInitialContextProperties(Properties defaultProperties) {
for (String unitName : getPersistentUnitNames()) {
initPersistenceUnitProperties(unitName, defaultProperties);
}
}
/**
* Overrides some properties defined for persistent units in "persistence.xml".
*/
protected void initPersistenceUnitProperties(String unitName, Properties properties) {
properties.put(unitName + ".hibernate.dialect", "org.sculptor.framework.persistence.CustomHSQLDialect");
properties.put(unitName + ".hibernate.show_sql", "true");
properties.put(unitName + ".hibernate.hbm2ddl.auto", "create-drop");
properties.put(unitName + ".hibernate.cache.use_query_cache", "false");
properties.put(unitName + ".hibernate.cache.use_second_level_cache", "false");
}
protected EntityManager getEntityManager() {
return entityManager;
}
protected DataSource getDataSource() {
return dataSource;
}
/**
* setup dbunit DatabaseTester/DataSet in transaction
*
* @throws Exception
*/
protected void setUpDatabaseTester() throws Exception {
DbUnitDataSourceUtils.setUpDatabaseTester(getClass(), getDataSource(), getDataSetFile());
restartSequence();
}
/**
* Start the id sequence from a high value to avoid conflicts with test
* data. You can define the sequence name with {@link #getSequenceName}.
*/
protected void restartSequence() {
String sequenceName = getSequenceName();
if (sequenceName == null) {
return;
}
try {
DbUnitDataSourceUtils.restartSequence(getConnection(), sequenceName);
} catch (Exception e) {
log.debug("Couldn't restart sequence: " + sequenceName);
}
}
/**
* In case you don't need to start the id sequence from a high value to
* avoid conflicts with test data you should override this method and return
* null.
*/
protected String getSequenceName() {
return null;
}
@After
public void tearDownDatabaseTester() throws Exception {
DbUnitDataSourceUtils.tearDownDatabaseTester();
}
/**
* Override this method to specify the XML file with DBUnit test data. If
* filename is not set, DbUnitDataSourceUtils will guess a filename.
*
* @return the filename with test data
*/
protected String getDataSetFile() {
return null;
}
protected IDatabaseConnection getConnection() throws Exception {
IDatabaseConnection connection = new DatabaseConnection(getDataSource().getConnection());
DatabaseConfig config = connection.getConfig();
config.setProperty(DatabaseConfig.PROPERTY_DATATYPE_FACTORY, new HsqlDataTypeFactory());
return connection;
}
protected int countRowsInTable(Class<?> domainObjectClass) throws Exception {
return countRowsInTable(domainObjectClass, "");
}
/**
* Counts the number of rows from a table via jdbc. Table name is picked for @Table
* annotation of the domainObjectClass
*
* @param domainObjectClass
* persistent class defining the name of the table for counting
* rows
* @param condition
* additional condition
* @return number of rows
*/
protected int countRowsInTable(Class<?> domainObjectClass, String condition) throws Exception {
String table;
if (domainObjectClass.isAnnotationPresent(Table.class)) {
table = domainObjectClass.getAnnotation(Table.class).name();
} else {
table = domainObjectClass.getSimpleName();
}
return countRowsInTable(table, condition);
}
protected int countRowsInTable(String table) throws Exception {
return countRowsInTable(table, "");
}
/**
* counts the number of rows from a table via jdbc
*
* @param tableName
* name of the table for counting rows
* @param condition
* additional condition
* @return number of rows
*/
protected int countRowsInTable(String table, String condition) throws Exception {
Connection con = null;
Statement stmt = null;
ResultSet rs = null;
try {
con = getConnection().getConnection();
stmt = con.createStatement();
rs = stmt.executeQuery("select count(*) as rowcount from " + table + " " + condition);
rs.next();
int count = rs.getInt("rowcount");
return count;
} catch (SQLException e) {
throw e;
} finally {
close(con, stmt, rs);
}
}
protected void logDb() {
IDatabaseConnection connection = null;
try {
connection = getConnection();
DbUnitDataSourceUtils.logDb(connection);
} catch (Exception e) {
throw new RuntimeException(e.getMessage(), e);
} finally {
if (connection != null) {
try {
connection.close();
} catch (SQLException ignore) {
}
}
}
}
protected static void close(Connection con, Statement stmt, ResultSet rs) {
if (rs != null) {
try {
rs.close();
} catch (SQLException ignore) {
}
}
if (stmt != null) {
try {
stmt.close();
} catch (SQLException ignore) {
}
}
if (con != null) {
try {
con.close();
} catch (SQLException ignore) {
}
}
}
}