package com.github.looly.hutool.db; import java.sql.Connection; import java.sql.DatabaseMetaData; import java.sql.ParameterMetaData; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; import java.sql.Types; import java.util.ArrayList; import java.util.List; import java.util.Map.Entry; import javax.naming.InitialContext; import javax.naming.NamingException; import javax.sql.DataSource; import com.github.looly.hutool.Log; import com.github.looly.hutool.StrUtil; import com.github.looly.hutool.exceptions.UtilException; import org.slf4j.Logger; import com.github.looly.hutool.db.dialect.Dialect; import com.github.looly.hutool.db.dialect.DialectFactory; import com.github.looly.hutool.db.meta.Column; import com.github.looly.hutool.db.meta.Table; import com.github.looly.hutool.exceptions.DbRuntimeException; /** * 数据库操作工具类 * * @author Luxiaolei * */ public class DbUtil { private static Logger log = Log.get(); private DbUtil() { // 非可实例化类 } /** * 实例化一个新的SQL运行对象 * * @param ds 数据源 * @return SQL执行类 */ public static SqlRunner newSqlRunner(DataSource ds) { return new SqlRunner(ds); } /** * 实例化一个新的SQL运行对象 * * @param ds 数据源 * @param dialect SQL方言 * @return SQL执行类 */ public static SqlRunner newSqlRunner(DataSource ds, Dialect dialect) { return new SqlRunner(ds, dialect); } /** * 连续关闭一系列的SQL相关对象<br/> * 这些对象必须按照顺序关闭,否则会出错。 * * @param objsToClose 需要关闭的对象 */ public static void close(Object... objsToClose) { for (Object obj : objsToClose) { try { if (obj != null) { if (obj instanceof ResultSet) { ((ResultSet) obj).close(); } else if (obj instanceof Statement) { ((Statement) obj).close(); } else if (obj instanceof PreparedStatement) { ((PreparedStatement) obj).close(); } else if (obj instanceof Connection) { ((Connection) obj).close(); } else { log.warn("Object " + obj.getClass().getName() + " not a ResultSet or Statement or PreparedStatement or Connection!"); } } } catch (SQLException e) { } } } /** * 获得JNDI数据源 * @param jndiName JNDI名称 * @return 数据源 */ public static DataSource getJndiDs(String jndiName) { try { return (DataSource) new InitialContext().lookup(jndiName); } catch (NamingException e) { log.error("Find JNDI datasource error!", e); } return null; } /** * 获得所有表名 */ public static List<String> getTables(DataSource ds) { final List<String> tables = new ArrayList<String>(); Connection conn = null; ResultSet rs = null; try { conn = ds.getConnection(); final DatabaseMetaData metaData = conn.getMetaData(); rs = metaData.getTables(conn.getCatalog(), null, null, new String[]{"TABLES"}); if(rs == null) { return null; } while(rs.next()) { final String table = rs.getString("TABLE_NAME"); if(StrUtil.isBlank(table) == false) { tables.add(table); } } } catch (Exception e) { throw new UtilException("Get tables error!", e); }finally { close(rs, conn); } return tables; } /** * 获得结果集的所有列名 * @param rs 结果集 * @return 列名数组 */ public static String[] getColumnNames(ResultSet rs) { try { ResultSetMetaData rsmd = rs.getMetaData(); int columnCount = rsmd.getColumnCount(); String[] labelNames = new String[columnCount]; for (int i=0; i<labelNames.length; i++) { labelNames[i] = rsmd.getColumnLabel(i +1); } return labelNames; } catch (Exception e) { throw new UtilException("Get colunms error!", e); } } /** * 获得表的所有列名 * @param ds 数据源 * @param tableName 表名 * @return 列数组 * @throws SQLException */ public static String[] getColumnNames(DataSource ds, String tableName) { List<String> columnNames = new ArrayList<String>(); Connection conn = null; ResultSet rs = null; try { conn = ds.getConnection(); final DatabaseMetaData metaData = conn.getMetaData(); rs = metaData.getColumns(conn.getCatalog(), null, tableName, null); while(rs.next()) { columnNames.add(rs.getString("COLUMN_NAME")); } return columnNames.toArray(new String[columnNames.size()]); } catch (Exception e) { throw new UtilException("Get columns error!", e); }finally { close(rs, conn); } } /** * 获得表的元信息 * @param ds 数据源 * @param tableName 表名 * @return Table对象 */ @SuppressWarnings("resource") public static Table getTableMeta(DataSource ds, String tableName) { final Table table = Table.create(tableName); Connection conn = null; ResultSet rs = null; try { conn = ds.getConnection(); final DatabaseMetaData metaData = conn.getMetaData(); //获得主键 rs = metaData.getPrimaryKeys(conn.getCatalog(), null, tableName); while(rs.next()) { table.addPk("COLUMN_NAME"); } //获得列 rs = metaData.getColumns(conn.getCatalog(), null, tableName, null); while(rs.next()) { table.setColumn(Column.create(tableName, rs)); } } catch (Exception e) { throw new UtilException("Get columns error!", e); }finally { close(rs, conn); } return table; } /** * 填充SQL的参数。 * * @param ps PreparedStatement * @param params SQL参数 * @throws SQLException */ public static void fillParams(PreparedStatement ps, Object... params) throws SQLException { if (params == null) { return; } ParameterMetaData pmd = ps.getParameterMetaData(); for (int i = 0; i < params.length; i++) { int paramIndex = i + 1; if (params[i] != null) { ps.setObject(paramIndex, params[i]); } else { int sqlType = Types.VARCHAR; try { sqlType = pmd.getParameterType(paramIndex); } catch (SQLException e) { log.warn("Param get type fail, by: " + e.getMessage()); } ps.setNull(paramIndex, sqlType); } } } /** * 获得自增键的值 * @param ps PreparedStatement * @return 自增键的值 * @throws SQLException */ public static Long getGeneratedKey(PreparedStatement ps) throws SQLException { ResultSet rs = null; try { rs = ps.getGeneratedKeys(); Long generatedKey = null; if(rs != null && rs.next()) { generatedKey = rs.getLong(1); } return generatedKey; } catch (SQLException e) { throw e; }finally { close(rs); } } /** * 构件相等条件的where语句<br> * 如果没有条件语句,泽返回空串,表示没有条件 * @param entity 条件实体 * @param paramValues 条件值得存放List * @return 带where关键字的SQL部分 */ public static String buildEqualsWhere(Entity entity, List<Object> paramValues) { if(null == entity || entity.isEmpty()) { return StrUtil.EMPTY; } final StringBuilder sb = new StringBuilder(" WHERE "); boolean isNotFirst = false; for (Entry<String, Object> entry : entity.entrySet()) { if(isNotFirst) { sb.append(" and "); }else { isNotFirst = true; } sb.append("`").append(entry.getKey()).append("`").append(" = ?"); paramValues.add(entry.getValue()); } return sb.toString(); } /** * 识别JDBC驱动名 * @param nameContainsProductInfo 包含数据库标识的字符串 * @return 驱动 */ public static String identifyDriver(String nameContainsProductInfo) { if(StrUtil.isBlank(nameContainsProductInfo)) { return null; } nameContainsProductInfo = nameContainsProductInfo.toLowerCase(); String driver = null; if(nameContainsProductInfo.contains("mysql")) { driver = DialectFactory.DRIVER_MYSQL; }else if(nameContainsProductInfo.contains("oracle")) { driver = DialectFactory.DRIVER_ORACLE; }else if(nameContainsProductInfo.contains("postgresql")) { driver = DialectFactory.DRIVER_POSTGRESQL; }else if(nameContainsProductInfo.contains("sqlite")) { driver = DialectFactory.DRIVER_SQLLITE3; } return driver; } /** * 识别JDBC驱动名 * @param ds 数据源 * @return 驱动 */ public static String identifyDriver(DataSource ds) { Connection conn = null; String driver = null; try { conn = ds.getConnection(); driver = identifyDriver(conn); } catch (Exception e) { throw new DbRuntimeException("Identify driver error!", e); }finally { close(conn); } return driver; } /** * 识别JDBC驱动名 * @param conn 数据库连接对象 * @return 驱动 */ public static String identifyDriver(Connection conn) { String driver = null; try { DatabaseMetaData meta = conn.getMetaData(); driver = identifyDriver(meta.getDatabaseProductName()); if(StrUtil.isBlank(driver)) { driver = identifyDriver(meta.getDriverName()); } } catch (SQLException e) { throw new DbRuntimeException("Identify driver error!", e); } return driver; } //---------------------------------------------------------------------------- Private method start //---------------------------------------------------------------------------- Private method end }