package jane.core;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicIntegerArray;
import java.util.concurrent.atomic.AtomicReferenceArray;
import java.util.concurrent.locks.ReentrantLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock.ReadLock;
import java.util.concurrent.locks.ReentrantReadWriteLock.WriteLock;
import jane.core.SContext.Safe;
/**
* 事务的基类(抽象类)
*/
public abstract class Procedure implements Runnable
{
public interface ExceptionHandler
{
void onException(Throwable e);
}
static final class IndexLock extends ReentrantLock
{
private static final long serialVersionUID = 1L;
public final int index;
IndexLock(int i)
{
index = i;
}
}
private static final IndexLock[] _lockPool = new IndexLock[Const.lockPoolSize]; // 全局共享的锁池
private static final AtomicIntegerArray _lockVersions = new AtomicIntegerArray(Const.lockPoolSize); // 全局共享的锁版本号池
private static final AtomicReferenceArray<IndexLock> _lockCreator = new AtomicReferenceArray<>(_lockPool); // 锁池中锁的线程安全创造器(副本)
private static final int _lockMask = Const.lockPoolSize - 1; // 锁池下标的掩码
private static final ReentrantReadWriteLock _rwlCommit = new ReentrantReadWriteLock(); // 用于数据提交的读写锁
private static ExceptionHandler _defaultEh; // 默认的全局异常处理
private ProcThread _pt; // 事务所属的线程上下文. 只在事务运行中有效
private final AtomicInteger _running = new AtomicInteger(); // 事务是否在运行中(不能同时并发运行)
private volatile Object _sid; // 事务所属的SessionId
/**
* 获取提交的写锁
* <p>
* 用于提交线程暂停工作线程之用
*/
static WriteLock getWriteLock()
{
return _rwlCommit.writeLock();
}
/**
* 设置当前默认的异常处理器
*/
public static void setDefaultOnException(ExceptionHandler eh)
{
_defaultEh = eh;
}
/**
* 获取当前线程正在运行的事务
*/
public static Procedure getCurProcedure()
{
Thread t = Thread.currentThread();
return t instanceof ProcThread ? ((ProcThread)t).proc : null;
}
/**
* 判断当前线程现在是否在事务执行当中
*/
public static boolean inProcedure()
{
return getCurProcedure() != null;
}
static void incVersion(int lockId)
{
_lockVersions.incrementAndGet(lockId & _lockMask);
}
/**
* 获取当前事务绑定的sid
*/
public final Object getSid()
{
return _sid;
}
/**
* 设置当前事务绑定的sid
* <p>
* 为了安全只能由内部类设置
*/
final void setSid(Object sid)
{
_sid = sid;
}
protected final void addOnCommit(Runnable r)
{
_pt.sctx.addOnCommit(r);
}
protected final void addOnRollback(Runnable r)
{
_pt.sctx.addOnRollback(r);
}
/**
* 设置当前事务不可被打断
* <p>
* 可以避免事务超时时被打断,只能在事务运行中设置,会使getBeginTime()结果失效<br>
* 一般用于事务可能会运行较久的情况,但一般不推荐这样做
*/
protected final void setUnintterrupted()
{
ProcThread pt = _pt;
if(pt != null)
pt.beginTime = Long.MAX_VALUE;
}
@SuppressWarnings("serial")
private static final class Redo extends Error
{
private static final Redo _instance = new Redo();
@SuppressWarnings("sync-override")
@Override
public Throwable fillInStackTrace()
{
return this;
}
}
@SuppressWarnings("serial")
private static final class Undo extends Error
{
private static final Undo _instance = new Undo();
@SuppressWarnings("sync-override")
@Override
public Throwable fillInStackTrace()
{
return this;
}
}
public static Error redoException()
{
return Redo._instance;
}
public static Error undoException()
{
return Undo._instance;
}
public static void redo()
{
throw Redo._instance;
}
public static void undo()
{
throw Undo._instance;
}
public <V extends Bean<V>, S extends Safe<V>> S lockGet(TableLong<V, S> t, long k) throws InterruptedException
{
appendLock(t.lockId(k));
return t.getNoLock(k);
}
public <K, V extends Bean<V>, S extends Safe<V>> S lockGet(Table<K, V, S> t, K k) throws InterruptedException
{
appendLock(t.lockId(k));
return t.getNoLock(k);
}
public static void check(boolean a, boolean b)
{
if(a != b) throw Redo._instance;
}
public static void check(int a, int b)
{
if(a != b) throw Redo._instance;
}
public static void check(long a, long b)
{
if(a != b) throw Redo._instance;
}
public static void check(float a, float b)
{
if(a != b) throw Redo._instance;
}
public static void check(double a, double b)
{
if(a != b) throw Redo._instance;
}
public static void check(Object a, Object b)
{
if(!a.equals(b)) throw Redo._instance;
}
public static void check(Object a)
{
if(a == null) throw Redo._instance;
}
public static void checkNot(boolean a, boolean b)
{
if(a == b) throw Redo._instance;
}
public static void checkNot(int a, int b)
{
if(a == b) throw Redo._instance;
}
public static void checkNot(long a, long b)
{
if(a == b) throw Redo._instance;
}
public static void checkNot(float a, float b)
{
if(a == b) throw Redo._instance;
}
public static void checkNot(double a, double b)
{
if(a == b) throw Redo._instance;
}
public static void checkNot(Object a, Object b)
{
if(a.equals(b)) throw Redo._instance;
}
public static void checkNull(Object a)
{
if(a != null) throw Redo._instance;
}
/**
* 解锁当前事务所加的全部锁
* <p>
* 只能在事务中调用
*/
protected final void unlock()
{
ProcThread pt = _pt;
if(pt == null) throw new IllegalStateException("invalid lock/unlock out of procedure");
int lockCount = pt.lockCount;
if(lockCount == 0) return;
if(pt.sctx.hasDirty()) throw new IllegalStateException("invalid unlock after any dirty record");
IndexLock[] locks = pt.locks;
for(int i = lockCount - 1; i >= 0; --i)
{
try
{
locks[i].unlock();
}
catch(Throwable e)
{
Log.log.error("UNLOCK FAILED!!!", e);
}
}
pt.lockCount = 0;
}
/**
* 根据lockId获取实际的锁对象
*/
private static IndexLock getLock(int lockIdx)
{
IndexLock lock = _lockPool[lockIdx];
if(lock != null) return lock;
if(!_lockCreator.compareAndSet(lockIdx, null, lock = new IndexLock(lockIdx))) // ensure init lock object only once
lock = _lockCreator.get(lockIdx); // should not be null
_lockPool[lockIdx] = lock; // still safe when overwritten
return lock;
}
/**
* 判断lockId是否已被获取到锁
*/
public static boolean isLocked(int lockId)
{
return getLock(lockId & _lockMask).isLocked();
}
/**
* 判断lockId是否已被当前线程获取到锁
*/
public static boolean isLockedByCurrentThread(int lockId)
{
return getLock(lockId & _lockMask).isHeldByCurrentThread();
}
/**
* 尝试加锁一个lockId
* <p>
* 只用于内部提交数据
*/
static IndexLock tryLock(int lockId)
{
IndexLock lock = getLock(lockId & _lockMask);
return lock.tryLock() ? lock : null;
}
/**
* 加锁一个lockId
* <p>
* lockId通过{@link Table}/{@link TableLong}的lockId方法获取<br>
* 只能在事务中调用, 加锁前会释放当前事务已经加过的锁
*/
protected final void lock(int lockId) throws InterruptedException
{
unlock();
ProcThread pt = _pt;
(pt.locks[0] = getLock(lockId & _lockMask)).lockInterruptibly();
pt.lockCount = 1;
}
/**
* 追加一个lockId的锁
* <p>
* 会引发已加锁的重排序并重锁,并检测两次锁之间是否有修改的序列号变化,如果有则抛出Redo异常<br>
* 只能在事务中调用. 且此调用之前的事务不能有写操作
*/
protected final void appendLock(int lockId) throws InterruptedException
{
final ProcThread pt = _pt;
if(pt == null) throw new IllegalStateException("invalid appendLock out of procedure");
final IndexLock[] locks = pt.locks;
final int lockIdx = lockId & _lockMask;
final int n = pt.lockCount;
if(n == 0)
{
(locks[0] = getLock(lockIdx)).lockInterruptibly();
pt.lockCount = 1;
return;
}
if(pt.sctx.hasDirty()) throw new IllegalStateException("invalid appendLock after any dirty record");
IndexLock lastLock = locks[n - 1];
int lastLockIdx = lastLock.index;
if(lastLockIdx == lockIdx) return;
if(lastLockIdx < lockIdx)
{
if(n >= Const.maxLockPerProcedure)
throw new IllegalStateException("appendLock exceed: " + (n + 1) + '>' + Const.maxLockPerProcedure);
(locks[n] = getLock(lockIdx)).lockInterruptibly();
pt.lockCount = n + 1;
return;
}
int i = n - 1;
for(; i > 0; --i) // 算出需要插入锁的下标位置i时跳出循环
{
lastLockIdx = locks[i - 1].index;
if(lastLockIdx == lockIdx) return;
if(lastLockIdx < lockIdx) break;
}
if(n >= Const.maxLockPerProcedure)
throw new IllegalStateException("appendLock exceed: " + (n + 1) + '>' + Const.maxLockPerProcedure);
final int[] versions = pt.versions;
for(int j = n - 1; j >= i; --j)
{
lastLock = locks[j];
versions[j] = _lockVersions.get(lastLock.index);
lastLock.unlock();
}
pt.lockCount = i;
(locks[i] = getLock(lockIdx)).lockInterruptibly();
pt.lockCount = ++i;
for(;;)
{
final IndexLock lock = locks[i];
(locks[i] = lastLock).lockInterruptibly();
pt.lockCount = ++i;
if(_lockVersions.get(lastLock.index) != versions[i - 2])
redo();
if(i > n) return;
lastLock = lock;
}
}
/**
* 加锁一组lockId
* <p>
* lockId通过{@link Table}/{@link TableLong}的lockId方法获取<br>
* 只能在事务中调用, 加锁前会释放当前事务已经加过的锁
* @param lockIds 注意此数组内的元素会被修改和排序
*/
protected final void lock(int[] lockIds) throws InterruptedException
{
unlock();
int n = lockIds.length;
if(n > Const.maxLockPerProcedure)
throw new IllegalStateException("lock exceed: " + n + '>' + Const.maxLockPerProcedure);
for(int i = 0; i < n; ++i)
lockIds[i] &= _lockMask;
Arrays.sort(lockIds);
ProcThread pt = _pt;
IndexLock[] locks = pt.locks;
for(int i = 0; i < n;)
{
(locks[i] = getLock(lockIds[i])).lockInterruptibly();
pt.lockCount = ++i;
}
}
/**
* 加锁一组lockId
* <p>
* lockId通过{@link Table}/{@link TableLong}的lockId方法获取<br>
* 只能在事务中调用, 加锁前会释放当前事务已经加过的锁
* @param lockIds 此容器内的元素不会改动
*/
protected final void lock(Collection<Integer> lockIds) throws InterruptedException
{
unlock();
int n = lockIds.size();
if(n > Const.maxLockPerProcedure)
throw new IllegalStateException("lock exceed: " + n + '>' + Const.maxLockPerProcedure);
int[] idxes = new int[n];
int i = 0;
if(lockIds instanceof ArrayList)
{
ArrayList<Integer> lockList = (ArrayList<Integer>)lockIds;
for(; i < n; ++i)
idxes[i] = lockList.get(i) & _lockMask;
}
else
{
for(int lockId : lockIds)
idxes[i++] = lockId & _lockMask;
}
Arrays.sort(idxes);
ProcThread pt = _pt;
IndexLock[] locks = pt.locks;
for(i = 0; i < n;)
{
(locks[i] = getLock(idxes[i])).lockInterruptibly();
pt.lockCount = ++i;
}
}
/**
* 加锁一组lockId
* <p>
* lockId通过{@link Table}/{@link TableLong}的lockId方法获取<br>
* 只能在事务中调用, 加锁前会释放当前事务已经加过的锁
*/
protected final void lock(int lockId0, int lockId1, int lockId2, int lockId3, int... lockIds) throws InterruptedException
{
int n = lockIds.length;
if(n + 4 > Const.maxLockPerProcedure)
throw new IllegalStateException("lock exceed: " + (n + 4) + '>' + Const.maxLockPerProcedure);
lockIds = Arrays.copyOf(lockIds, n + 4);
lockIds[n] = lockId0;
lockIds[n + 1] = lockId1;
lockIds[n + 2] = lockId2;
lockIds[n + 3] = lockId3;
lock(lockIds);
}
/**
* 内部用于排序加锁2个lockId
* <p>
*/
private void lock2(int lockIdx0, int lockIdx1) throws InterruptedException
{
ProcThread pt = _pt;
IndexLock[] locks = pt.locks;
int i = pt.lockCount;
if(lockIdx0 < lockIdx1)
{
(locks[i] = getLock(lockIdx0)).lockInterruptibly();
pt.lockCount = ++i;
(locks[i] = getLock(lockIdx1)).lockInterruptibly();
pt.lockCount = ++i;
}
else
{
(locks[i] = getLock(lockIdx1)).lockInterruptibly();
pt.lockCount = ++i;
(locks[i] = getLock(lockIdx0)).lockInterruptibly();
pt.lockCount = ++i;
}
}
/**
* 内部用于排序加锁3个lockId
* <p>
*/
private void lock3(int lockIdx0, int lockIdx1, int lockIdx2) throws InterruptedException
{
ProcThread pt = _pt;
IndexLock[] locks = pt.locks;
int i = pt.lockCount;
if(lockIdx0 <= lockIdx1)
{
if(lockIdx0 < lockIdx2)
{
(locks[i] = getLock(lockIdx0)).lockInterruptibly();
pt.lockCount = ++i;
lock2(lockIdx1, lockIdx2);
}
else
{
(locks[i] = getLock(lockIdx2)).lockInterruptibly();
pt.lockCount = ++i;
(locks[i] = getLock(lockIdx0)).lockInterruptibly();
pt.lockCount = ++i;
(locks[i] = getLock(lockIdx1)).lockInterruptibly();
pt.lockCount = ++i;
}
}
else
{
if(lockIdx1 < lockIdx2)
{
(locks[i] = getLock(lockIdx1)).lockInterruptibly();
pt.lockCount = ++i;
lock2(lockIdx0, lockIdx2);
}
else
{
(locks[i] = getLock(lockIdx2)).lockInterruptibly();
pt.lockCount = ++i;
(locks[i] = getLock(lockIdx1)).lockInterruptibly();
pt.lockCount = ++i;
(locks[i] = getLock(lockIdx0)).lockInterruptibly();
pt.lockCount = ++i;
}
}
}
/**
* 加锁2个lockId
* <p>
* lockId通过{@link Table}/{@link TableLong}的lockId方法获取<br>
* 只能在事务中调用, 加锁前会释放当前事务已经加过的锁<br>
* 这个方法比加锁一组lockId的效率高
*/
protected final void lock(int lockId0, int lockId1) throws InterruptedException
{
unlock();
lock2(lockId0 & _lockMask, lockId1 & _lockMask);
}
/**
* 加锁3个lockId
* <p>
* lockId通过{@link Table}/{@link TableLong}的lockId方法获取<br>
* 只能在事务中调用, 加锁前会释放当前事务已经加过的锁<br>
* 这个方法比加锁一组lockId的效率高
*/
protected final void lock(int lockId0, int lockId1, int lockId2) throws InterruptedException
{
unlock();
lock3(lockId0 & _lockMask, lockId1 & _lockMask, lockId2 & _lockMask);
}
/**
* 加锁4个lockId
* <p>
* lockId通过{@link Table}/{@link TableLong}的lockId方法获取<br>
* 只能在事务中调用, 加锁前会释放当前事务已经加过的锁<br>
* 这个方法比加锁一组lockId的效率高
*/
protected final void lock(int lockId0, int lockId1, int lockId2, int lockId3) throws InterruptedException
{
unlock();
lockId0 &= _lockMask;
lockId1 &= _lockMask;
lockId2 &= _lockMask;
lockId3 &= _lockMask;
ProcThread pt = _pt;
IndexLock[] locks = pt.locks;
int i = 0;
if(lockId0 <= lockId1)
{
if(lockId0 < lockId2)
{
if(lockId0 < lockId3)
{
(locks[i] = getLock(lockId0)).lockInterruptibly();
pt.lockCount = ++i;
lock3(lockId1, lockId2, lockId3);
}
else
{
(locks[i] = getLock(lockId3)).lockInterruptibly();
pt.lockCount = ++i;
(locks[i] = getLock(lockId0)).lockInterruptibly();
pt.lockCount = ++i;
lock2(lockId1, lockId2);
}
}
else if(lockId2 < lockId3)
{
(locks[i] = getLock(lockId2)).lockInterruptibly();
pt.lockCount = ++i;
if(lockId0 < lockId3)
{
(locks[i] = getLock(lockId0)).lockInterruptibly();
pt.lockCount = ++i;
lock2(lockId1, lockId3);
}
else
{
(locks[i] = getLock(lockId3)).lockInterruptibly();
pt.lockCount = ++i;
(locks[i] = getLock(lockId0)).lockInterruptibly();
pt.lockCount = ++i;
(locks[i] = getLock(lockId1)).lockInterruptibly();
pt.lockCount = ++i;
}
}
else
{
(locks[i] = getLock(lockId3)).lockInterruptibly();
pt.lockCount = ++i;
(locks[i] = getLock(lockId2)).lockInterruptibly();
pt.lockCount = ++i;
(locks[i] = getLock(lockId0)).lockInterruptibly();
pt.lockCount = ++i;
(locks[i] = getLock(lockId1)).lockInterruptibly();
pt.lockCount = ++i;
}
}
else
{
if(lockId1 < lockId2)
{
if(lockId1 < lockId3)
{
(locks[i] = getLock(lockId1)).lockInterruptibly();
pt.lockCount = ++i;
lock3(lockId0, lockId2, lockId3);
}
else
{
(locks[i] = getLock(lockId3)).lockInterruptibly();
pt.lockCount = ++i;
(locks[i] = getLock(lockId1)).lockInterruptibly();
pt.lockCount = ++i;
lock2(lockId0, lockId2);
}
}
else if(lockId2 < lockId3)
{
(locks[i] = getLock(lockId2)).lockInterruptibly();
pt.lockCount = ++i;
if(lockId1 < lockId3)
{
(locks[i] = getLock(lockId1)).lockInterruptibly();
pt.lockCount = ++i;
lock2(lockId0, lockId3);
}
else
{
(locks[i] = getLock(lockId3)).lockInterruptibly();
pt.lockCount = ++i;
(locks[i] = getLock(lockId1)).lockInterruptibly();
pt.lockCount = ++i;
(locks[i] = getLock(lockId0)).lockInterruptibly();
pt.lockCount = ++i;
}
}
else
{
(locks[i] = getLock(lockId3)).lockInterruptibly();
pt.lockCount = ++i;
(locks[i] = getLock(lockId2)).lockInterruptibly();
pt.lockCount = ++i;
(locks[i] = getLock(lockId1)).lockInterruptibly();
pt.lockCount = ++i;
(locks[i] = getLock(lockId0)).lockInterruptibly();
pt.lockCount = ++i;
}
}
}
/**
* 事务的运行入口
* <p>
* 同{@link #execute}, 实现Runnable的接口,没有返回值
*/
@Override
public void run()
{
try
{
execute();
}
catch(Throwable e)
{
Log.log.error("procedure fatal exception: " + toString(), e);
}
}
/**
* 事务的运行入口
* <p>
* 必须在ProcThread类的线程上运行. 一般应通过调度来运行({@link DBManager#submit})<br>
* 如果确保没有顺序问题,也可以由用户直接调用,但不能在事务中嵌套调用
*/
public boolean execute() throws Exception
{
if(!_running.compareAndSet(0, 1))
{
Log.log.error("procedure already running: " + toString());
return false;
}
if(DBManager.instance().isExiting())
{
Thread.sleep(Long.MAX_VALUE); // 如果有退出信号则线程睡死等待终结
throw new IllegalStateException();
}
ProcThread pt = null;
SContext sctx = null;
ReadLock rl = null;
try
{
_pt = pt = (ProcThread)Thread.currentThread();
rl = _rwlCommit.readLock();
rl.lock();
sctx = pt.sctx;
pt.beginTime = System.currentTimeMillis();
pt.proc = this;
for(int n = Const.maxProceduerRedo;;)
{
if(Thread.interrupted())
throw new InterruptedException();
try
{
onProcess();
break;
}
catch(Redo e)
{
}
sctx.rollback();
unlock();
if(--n <= 0)
throw new Exception("procedure redo too many times=" + Const.maxProceduerRedo + ": " + toString());
}
sctx.commit();
return true;
}
catch(Throwable e)
{
try
{
if(e != Undo._instance)
onException(e);
}
catch(Throwable ex)
{
Log.log.error("procedure.onException exception: " + toString(), ex);
}
finally
{
if(sctx != null)
sctx.rollback();
}
return false;
}
finally
{
if(pt != null)
unlock();
synchronized(this)
{
if(pt != null)
pt.proc = null;
Thread.interrupted(); // 清除interrupted标识
}
if(rl != null)
rl.unlock();
_pt = null;
_running.set(0);
}
}
/**
* 由子类实现的事务
*/
protected abstract void onProcess() throws Exception;
/**
* 可由子类继承的事务异常处理
*/
protected void onException(Throwable e)
{
if(_defaultEh != null)
_defaultEh.onException(e);
else
Log.log.error("procedure exception: " + toString(), e);
}
@Override
public String toString()
{
return getClass().getName() + ":sid=" + _sid;
}
}