package com.ctrip.platform.dal.dao.task; import static com.ctrip.platform.dal.dao.helper.DalShardingHelper.buildShardStr; import static com.ctrip.platform.dal.dao.helper.DalShardingHelper.detectDistributedTransaction; import static com.ctrip.platform.dal.dao.helper.DalShardingHelper.isTableShardingEnabled; import static com.ctrip.platform.dal.dao.helper.DalShardingHelper.locateTableShardId; import java.sql.SQLException; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.Callable; import com.ctrip.platform.dal.common.enums.ParameterDirection; import com.ctrip.platform.dal.dao.DalClient; import com.ctrip.platform.dal.dao.DalClientFactory; import com.ctrip.platform.dal.dao.DalHintEnum; import com.ctrip.platform.dal.dao.DalHints; import com.ctrip.platform.dal.dao.ResultMerger; import com.ctrip.platform.dal.dao.StatementParameter; import com.ctrip.platform.dal.dao.StatementParameters; import com.ctrip.platform.dal.dao.client.DalLogger; import com.ctrip.platform.dal.dao.helper.DalShardingHelper; import com.ctrip.platform.dal.dao.sqlbuilder.SqlBuilder; import com.ctrip.platform.dal.dao.sqlbuilder.TableSqlBuilder; public class DalSqlTaskRequest<T> implements DalRequest<T>{ private DalLogger logger; private String logicDbName; private SqlBuilder builder; private StatementParameters parameters; private DalHints hints; private SqlTask<T> task; private ResultMerger<T> merger; private Set<String> shards; private Map<String, List<?>> parametersByShard; public DalSqlTaskRequest(String logicDbName, SqlBuilder builder, DalHints hints, SqlTask<T> task, ResultMerger<T> merger) throws SQLException { logger = DalClientFactory.getDalLogger(); this.logicDbName = logicDbName; this.builder = builder; this.parameters = builder.buildParameters(); this.hints = hints; this.task = task; this.merger = merger; shards = getShards(); } @Override public void validate() throws SQLException { detectDistributedTransaction(shards); } @Override public boolean isCrossShard() { return shards != null && shards.size() > 1; } @Override public Callable<T> createTask() throws SQLException { DalHints tmpHints = hints.clone(); if(shards != null && shards.size() == 1) { tmpHints.inShard(shards.iterator().next()); } return create(parameters, tmpHints); } @Override public Map<String, Callable<T>> createTasks() throws SQLException { Map<String, Callable<T>> tasks = new HashMap<>(); if(parametersByShard == null) { // Create by given shards for(String shard: shards) { tasks.put(shard, create(parameters.duplicate(), hints.clone().inShard(shard))); } }else{ // Create by sharded values for(Map.Entry<String, ?> shard: parametersByShard.entrySet()) { StatementParameters tempParameters = parameters.duplicateWith(hints.getShardBy(), (List)shard.getValue()); tasks.put(shard.getKey(), create(tempParameters, hints.clone().inShard(shard.getKey()))); } } return tasks; } private Callable<T> create(StatementParameters parameters, DalHints hints) throws SQLException { if(builder instanceof TableSqlBuilder && isTableShardingEnabled(logicDbName, ((TableSqlBuilder)builder).getTableName())){ String tableShardStr = buildShardStr(logicDbName, locateTableShardId(logicDbName, hints, parameters, null)); return new SqlTaskCallable<>(DalClientFactory.getClient(logicDbName), ((TableSqlBuilder)builder).build(tableShardStr), parameters, hints, task); } return new SqlTaskCallable<>(DalClientFactory.getClient(logicDbName), builder.build(), parameters, hints, task); } @Override public ResultMerger<T> getMerger() { return merger; } private Set<String> getShards() throws SQLException { Set<String> shards = null; if(!DalShardingHelper.isShardingEnabled(logicDbName)) return null; if(hints.isAllShards()) { shards = DalClientFactory.getDalConfigure().getDatabaseSet(logicDbName).getAllShards(); } else if(hints.isInShards()){ shards = (Set<String>)hints.get(DalHintEnum.shards); } else if(hints.isShardBy()){ // The new code gen will set hints shardBy to indicate this is a potential cross shard operation // Check parameters. It can only surpport DB shard at this level StatementParameter parameter = parameters.get(hints.getShardBy(), ParameterDirection.Input); parametersByShard = DalShardingHelper.shuffle(logicDbName, (List)parameter.getValue()); shards = parametersByShard.keySet(); } if(shards != null && shards.size() > 1) logger.warn("Execute on multiple shards detected: " + builder.build()); return shards; } private static class SqlTaskCallable<T> implements Callable<T> { private DalClient client; private String sql; private StatementParameters parameters; private DalHints hints; private SqlTask<T> task; public SqlTaskCallable(DalClient client, String sql, StatementParameters parameters, DalHints hints, SqlTask<T> task) throws SQLException { this.client = client; this.sql = sql; this.hints = hints; this.task = task; this.parameters = parameters; compile(); } private void compile() throws SQLException { // If there is no in clause, just return if(!parameters.containsInParameter()) return; sql = SQLCompiler.compile(sql, parameters.getAllInParameters()); parameters.compile(); } @Override public T call() throws Exception { return task.execute(client, sql, parameters, hints); } } }