package org.opencloudb.parser.druid.impl; import java.sql.SQLNonTransientException; import java.sql.SQLSyntaxErrorException; import java.util.List; import org.opencloudb.config.model.SchemaConfig; import org.opencloudb.config.model.TableConfig; import org.opencloudb.route.RouteResultset; import com.alibaba.druid.sql.ast.SQLExpr; import com.alibaba.druid.sql.ast.SQLStatement; import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement; public class DruidInsertParser extends DefaultDruidParser { @Override public void visitorParse(RouteResultset rrs, SQLStatement stmt) throws SQLNonTransientException { } @Override public void statementParse(SchemaConfig schema, RouteResultset rrs, SQLStatement stmt) throws SQLNonTransientException { MySqlInsertStatement insert = (MySqlInsertStatement)stmt; String tableName = removeBackquote(insert.getTableName().getSimpleName()).toUpperCase(); ctx.addTable(tableName); TableConfig tc = schema.getTables().get(tableName); if(tc == null) { String msg = "can't find table define in schema " + tableName + " schema:" + schema.getName(); LOGGER.warn(msg); throw new SQLNonTransientException(msg); } else { String partitionColumn = tc.getPartitionColumn(); if(partitionColumn != null) {//分片表 //拆分表必须给出column list,否则无法寻找分片字段的值 if(insert.getColumns() == null || insert.getColumns().size() == 0) { throw new SQLSyntaxErrorException("partition table, insert must provide ColumnList"); } boolean isFound = false; if(insert.getValuesList().size() > 1 || insert.getQuery() != null) { //insert into .... select ....不能支持 //insert into table(id) values (),(),....不能支持 String inf = "insert multi rows not supported! "; //TODO 此处可优化拆分到多个分片执行,从而支持一次插入多行 LOGGER.warn(inf); throw new SQLNonTransientException(inf); } for(int i = 0; i < insert.getColumns().size(); i++) { if(partitionColumn.equalsIgnoreCase(removeBackquote(insert.getColumns().get(i).toString()))) {//找到分片字段 isFound = true; String column = removeBackquote(insert.getColumns().get(i).toString()); String value = removeBackquote(insert.getValues().getValues().get(i).toString()); ctx.addShardingExpr(tableName, column, value); //只单分片键,找到了就返回 break; } } if(!isFound) {//分片表的 String msg = "bad insert sql (sharding column:"+ partitionColumn + " not provided," + stmt; LOGGER.warn(msg); throw new SQLNonTransientException(msg); } // insert into .... on duplicateKey //such as :INSERT INTO TABLEName (a,b,c) VALUES (1,2,3) ON DUPLICATE KEY UPDATE b=VALUES(b); //INSERT INTO TABLEName (a,b,c) VALUES (1,2,3) ON DUPLICATE KEY UPDATE c=c+1; if(insert.getDuplicateKeyUpdate() != null) { List<SQLExpr> updateList = insert.getDuplicateKeyUpdate(); for(SQLExpr expr : updateList) { SQLBinaryOpExpr opExpr = (SQLBinaryOpExpr)expr; String column = removeBackquote(opExpr.getLeft().toString().toUpperCase()); if(column.equals(partitionColumn)) { String msg = "partion key can't be updated: " + tableName + " -> " + partitionColumn; LOGGER.warn(msg); throw new SQLNonTransientException(msg); } } } } } } }