/*
* This file is part of the Jikes RVM project (http://jikesrvm.org).
*
* This file is licensed to You under the Eclipse Public License (EPL);
* You may not use this file except in compliance with the License. You
* may obtain a copy of the License at
*
* http://www.opensource.org/licenses/eclipse-1.0.php
*
* See the COPYRIGHT.txt file distributed with this work for information
* regarding copyright ownership.
*/
package org.jikesrvm.compilers.opt.ir;
import java.util.Enumeration;
import org.jikesrvm.VM;
import org.jikesrvm.compilers.opt.OptimizingCompilerException;
import org.jikesrvm.compilers.opt.ir.operand.BranchProfileOperand;
/**
* Used to iterate over the branch targets (including the fall through edge)
* and associated probabilites of a basic block.
* Takes into account the ordering of branch instructions when
* computing the edge weights such that the total target weight will always
* be equal to 1.0 (flow in == flow out).
*/
public final class WeightedBranchTargets {
private BasicBlock[] targets;
private float[] weights;
private int cur;
private int max;
public void reset() {
cur = 0;
}
public boolean hasMoreElements() {
return cur < max;
}
public void advance() {
cur++;
}
public BasicBlock curBlock() {
return targets[cur];
}
public float curWeight() {
return weights[cur];
}
public WeightedBranchTargets(BasicBlock bb) {
targets = new BasicBlock[3];
weights = new float[3];
cur = 0;
max = 0;
float prob = 1f;
for (Enumeration<Instruction> ie = bb.enumerateBranchInstructions(); ie.hasMoreElements();) {
Instruction s = ie.nextElement();
if (IfCmp.conforms(s)) {
BasicBlock target = IfCmp.getTarget(s).target.getBasicBlock();
BranchProfileOperand prof = IfCmp.getBranchProfile(s);
float taken = prob * prof.takenProbability;
prob = prob * (1f - prof.takenProbability);
addEdge(target, taken);
} else if (Goto.conforms(s)) {
BasicBlock target = Goto.getTarget(s).target.getBasicBlock();
addEdge(target, prob);
} else if (InlineGuard.conforms(s)) {
BasicBlock target = InlineGuard.getTarget(s).target.getBasicBlock();
BranchProfileOperand prof = InlineGuard.getBranchProfile(s);
float taken = prob * prof.takenProbability;
prob = prob * (1f - prof.takenProbability);
addEdge(target, taken);
} else if (IfCmp2.conforms(s)) {
BasicBlock target = IfCmp2.getTarget1(s).target.getBasicBlock();
BranchProfileOperand prof = IfCmp2.getBranchProfile1(s);
float taken = prob * prof.takenProbability;
prob = prob * (1f - prof.takenProbability);
addEdge(target, taken);
target = IfCmp2.getTarget2(s).target.getBasicBlock();
prof = IfCmp2.getBranchProfile2(s);
taken = prob * prof.takenProbability;
prob = prob * (1f - prof.takenProbability);
addEdge(target, taken);
} else if (TableSwitch.conforms(s)) {
int lowLimit = TableSwitch.getLow(s).value;
int highLimit = TableSwitch.getHigh(s).value;
int number = highLimit - lowLimit + 1;
float total = 0f;
for (int i = 0; i < number; i++) {
BasicBlock target = TableSwitch.getTarget(s, i).target.getBasicBlock();
BranchProfileOperand prof = TableSwitch.getBranchProfile(s, i);
float taken = prob * prof.takenProbability;
total += prof.takenProbability;
addEdge(target, taken);
}
BasicBlock target = TableSwitch.getDefault(s).target.getBasicBlock();
BranchProfileOperand prof = TableSwitch.getDefaultBranchProfile(s);
float taken = prob * prof.takenProbability;
total += prof.takenProbability;
if (VM.VerifyAssertions && !epsilon(total, 1f)) {
VM.sysFail("Total outflow (" + total + ") does not sum to 1 for: " + s);
}
addEdge(target, taken);
} else if (LowTableSwitch.conforms(s)) {
int number = LowTableSwitch.getNumberOfTargets(s);
float total = 0f;
for (int i = 0; i < number; i++) {
BasicBlock target = LowTableSwitch.getTarget(s, i).target.getBasicBlock();
BranchProfileOperand prof = LowTableSwitch.getBranchProfile(s, i);
float taken = prob * prof.takenProbability;
total += prof.takenProbability;
addEdge(target, taken);
}
if (VM.VerifyAssertions && !epsilon(total, 1f)) {
VM.sysFail("Total outflow (" + total + ") does not sum to 1 for: " + s);
}
} else if (LookupSwitch.conforms(s)) {
int number = LookupSwitch.getNumberOfTargets(s);
float total = 0f;
for (int i = 0; i < number; i++) {
BasicBlock target = LookupSwitch.getTarget(s, i).target.getBasicBlock();
BranchProfileOperand prof = LookupSwitch.getBranchProfile(s, i);
float taken = prob * prof.takenProbability;
total += prof.takenProbability;
addEdge(target, taken);
}
BasicBlock target = LookupSwitch.getDefault(s).target.getBasicBlock();
BranchProfileOperand prof = LookupSwitch.getDefaultBranchProfile(s);
float taken = prob * prof.takenProbability;
total += prof.takenProbability;
if (VM.VerifyAssertions && !epsilon(total, 1f)) {
VM.sysFail("Total outflow (" + total + ") does not sum to 1 for: " + s);
}
addEdge(target, taken);
} else {
throw new OptimizingCompilerException("TODO " + s + "\n");
}
}
BasicBlock ft = bb.getFallThroughBlock();
if (ft != null) addEdge(ft, prob);
}
private void addEdge(BasicBlock target, float weight) {
if (max == targets.length) {
BasicBlock[] tmp = new BasicBlock[targets.length << 1];
for (int i = 0; i < targets.length; i++) {
tmp[i] = targets[i];
}
targets = tmp;
float[] tmp2 = new float[weights.length << 1];
for (int i = 0; i < weights.length; i++) {
tmp2[i] = weights[i];
}
weights = tmp2;
}
targets[max] = target;
weights[max] = weight;
max++;
}
private boolean epsilon(float a, float b) {
return Math.abs(a - b) < 0.1;
}
}