package com.ctrip.platform.dal.dao.task; import static com.ctrip.platform.dal.dao.helper.DalShardingHelper.isAlreadySharded; import static com.ctrip.platform.dal.dao.helper.DalShardingHelper.isShardingEnabled; import static com.ctrip.platform.dal.dao.helper.DalShardingHelper.isTableShardingEnabled; import static com.ctrip.platform.dal.dao.helper.DalShardingHelper.shuffle; import static com.ctrip.platform.dal.dao.helper.DalShardingHelper.shuffleByTable; import java.sql.SQLException; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.Callable; import com.ctrip.platform.dal.dao.DalHints; import com.ctrip.platform.dal.exceptions.DalException; import com.ctrip.platform.dal.exceptions.ErrorCode; public class DalBulkTaskRequest<K, T> implements DalRequest<K>{ private String logicDbName; private String rawTableName; private DalHints hints; private List<T> rawPojos; private List<Map<String, ?>> daoPojos; private BulkTask<K, T> task; private BulkTaskContext<T> taskContext; private BulkTaskResultMerger<K> dbShardMerger; Map<String, Map<Integer, Map<String, ?>>> shuffled; public DalBulkTaskRequest(String logicDbName, String rawTableName, DalHints hints, List<T> rawPojos, BulkTask<K, T> task) { this.logicDbName = logicDbName; this.rawTableName = rawTableName; this.hints = hints; this.rawPojos = rawPojos; this.task = task; } @Override public void validate() throws SQLException { if(null == rawPojos) throw new DalException(ErrorCode.ValidatePojoList); if(task == null) throw new DalException(ErrorCode.ValidateTask); dbShardMerger = task.createMerger(); daoPojos = task.getPojosFields(rawPojos); taskContext = task.createTaskContext(hints, daoPojos, rawPojos); } @Override public boolean isCrossShard() throws SQLException { if(isAlreadySharded(logicDbName, rawTableName, hints)) return false; if(isShardingEnabled(logicDbName)) { shuffled = shuffle(logicDbName, hints.getShardId(), daoPojos); // Only in one or no shard return shuffled.size() <= 1 ? false : true; } // Shard at table level or no shard at all return false; } @Override public Callable<K> createTask() throws SQLException { hints = hints.clone(); handleKeyHolder(false); // If only one shard is shuffled if(shuffled != null) { if(shuffled.size() == 0) return new BulkTaskCallable<>(logicDbName, rawTableName, hints, new HashMap<Integer, Map<String, ?>>(), task, taskContext); String shard = shuffled.keySet().iterator().next(); return new BulkTaskCallable<>(logicDbName, rawTableName, hints.inShard(shard), shuffled.get(shard), task, taskContext); } // Convert to index map Map<Integer, Map<String, ?>> daoPojosMap = new HashMap<>(); for(int i = 0; i < daoPojos.size(); i++) daoPojosMap.put(i, daoPojos.get(i)); return new BulkTaskCallable<>(logicDbName, rawTableName, hints, daoPojosMap, task, taskContext); } @Override public Map<String, Callable<K>> createTasks() throws SQLException { Map<String, Callable<K>> tasks = new HashMap<>(); // I know this is not so elegant. handleKeyHolder(true); for(String shard: shuffled.keySet()) { Map<Integer, Map<String, ?>> pojosInShard = shuffled.get(shard); dbShardMerger.recordPartial(shard, pojosInShard.keySet().toArray(new Integer[pojosInShard.size()])); tasks.put(shard, new BulkTaskCallable<>( logicDbName, rawTableName, hints.clone().inShard(shard), shuffled.get(shard), task, taskContext)); } return tasks; } private void handleKeyHolder(boolean requireMerge) { if(hints.getKeyHolder() == null) return; hints.getKeyHolder().requireMerge(); } @Override public BulkTaskResultMerger<K> getMerger() { return dbShardMerger; } private static class BulkTaskCallable<K, T> implements Callable<K> { private String logicDbName; private String rawTableName; private DalHints hints; private Map<Integer, Map<String, ?>> shaffled; private BulkTask<K, T> task; private BulkTaskContext<T> taskContext; public BulkTaskCallable(String logicDbName, String rawTableName, DalHints hints, Map<Integer, Map<String, ?>> shaffled, BulkTask<K, T> task, BulkTaskContext<T> taskContext){ this.logicDbName = logicDbName; this.rawTableName = rawTableName; this.hints = hints; this.shaffled = shaffled; this.task = task; this.taskContext = taskContext; } @Override public K call() throws Exception { if(shaffled.isEmpty()) return task.getEmptyValue(); if(isTableShardingEnabled(logicDbName, rawTableName)) { return executeByTableShards(); }else{ return task.execute(hints, shaffled, taskContext); } } private K executeByTableShards() throws SQLException { BulkTaskResultMerger<K> merger = task.createMerger(); Map<String, Map<Integer, Map<String, ?>>> pojosInTable = shuffleByTable(logicDbName, hints.getTableShardId(), shaffled); if(pojosInTable.size() > 1 && hints.getKeyHolder() != null) { hints.getKeyHolder().requireMerge(); } DalHints tmpHints; for(String curTableShardId: pojosInTable.keySet()) { Map<Integer, Map<String, ?>> pojosInShard = pojosInTable.get(curTableShardId); tmpHints = hints.clone(); tmpHints.inTableShard(curTableShardId); merger.recordPartial(curTableShardId, pojosInShard.keySet().toArray(new Integer[pojosInShard.size()])); K partial = task.execute(tmpHints, pojosInShard, taskContext); merger.addPartial(curTableShardId, partial); } return merger.merge(); } } }