/* * Copyright 2008 Pavel Syrtsov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and limitations * under the License. */ package com.sf.ddao.shards; import com.google.inject.Guice; import com.google.inject.Injector; import com.mockrunner.jdbc.JDBCTestModule; import com.mockrunner.jdbc.PreparedStatementResultSetHandler; import com.mockrunner.mock.jdbc.JDBCMockObjectFactory; import com.mockrunner.mock.jdbc.MockResultSet; import com.sf.ddao.*; import com.sf.ddao.chain.ChainModule; import com.sf.ddao.factory.param.ThreadLocalParameter; import com.sf.ddao.orm.RSMapper; import com.sf.ddao.orm.UseRSMapper; import com.sf.ddao.orm.rsmapper.rowmapper.BeanRowMapperFactory; import com.sf.ddao.orm.rsmapper.rowmapper.RowMapper; import junit.framework.TestCase; import org.apache.commons.chain.Context; import org.mockejb.jndi.MockContextFactory; import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.List; /** * Created by: Pavel Syrtsov * Date: Apr 6, 2007 * Time: 7:00:11 PM */ public class ShardedDaoTest extends TestCase { Injector injector; private static final String PART_NAME = "testPartName"; private JDBCTestModule testModule1; private JDBCTestModule testModule2; @ShardedDao(TestShardingService.class) public static interface TestUserDao extends TransactionableDao { /** * in this statement we assume that 1st method arg is Java Bean * and refer to property by name. It works same way for Map. * * @param userBean - parameter object * @return object created from data returned by sql */ @Select("select id, name from user where id = #id#") TestUserBean getUser(@ShardKey("id") TestUserBean userBean); /** * MultiShardSelect annotation allows to execute SQl statement * on multiple shards and takes care of merging results from multiple shards * by default it assumes that result is a collection of objects and will do merge of collections. * To provide custom merger logic annotation allows to define value for resultMerger class. * * @param userIdList * @return merged */ @MultiShardSelect("select id, name from user_data where user_id in ($ctx:keyList$)") List<TestUserBean> getUserDataList(@ShardKey List<Integer> userIdList); /** * 1st parameter passed by reference, 2nd by value (by injecting result of toString() into SQL). * * @param tableName name of table * @param size - max size of array * @param userId - query parameter * @return objects created from data returned by sql */ @Select("select id, name from $0$ where user_id = #2# limit #1#") TestUserBean[] getUserDataArray(String tableName, int size, @ShardKey long userId); @Select("select id, name from user_data where user_id = #0#") void processUserData(@ShardKey long userId, @UseRSMapper RSMapper selectCallback); /** * values that have ':' with prefix assumed to be call to predefined static function registered by ParameterService, * there are few if them predefined: * prefix threadLocal: allows to pass value using ThreadLocal * prefix ctx: allows to pass value using Context object in method arguments * prefix joinList: allows to join list of keys in comma separated string * * @param userId - query paramter * @return value returned by query */ @Select("select id from user_data where part = '$threadLocal:" + PART_NAME + "$' and user_id = #0#") int getUserData(@ShardKey long userId); @SelectThenInsert({"select nextval from userIdSequence", "insert into user(id,name) values(#threadLocal:id#, #name#)"}) long addUser(@ShardKey("id") TestUserBean user); } protected void setUp() throws Exception { this.injector = Guice.createInjector(new ChainModule(TestUserDao.class)); super.setUp(); MockContextFactory.setAsInitial(); JDBCMockObjectFactory mockFactory1 = new JDBCMockObjectFactory(); testModule1 = new JDBCTestModule(mockFactory1); JDBCMockObjectFactory mockFactory2 = new JDBCMockObjectFactory(); testModule2 = new JDBCTestModule(mockFactory2); final TestShardingService controlDao = injector.getInstance(TestShardingService.class); controlDao.setDS1(mockFactory1.getMockDataSource()); controlDao.setDS2(mockFactory2.getMockDataSource()); } protected void tearDown() throws Exception { super.tearDown(); MockContextFactory.revertSetAsInitial(); } private void createResultSet(JDBCTestModule testModule, Object... data) { PreparedStatementResultSetHandler handler = testModule.getPreparedStatementResultSetHandler(); MockResultSet rs = handler.createResultSet(); for (int i = 0; i < data.length; i++) { Object colName = data[i++]; Object colValues = data[i]; rs.addColumn(colName.toString(), (Object[]) colValues); } handler.prepareGlobalResultSet(rs); } public void testSingleRecordGet() throws Exception { // create dao object TestUserDao dao = injector.getInstance(TestUserDao.class); // reuse it for multiple invocations getUserOnce(testModule1, dao, 1, "foo1", false); getUserOnce(testModule1, dao, 10, "foo2", false); getUserOnce(testModule2, dao, 11, "bar1", false); getUserOnce(testModule2, dao, 20, "bar2", false); } private void getUserOnce(JDBCTestModule testModule, TestUserDao dao, int id, String name, boolean inTx) throws Exception { // setup test TestUserBean data = new TestUserBean(true); data.setId(id); data.setName(name); createResultSet(testModule, "id", new Object[]{data.getId()}, "name", new Object[]{data.getName()}); // execute dao method TestUserBean res = dao.getUser(data); // verify result assertNotNull(res); assertEquals(res.getId(), data.getId()); assertEquals(res.getName(), data.getName()); testModule.verifySQLStatementExecuted("select id, name from user where id = ?"); testModule.verifyAllResultSetsClosed(); testModule.verifyAllStatementsClosed(); if (!inTx) { testModule.verifyConnectionClosed(); } } public void testMultiShardGetRecordList() throws Exception { TestUserDao dao = injector.getInstance(TestUserDao.class); // setup test createResultSet(testModule1, "id", new Object[]{1, 2}, "name", new Object[]{"u1", "u2"}); createResultSet(testModule2, "id", new Object[]{15, 16}, "name", new Object[]{"u15", "u16"}); List<Integer> userIdList = new ArrayList<Integer>(); userIdList.add(1); userIdList.add(2); userIdList.add(15); userIdList.add(16); // execute dao method List<TestUserBean> res = dao.getUserDataList(userIdList); Collections.sort(res, new Comparator<TestUserBean>() { public int compare(TestUserBean testUserBean, TestUserBean testUserBean1) { return (int) (testUserBean.getId() - testUserBean1.getId()); } }); // verify result assertNotNull(res); assertEquals(4, res.size()); assertEquals(1, res.get(0).getId()); assertEquals("u1", res.get(0).getName()); assertEquals(2, res.get(1).getId()); assertEquals("u2", res.get(1).getName()); assertEquals(15, res.get(2).getId()); assertEquals("u15", res.get(2).getName()); assertEquals(16, res.get(3).getId()); assertEquals("u16", res.get(3).getName()); testModule1.verifySQLStatementExecuted("select id, name from user"); testModule1.verifyAllResultSetsClosed(); testModule1.verifyAllStatementsClosed(); testModule1.verifyConnectionClosed(); testModule2.verifySQLStatementExecuted("select id, name from user"); testModule2.verifyAllResultSetsClosed(); testModule2.verifyAllStatementsClosed(); testModule2.verifyConnectionClosed(); } public void testGetUserArray() throws Exception { // execute dao method TestUserDao dao = injector.getInstance(TestUserDao.class); getUserDataArray(dao, testModule1, 1); getUserDataArray(dao, testModule1, 10); getUserDataArray(dao, testModule2, 11); getUserDataArray(dao, testModule2, 20); } private void getUserDataArray(TestUserDao dao, JDBCTestModule testModule, int userId) { // setup test createResultSet(testModule, "id", new Object[]{1, 2}, "name", new Object[]{"foo", "bar"}); TestUserBean[] res = dao.getUserDataArray("user", 2, userId); // verify result assertNotNull(res); assertEquals(res.length, 2); assertEquals(res[0].getId(), 1); assertEquals(res[0].getName(), "foo"); assertEquals(res[1].getId(), 2); assertEquals(res[1].getName(), "bar"); testModule.verifySQLStatementExecuted("select id, name from user where user_id = ? limit ?"); testModule.verifySQLStatementParameter("select id, name from user where user_id = ? limit ?", 0, 2, 2); testModule.verifyAllResultSetsClosed(); testModule.verifyAllStatementsClosed(); testModule.verifyConnectionClosed(); } public void testSelectCallback() throws Exception { TestUserDao dao = injector.getInstance(TestUserDao.class); processUserData(dao, testModule1); } private void processUserData(TestUserDao dao, JDBCTestModule testModule) { // setup test createResultSet(testModule, "id", new Object[]{1, 2}, "name", new Object[]{"foo", "bar"}); final List<TestUserBean> res = new ArrayList<TestUserBean>(); // execute dao method dao.processUserData(1, new RSMapper() { RowMapper rowMapper = new BeanRowMapperFactory(TestUserBean.class).get(); public Object handle(Context context, ResultSet rs) throws SQLException { while (rs.next()) { res.add((TestUserBean) rowMapper.map(rs)); } return null; } }); // verify result assertNotNull(res); assertEquals(res.size(), 2); assertEquals(res.get(0).getId(), 1); assertEquals(res.get(0).getName(), "foo"); assertEquals(res.get(1).getId(), 2); assertEquals(res.get(1).getName(), "bar"); testModule.verifySQLStatementExecuted("select id, name from user"); testModule.verifyAllResultSetsClosed(); testModule.verifyAllStatementsClosed(); testModule.verifyConnectionClosed(); } public void testUsingStaticFunction() throws Exception { TestUserDao dao = injector.getInstance(TestUserDao.class); getUserData(dao, 1, testModule1, 0); getUserData(dao, 10, testModule1, 1); getUserData(dao, 11, testModule2, 0); getUserData(dao, 20, testModule2, 1); } private void getUserData(TestUserDao dao, long userId, JDBCTestModule testModule, int idx) { // setup test final int id = 11; final String testPart = "testPart"; createResultSet(testModule, "id", new Object[]{id}); ThreadLocalParameter.put(PART_NAME, testPart); // execute dao method int res = dao.getUserData(userId); // verify result ThreadLocalParameter.remove(PART_NAME); assertEquals(id, res); testModule.verifySQLStatementExecuted("select id from user_data where part = '" + testPart + "' and user_id = ?"); testModule.verifyPreparedStatementParameter(idx, 1, userId); testModule.verifyAllResultSetsClosed(); testModule.verifyAllStatementsClosed(); testModule.verifyConnectionClosed(); } public void testTx() throws Exception { final long id = 7; final String testName = "testName"; // execute dao method final TestUserDao dao = injector.getInstance(TestUserDao.class); final TestUserBean user = new TestUserBean(true); user.setName(testName); TxHelper.execInTx(dao, new Runnable() { public void run() { try { createResultSet(testModule1, "nextval", new Object[]{id}); final long res = dao.addUser(user); final Connection connection1 = TxHelper.getConnectionOnHold(); assertNotNull(connection1); assertFalse(connection1.isClosed()); testModule1.verifyNotCommitted(); getUserOnce(testModule1, dao, 11, "user11", true); final Connection connection2 = TxHelper.getConnectionOnHold(); assertSame(connection1, connection2); assertFalse(connection2.isClosed()); testModule1.verifyNotCommitted(); assertEquals(id, res); } catch (Exception e) { throw new RuntimeException(e); } } }, id); final Connection connection = TxHelper.getConnectionOnHold(); assertNull(connection); testModule1.verifyCommitted(); testModule1.verifySQLStatementExecuted("select nextval from userIdSequence"); testModule1.verifySQLStatementExecuted("insert into user(id,name) values(?, ?)"); testModule1.verifyPreparedStatementParameter(1, 1, id); testModule1.verifyPreparedStatementParameter(1, 2, testName); testModule1.verifyAllResultSetsClosed(); testModule1.verifyAllStatementsClosed(); testModule1.verifyConnectionClosed(); } }