package org.opencloudb.parser.druid.impl; import java.sql.SQLNonTransientException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.log4j.Logger; import org.opencloudb.config.model.SchemaConfig; import org.opencloudb.mpp.RangeValue; import org.opencloudb.parser.druid.DruidParser; import org.opencloudb.parser.druid.DruidShardingParseInfo; import org.opencloudb.parser.druid.MycatSchemaStatVisitor; import org.opencloudb.route.RouteResultset; import com.alibaba.druid.sql.ast.SQLStatement; import com.alibaba.druid.stat.TableStat.Condition; /** * 对SQLStatement解析 * 主要通过visitor解析和statement解析:有些类型的SQLStatement通过visitor解析足够了, * 有些只能通过statement解析才能得到所有信息 * 有些需要通过两种方式解析才能得到完整信息 * @author wang.dw * */ public class DefaultDruidParser implements DruidParser { protected static final Logger LOGGER = Logger.getLogger(DefaultDruidParser.class); /** * 解析得到的结果 */ protected DruidShardingParseInfo ctx; private Map<String,String> tableAliasMap = new HashMap<String,String>(); private List<Condition> conditions = new ArrayList<Condition>(); public Map<String, String> getTableAliasMap() { return tableAliasMap; } public List<Condition> getConditions() { return conditions; } /** * 使用MycatSchemaStatVisitor解析,得到tables、tableAliasMap、conditions等 * @param schema * @param stmt */ public void parser(SchemaConfig schema, RouteResultset rrs, SQLStatement stmt) throws SQLNonTransientException { ctx = new DruidShardingParseInfo(); //通过visitor解析 visitorParse(rrs,stmt); //通过Statement解析 statementParse(schema, rrs, stmt); //改写sql:如insert语句主键自增长的可以 changeSql(schema, rrs, stmt); ctx.setSql(stmt.toString()); } /** * 子类可覆盖(如果visitorParse解析得不到表名、字段等信息的,就通过覆盖该方法来解析) * 子类覆盖该方法一般是将SQLStatement转型后再解析(如转型为MySqlInsertStatement) */ @Override public void statementParse(SchemaConfig schema, RouteResultset rrs, SQLStatement stmt) throws SQLNonTransientException { } /** * 改写sql:如insert是 */ @Override public void changeSql(SchemaConfig schema, RouteResultset rrs, SQLStatement stmt) throws SQLNonTransientException { } /** * 子类可覆盖(如果该方法解析得不到表名、字段等信息的,就覆盖该方法,覆盖成空方法,然后通过statementPparse去解析) * 通过visitor解析:有些类型的Statement通过visitor解析得不到表名、 * @param stmt */ @Override public void visitorParse(RouteResultset rrs, SQLStatement stmt) throws SQLNonTransientException{ MycatSchemaStatVisitor visitor = new MycatSchemaStatVisitor(); stmt.accept(visitor); if(visitor.getAliasMap() != null) { for(Map.Entry<String, String> entry : visitor.getAliasMap().entrySet()) { String key = entry.getKey(); String value = entry.getValue(); if(key != null && key.indexOf("`") >= 0) { key = key.replaceAll("`", ""); } if(value != null && value.indexOf("`") >= 0) { value = value.replaceAll("`", ""); } //表名前面带database的,去掉 if(key != null) { int pos = key.indexOf("."); if(pos> 0) { key = key.substring(pos + 1); } } if(key.equals(value)) { ctx.addTable(key.toUpperCase()); } else { tableAliasMap.put(key, value); } } ctx.setTableAliasMap(tableAliasMap); } conditions = visitor.getConditions(); //遍历condition ,找分片字段 for(Condition condition : conditions) { List<Object> values = condition.getValues(); if(values.size() == 0) { break; } if(checkConditionValues(values)) { String columnName = removeBackquote(condition.getColumn().getName().toUpperCase()); String tableName = removeBackquote(condition.getColumn().getTable().toUpperCase()); String operator = condition.getOperator(); //between \ in 、>= > = < =< ,in和=是一样的处理逻辑 if(operator.equals("between")) { RangeValue rv = new RangeValue(values.get(0), values.get(1), RangeValue.EE); ctx.addShardingExpr(tableName.toUpperCase(), columnName, rv); } else { ctx.addShardingExpr(tableName.toUpperCase(), columnName, values.toArray()); } } } } private boolean checkConditionValues(List<Object> values) { for(Object value : values) { if(value != null && !value.toString().equals("")) { return true; } } return false; } public DruidShardingParseInfo getCtx() { return ctx; } /** * 移除`符号 * @param str * @return */ public String removeBackquote(String str){ //删除名字中的`tablename`和'value' if (str.length() > 0) { StringBuilder sb = new StringBuilder(str); if (sb.charAt(0) == '`'||sb.charAt(0) == '\'') { sb.deleteCharAt(0); } if (sb.charAt(sb.length() - 1) == '`'||sb.charAt(sb.length() - 1) == '\'') { sb.deleteCharAt(sb.length() - 1); } return sb.toString(); } return ""; } }