/*
* This file is part of the Jikes RVM project (http://jikesrvm.org).
*
* This file is licensed to You under the Common Public License (CPL);
* You may not use this file except in compliance with the License. You
* may obtain a copy of the License at
*
* http://www.opensource.org/licenses/cpl1.0.php
*
* See the COPYRIGHT.txt file distributed with this work for information
* regarding copyright ownership.
*/
package org.jikesrvm.compilers.opt.ir;
import org.jikesrvm.VM;
import org.jikesrvm.compilers.opt.OPT_OptimizingCompilerException;
/**
* Used to iterate over the branch targets (including the fall through edge)
* and associated probabilites of a basic block.
* Takes into account the ordering of branch instructions when
* computing the edge weights such that the total target weight will always
* be equal to 1.0 (flow in == flow out).
*/
public final class OPT_WeightedBranchTargets {
private OPT_BasicBlock[] targets;
private float[] weights;
private int cur;
private int max;
public void reset() { cur = 0; }
public boolean hasMoreElements() { return cur < max; }
public void advance() { cur++; }
public OPT_BasicBlock curBlock() { return targets[cur]; }
public float curWeight() { return weights[cur]; }
public OPT_WeightedBranchTargets(OPT_BasicBlock bb) {
targets = new OPT_BasicBlock[3];
weights = new float[3];
cur = 0;
max = 0;
float prob = 1f;
for (OPT_InstructionEnumeration ie = bb.enumerateBranchInstructions(); ie.hasMoreElements();) {
OPT_Instruction s = ie.next();
if (IfCmp.conforms(s)) {
OPT_BasicBlock target = IfCmp.getTarget(s).target.getBasicBlock();
OPT_BranchProfileOperand prof = IfCmp.getBranchProfile(s);
float taken = prob * prof.takenProbability;
prob = prob * (1f - prof.takenProbability);
addEdge(target, taken);
} else if (Goto.conforms(s)) {
OPT_BasicBlock target = Goto.getTarget(s).target.getBasicBlock();
addEdge(target, prob);
} else if (InlineGuard.conforms(s)) {
OPT_BasicBlock target = InlineGuard.getTarget(s).target.getBasicBlock();
OPT_BranchProfileOperand prof = InlineGuard.getBranchProfile(s);
float taken = prob * prof.takenProbability;
prob = prob * (1f - prof.takenProbability);
addEdge(target, taken);
} else if (IfCmp2.conforms(s)) {
OPT_BasicBlock target = IfCmp2.getTarget1(s).target.getBasicBlock();
OPT_BranchProfileOperand prof = IfCmp2.getBranchProfile1(s);
float taken = prob * prof.takenProbability;
prob = prob * (1f - prof.takenProbability);
addEdge(target, taken);
target = IfCmp2.getTarget2(s).target.getBasicBlock();
prof = IfCmp2.getBranchProfile2(s);
taken = prob * prof.takenProbability;
prob = prob * (1f - prof.takenProbability);
addEdge(target, taken);
} else if (TableSwitch.conforms(s)) {
int lowLimit = TableSwitch.getLow(s).value;
int highLimit = TableSwitch.getHigh(s).value;
int number = highLimit - lowLimit + 1;
float total = 0f;
for (int i = 0; i < number; i++) {
OPT_BasicBlock target = TableSwitch.getTarget(s, i).target.getBasicBlock();
OPT_BranchProfileOperand prof = TableSwitch.getBranchProfile(s, i);
float taken = prob * prof.takenProbability;
total += prof.takenProbability;
addEdge(target, taken);
}
OPT_BasicBlock target = TableSwitch.getDefault(s).target.getBasicBlock();
OPT_BranchProfileOperand prof = TableSwitch.getDefaultBranchProfile(s);
float taken = prob * prof.takenProbability;
total += prof.takenProbability;
if (VM.VerifyAssertions && !epsilon(total, 1f)) {
VM.sysFail("Total outflow (" + total + ") does not sum to 1 for: " + s);
}
addEdge(target, taken);
} else if (LowTableSwitch.conforms(s)) {
int number = LowTableSwitch.getNumberOfTargets(s);
float total = 0f;
for (int i = 0; i < number; i++) {
OPT_BasicBlock target = LowTableSwitch.getTarget(s, i).target.getBasicBlock();
OPT_BranchProfileOperand prof = LowTableSwitch.getBranchProfile(s, i);
float taken = prob * prof.takenProbability;
total += prof.takenProbability;
addEdge(target, taken);
}
if (VM.VerifyAssertions && !epsilon(total, 1f)) {
VM.sysFail("Total outflow (" + total + ") does not sum to 1 for: " + s);
}
} else if (LookupSwitch.conforms(s)) {
int number = LookupSwitch.getNumberOfTargets(s);
float total = 0f;
for (int i = 0; i < number; i++) {
OPT_BasicBlock target = LookupSwitch.getTarget(s, i).target.getBasicBlock();
OPT_BranchProfileOperand prof = LookupSwitch.getBranchProfile(s, i);
float taken = prob * prof.takenProbability;
total += prof.takenProbability;
addEdge(target, taken);
}
OPT_BasicBlock target = LookupSwitch.getDefault(s).target.getBasicBlock();
OPT_BranchProfileOperand prof = LookupSwitch.getDefaultBranchProfile(s);
float taken = prob * prof.takenProbability;
total += prof.takenProbability;
if (VM.VerifyAssertions && !epsilon(total, 1f)) {
VM.sysFail("Total outflow (" + total + ") does not sum to 1 for: " + s);
}
addEdge(target, taken);
} else {
throw new OPT_OptimizingCompilerException("TODO " + s + "\n");
}
}
OPT_BasicBlock ft = bb.getFallThroughBlock();
if (ft != null) addEdge(ft, prob);
}
private void addEdge(OPT_BasicBlock target, float weight) {
if (max == targets.length) {
OPT_BasicBlock[] tmp = new OPT_BasicBlock[targets.length << 1];
for (int i = 0; i < targets.length; i++) {
tmp[i] = targets[i];
}
targets = tmp;
float[] tmp2 = new float[weights.length << 1];
for (int i = 0; i < weights.length; i++) {
tmp2[i] = weights[i];
}
weights = tmp2;
}
targets[max] = target;
weights[max] = weight;
max++;
}
private boolean epsilon(float a, float b) {
return Math.abs(a - b) < 0.003;
}
}