package com.ctrip.platform.dal.dao.task; import java.sql.SQLException; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import org.apache.commons.lang.StringUtils; import com.ctrip.platform.dal.dao.DalHints; import com.ctrip.platform.dal.dao.StatementParameters; import com.ctrip.platform.dal.dao.UpdatableEntity; import com.ctrip.platform.dal.exceptions.DalException; import com.ctrip.platform.dal.exceptions.ErrorCode; public class BatchUpdateTask<T> extends AbstractIntArrayBulkTask<T> { public static final String TMPL_SQL_UPDATE = "UPDATE %s SET %s WHERE %s"; @Override public BulkTaskContext<T> createTaskContext(DalHints hints, List<Map<String, ?>> daoPojos, List<T> rawPojos) throws DalException { BulkTaskContext<T> taskContext = new BulkTaskContext<T>(rawPojos); Map<String, Boolean> pojoFieldStatus = taskContext.isUpdatableEntity() ? filterUpdatableEntity(hints, rawPojos) : filterNullColumns(hints, daoPojos); if(pojoFieldStatus.size() == 0) throw new DalException(ErrorCode.ValidateFieldCount); taskContext.setPojoFieldStatus(pojoFieldStatus); return taskContext; } @Override public int[] execute(DalHints hints, Map<Integer, Map<String, ?>> daoPojos, BulkTaskContext<T> taskContext) throws SQLException { List<T> rawPojos = taskContext.getRawPojos(); boolean isUpdatableEntity = taskContext.isUpdatableEntity(); Map<String, Boolean> pojoFieldStatus = taskContext.getPojoFieldStatus(); StatementParameters[] parametersList = new StatementParameters[daoPojos.size()]; int i = 0; String[] updateColumnNames = pojoFieldStatus.keySet().toArray(new String[pojoFieldStatus.size()]); for (Integer index :daoPojos.keySet()) { Map<String, ?> pojo = daoPojos.get(index); StatementParameters parameters = new StatementParameters(); if(isUpdatableEntity && !hints.isUpdateUnchangedField()) addParameters(parameters, pojo, updateColumnNames, ((UpdatableEntity)rawPojos.get(index)).getUpdatedColumns()); else addParameters(parameters, pojo, updateColumnNames); addParameters(parameters, pojo, parser.getPrimaryKeyNames()); addVersion(parameters, pojo); parametersList[i++] = parameters; } String batchUpdateSql = buildBatchUpdateSql(getTableName(hints), pojoFieldStatus); int[] result = client.batchUpdate(batchUpdateSql, parametersList, hints); return result; } public void addParameters(StatementParameters parameters, Map<String, ?> entries, String[] validColumns, Set<String> updatedColumns) { int index = parameters.size() + 1; for(String column : validColumns){ Object value = updatedColumns.contains(column) ? entries.get(column) : null; addParameter(parameters, index++, column, value); } } /** * Find out all columns that are not changed to reduce the batch update sql size * E.g * C1 C2 C3 C4 C5 C6 * E1 x x x * E2 x x x * * final not changed columns: C1 and C2 * final always changed columns: C3 and C4 * final may changed columns: C2 and C5 * * So C1 and C2 will be removed from final update sql * C3 and C4 will using set value * C2 and C5 will use set ifnull */ private Map<String, Boolean> filterUpdatableEntity(DalHints hints, List<T> rawPojos) { Set<String> qualifiedColumns = filterColumns(hints); Map<String, Boolean> columnStatus = new HashMap<String, Boolean>(); for(String column: qualifiedColumns) columnStatus.put(column, false); if(hints.isUpdateUnchangedField()) { return columnStatus; } Set<String> unChangedFields = new HashSet<>(qualifiedColumns); Set<String> changedFields = new HashSet<>(qualifiedColumns); for (T pojo: rawPojos) { if(unChangedFields.isEmpty()) break; Set<String> updatedColumns = getUpdatedColumns(pojo); if(updatedColumns.size() == 0) continue; unChangedFields.removeAll(updatedColumns); changedFields.retainAll(updatedColumns); } for(String unChangedField: unChangedFields) columnStatus.remove(unChangedField); Set<String> remain = new HashSet<>(columnStatus.keySet()); remain.removeAll(changedFields); for(String maybeChangedField: remain) columnStatus.put(maybeChangedField, true); return columnStatus; } private Map<String, Boolean> filterNullColumns(DalHints hints, List<Map<String, ?>> daoPojos) { Set<String> qualifiedColumns = filterColumns(hints); Map<String, Boolean> columnStatus = new HashMap<String, Boolean>(); for(String column: qualifiedColumns) columnStatus.put(column, false); if(hints.isUpdateNullField()) { return columnStatus; } String[] columnsToCheck = qualifiedColumns.toArray(new String[qualifiedColumns.size()]); Set<String> nullFields = new HashSet<>(qualifiedColumns); Set<String> notNullFields = new HashSet<>(nullFields); for (Map<String, ?> pojo: daoPojos) { if(notNullFields.isEmpty() && nullFields.isEmpty()) break; for (int i = 0; i < columnsToCheck.length; i++) { String colName = columnsToCheck[i]; boolean isNull = pojo.get(colName) == null; Set<String> check = isNull ? notNullFields : nullFields; if(!check.isEmpty() && check.contains(colName)) check.remove(colName); } } for(String nullField: nullFields) columnStatus.remove(nullField); Set<String> remain = new HashSet<>(columnStatus.keySet()); remain.removeAll(notNullFields); for(String maybeNullField: remain) columnStatus.put(maybeNullField, true); return columnStatus; } private String buildBatchUpdateSql(String tableName, Map<String, Boolean> pojoFieldStatus) { List<String> updateColumnTmpls = new ArrayList<>(pojoFieldStatus.size()); for(Map.Entry<String, Boolean> fieldStatus: pojoFieldStatus.entrySet()) { String columnName = fieldStatus.getKey(); String quotedColumnName = quote(columnName); // If the field contains null value if(fieldStatus.getValue()) updateColumnTmpls.add(String.format(setValueTmpl, quotedColumnName, quotedColumnName)); else updateColumnTmpls.add(String.format(TMPL_SET_VALUE, quotedColumnName)); } if(isVersionUpdatable) updateColumnTmpls.add(setVersionValueTmpl); String updateColumnsTmpl = StringUtils.join(updateColumnTmpls, COLUMN_SEPARATOR); return String.format(TMPL_SQL_UPDATE, tableName, updateColumnsTmpl, updateCriteriaTmpl); } private void addVersion(StatementParameters parameters, Map<String, ?> pojo) throws DalException { if(!hasVersion) return; Object version = pojo.get(parser.getVersionColumn()); if(version == null) throw new DalException(ErrorCode.ValidateVersion); addParameter(parameters, parameters.size() + 1, parser.getVersionColumn(), version); } }