/*
* SelectorOperator.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.operators;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.inferencexml.operators.SelectorOperatorParser;
import dr.math.MathUtils;
import java.util.ArrayList;
import java.util.List;
/**
* The code is much more elegant in Python.
* I definitly don't have to write my own max for a list of integers (yikes).
*
* @author Joseph Heled
* Date: 4/09/2009
*/
public class SelectorOperator extends SimpleMCMCOperator {
private final Parameter selector;
private final int[] np_m1;
private final int[] np_m2;
public SelectorOperator(Parameter selector) {
this.selector = selector;
final int len = selector.getSize();
np_m1 = new int[len +1];
for(int l = 0; l < np_m1.length; ++l) {
np_m1[l] = npos(len, l);
}
np_m2 = new int[len + 1];
np_m2[0] = 1;
for(int mx = 0; mx < len; ++mx) {
np_m2[mx+1] = 0;
for(int l = 1; l < len+1; ++l) {
np_m2[mx+1] += npos(l, mx);
}
}
}
public String getOperatorName() {
return SelectorOperatorParser.SELECTOR_OPERATOR + "(" + selector.getParameterName() + ")";
}
public double doOperation() {
final int[] s = vals();
final List<Integer> poss = movesFrom_m2(s);
final int i = MathUtils.nextInt(poss.size()/2);
final int[] y = new int[s.length];
System.arraycopy(s, 0, y, 0, s.length);
final Integer p = poss.get(2 * i);
y[p] = poss.get(2*i+1);
double hr = count_sr_m2(s, y);
hr *= (double)(poss.size()* np_m2[max(s)+1])/(movesFrom_m2(y).size() * np_m2[max(y)+1]);
selector.setParameterValue(p, y[p]);
return Math.log(hr);
}
// public double doOperation_m1() throws OperatorFailedException {
//
// final int[] s = vals();
// final List<Integer> poss = movesFrom_m1(s);
// final int i = MathUtils.nextInt(poss.size()/2);
//
// final int[] y = new int[s.length];
// System.arraycopy(s, 0, y, 0, s.length);
// final Integer p = poss.get(2 * i);
// y[p] = poss.get(2*i+1);
//
// double hr = count_sr_m1(s, y);
// hr *= (double)(poss.size()* np_m1[max(s)])/(movesFrom_m1(y).size() * np_m1[max(y)]);
//
// selector.setParameterValue(p, y[p]);
//
// return Math.log(hr);
// }
public String getPerformanceSuggestion() {
return null;
}
private int[] vals() {
return intVals(selector);
}
static int[] intVals(Variable<Double> var) {
int[] v = new int[var.getSize()];
for(int k = 0; k < v.length; ++k) {
final double vk = var.getValue(k);
v[k] = (int)(vk + ((vk>= 0) ? 0.5 : -0.5));
}
return v;
}
private List<Integer> movesFrom_m2(int[] s) {
final int mx = max(s);
final int[] counts = counts_used_m2(s, mx);
final List<Integer> opt = new ArrayList<Integer>(5);
for(int k = 0; k < s.length; ++k) {
final int si = s[k];
if( si < 0 ) {
opt.add(k); opt.add(0);
if( mx >= 0 ) {
opt.add(k); opt.add(mx+1);
}
for(int x = 1; x < mx+1; ++x) {
if( counts[x]+1 <= counts[x-1]) {
opt.add(k); opt.add(x);
}
}
} else {
if( si < mx && ((counts[si] == 1) || counts[si] == counts[si+1]) ) {
// only or breaks order -> no moves
} else {
for(int x = 0; x < mx+1; ++x) {
if(x == si) {
continue;
}
if( (x > si && counts[si] - 1 >= counts[x] + 1 && counts[x-1] >= counts[x]+1)
||
(x < si && (x > 0 && counts[x]+1 <= counts[x-1] || x == 0)) ) {
opt.add(k); opt.add(x);
}
}
if( counts[si] > 1) {
opt.add(k); opt.add(mx+1);
}
opt.add(k); opt.add(-1);
}
}
}
return opt;
}
private List<Integer> movesFrom_m1(int[] s) {
final int mx = max(s);
final int[] counts = counts_m1(s, mx);
final List<Integer> opt = new ArrayList<Integer>(5);
for(int k = 0; k < s.length; ++k) {
final int si = s[k];
if( si<mx && ((counts[si] == 1) || counts[si] == counts[si+1]) ) {
// only or breaks order -> no moves
} else {
for(int x = 0; x < mx+1; ++x) {
if(x == si) {
continue;
}
if( (x > si && counts[si] - 1 >= counts[x] + 1 && counts[x-1] >= counts[x]+1)
||
(x < si && (x > 0 && counts[x]+1 <= counts[x-1] || x == 0)) ) {
opt.add(k);
opt.add(x);
}
}
if( counts[si] > 1) {
opt.add(k);
opt.add(mx+1);
}
}
}
return opt;
}
private static int npos(int s, int m) {
return npos(s, m, 1);
}
private static int npos(int s, int m, int mn) {
if( m == 0 || s == 0 ) {
return 1;
}
int tot = 0;
for(int k = mn; k < 1+s/m; ++k) {
final int r = s - k*(m+1);
if(r < 0 ) {
break;
}
tot += npos(r, m-1, 0);
}
return tot;
}
private static int sum(int[] s) {
int sum = 0;
for(int si : s) {
sum += si;
}
return sum;
}
private static int max(int[] s) {
int mx = s[0];
for(int k = 1; k < s.length; ++k) {
if( mx < s[k] ) {
mx = s[k];
}
}
return mx;
}
private static int[] counts_m1(int[] s, int mx) {
int[] c = new int[mx+1];
for(int si : s) {
c[si]++;
}
return c;
}
// Counts in s including unused (-1). indices are shifted by 1
private static int[] counts_m2(int[] s, int mx) {
int[] c = new int[mx+2];
for(int si : s) {
c[si+1]++;
}
return c;
}
static int[] counts_m2(int[] s) {
return counts_m2(s, max(s));
}
private static int[] counts_used_m2(int[] s, int mx) {
int[] c = new int[mx+1];
for(int si : s) {
if( si >= 0 ) {
c[si]++;
}
}
return c;
}
static int[] counts_used_m2(int[] s) {
return counts_used_m2(s, max(s));
}
private static long choose(int n, int k) {
double r = 1;
while( n > k ) {
r *= n;
r /= (n-k);
--n;
}
return (long)(r+0.5);
}
private static long[] countl_m1(int[] ls) {
int l = sum(ls);
int i = 0;
long[] r = new long[ls.length];
while(l > 0) {
r[i] = choose(l, ls[i]);
l -= ls[i];
i += 1;
}
return r;
}
private static double count_sr_m1(int[] x, int[] y) {
long[] r1 = countl_m1(counts_m1(x, max(x)));
long[] r2 = countl_m1(counts_m1(y, max(y)));
int k = Math.min(r1.length, r2.length);
double r = 1;
for(int i=0; i < k; ++i) {
r *= r1[i];
r /= r2[i];
}
for(int i=k; i < r1.length; ++i) {
r *= r1[i];
}
for(int i=k; i < r2.length; ++i) {
r /= r2[i];
}
return r;
}
private static long[] countl_m2(int[] ls) {
if( ls.length == 1 ) {
return new long[]{1};
}
int l = sum(ls);
long[] r = new long[ls.length];
r[0] = choose(l, ls[0]);
l -= ls[0];
int i = 1;
while( l > 0 ) {
r[i] = choose(l, ls[i]);
l -= ls[i];
i += 1;
}
return r;
}
private static double count_sr_m2(int[] x, int[] y) {
long[] r1 = countl_m2(counts_m2(x));
long[] r2 = countl_m2(counts_m2(y));
int k = Math.min(r1.length, r2.length);
double r = 1;
for(int i=0; i < k; ++i) {
r *= r1[i];
r /= r2[i];
}
for(int i=k; i < r1.length; ++i) {
r *= r1[i];
}
for(int i=k; i < r2.length; ++i) {
r /= r2[i];
}
return r;
}
}