package shef.strategies.uct; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Deque; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map.Entry; import player.gamer.statemachine.StateMachineGamer; import player.gamer.statemachine.reflex.event.ReflexMoveSelectionEvent; import player.gamer.statemachine.reflex.gui.ReflexDetailPanel; import shef.strategies.uct.tree.Level; import shef.strategies.uct.tree.StateActionPair; import shef.strategies.uct.tree.StateModel; import shef.strategies.uct.tree.Tree; import util.statemachine.MachineState; import util.statemachine.Move; import util.statemachine.Role; import util.statemachine.StateMachine; import util.statemachine.exceptions.GoalDefinitionException; import util.statemachine.exceptions.MoveDefinitionException; import util.statemachine.exceptions.TransitionDefinitionException; import util.statemachine.implementation.prover.cache.CachedProverStateMachine; import apps.player.detail.DetailPanel; /** * A base class for any player based on the UCT and changing the out of tree * exploration. * * Creates an UCT tree and performs expansion on nodes based on the UCT * algorithm. <cite> * * @author jonathan * */ public abstract class UCTGamer extends StateMachineGamer { /** * C in the UCT equation, this alters the balance between exploration and * exploitation */ protected static final float C = 50; /** Role of the player */ private Role myRole; /** Index of the player's role in the player list */ private int myRoleID; /** Total number of players */ public static int roleCount; /** UCT tree */ private Tree tree; /** Number of moves played */ private int moveCount; /** Handle to the StateMachine governing this player */ protected StateMachine theMachine; /** */ protected List<Role> roles; /** * Uses a CachedProverStateMachine */ @Override public StateMachine getInitialStateMachine() { return new CachedProverStateMachine(); } /** * Setup the UCT game tree and perform rollouts for as long as possible. * * @param timeout * time in ms this meta game stage should be finished by */ @Override public void stateMachineMetaGame(final long timeout) throws TransitionDefinitionException, MoveDefinitionException, GoalDefinitionException { System.out.println("init"); final long finishBy = timeout; // - 1000; theMachine = getStateMachine(); myRole = getRole(); roles = theMachine.getRoles(); myRoleID = roles.indexOf(myRole); roleCount = roles.size(); moveCount = 0; try { tree = new Tree(getCurrentState(), this, roleCount); } catch (InterruptedException e) { // TODO Auto-generated catch block e.printStackTrace(); } final StateModel currentSM = tree.getStateLists().get(moveCount).states.get(getStateMachine().getInitialState()); int rollCount = 0; System.out.println("beginning rollouts"); while (System.currentTimeMillis() < finishBy) { rollout(currentSM); rollCount++; } System.out.println(rollCount + " initial"); } /** * As many times as possible in the time available perform rollouts from the * current state * * @param timeout * when in ms this move selection should be completed by * @return the move attributed to the most promising {@link StateActionPair} */ @Override public Move stateMachineSelectMove(final long timeout) throws TransitionDefinitionException, MoveDefinitionException, GoalDefinitionException { final long start = System.currentTimeMillis(); final long finishBy = timeout; // - 1000; final MachineState cState = getCurrentState(); final StateModel currentSM = tree.getStateLists().get(moveCount).states.get(cState); final List<Move> moves = theMachine.getLegalMoves(cState, myRole); Move selection = moves.get(0); int rollCount = 0; while (true) { if (System.currentTimeMillis() > finishBy) { // select best move! double maxVal = Float.NEGATIVE_INFINITY; List<Move> maxMove = null; HashMap<List<Move>, StateActionPair> saps = currentSM.actionsPairs; for (Entry<List<Move>, StateActionPair> sap : saps.entrySet()) { System.out.println("Move " + sap.getKey() + " explored " + sap.getValue().exploreCount + " " + Arrays.toString(sap.getValue().value)); double v = sap.getValue().value[myRoleID]; if (v > maxVal || maxMove == null) { maxMove = sap.getKey(); maxVal = v; } } selection = maxMove.get(myRoleID); break; } rollout(currentSM); rollCount++; } final long stop = System.currentTimeMillis(); moveCount++; // StringBuilder sb = new StringBuilder(); // tree.print(sb); // System.out.println(sb.toString()); notifyObservers(new ReflexMoveSelectionEvent(moves, selection, stop - start)); System.out.println(rollCount + " " + selection); return selection; } private Deque<StateActionPair> backupSAPs; private Deque<StateModel> backupStates; /** * Perform a single UCT rollout * * @param rolloutRootSM * state to begin rollout from * @throws MoveDefinitionException * @throws TransitionDefinitionException * @throws GoalDefinitionException */ private void rollout(final StateModel rolloutRootSM) throws MoveDefinitionException, TransitionDefinitionException, GoalDefinitionException { StateModel traverser = rolloutRootSM; List<StateActionPair> actions = new ArrayList<StateActionPair>(traverser.actionsPairs.values()); backupSAPs = new ArrayDeque<StateActionPair>(); // new // ArrayList<StateActionPair>(); backupStates = new ArrayDeque<StateModel>();// new // ArrayList<StateModel>(); boolean expandLeaf = true; int lvl = 0; while (!actions.isEmpty()) { List<Move> toPlay = new ArrayList<Move>(); // for each player for (int p = 0; p < roleCount; p++) { expandLeaf = true; int i = 0; float[] v = new float[actions.size()]; for (StateActionPair sap : actions) { if (sap.exploreCount == 0) { v[i] = Float.POSITIVE_INFINITY; expandLeaf = false; } else { float uctBonus = (float) Math.sqrt(Math.log(traverser.timesExplored) / (float) sap.exploreCount); v[i] = (float) (sap.value[p] + C * uctBonus); } i++; } // index of highest valued node int index = 0; float lowest = Integer.MIN_VALUE; for (int j = 0; j < v.length; j++) { if (v[j] > lowest) { index = j; lowest = v[j]; } } toPlay.add(actions.get(index).action.get(p)); } backupStates.push(traverser); StateActionPair chosenSAP = traverser.actionsPairs.get(toPlay); backupSAPs.push(chosenSAP); traverser = chosenSAP.result; actions = new ArrayList<StateActionPair>(traverser.actionsPairs.values()); lvl++; } backupStates.add(traverser); if (expandLeaf && !theMachine.isTerminal(traverser.state)) { tree.expandNode(traverser, lvl); // Level cur = tree.getStateLists().get(traverser.depth + 1); // traverser = cur.states.get(theMachine.getRandomNextState(traverser.state)); } List<Double> outcome; if (!theMachine.isTerminal(traverser.state)) { outcome = completeRollout(traverser.state, lvl); } else { outcome = theMachine.getDoubleGoals(traverser.state); } // distribute goal to each player backpropogate(backupSAPs, backupStates, outcome); } /** * Complete the rest of this UCT rollout past the UCT horizon * * @param from * the state to complete rollout from * @return the terminal state reached * @throws MoveDefinitionException * @throws TransitionDefinitionException * @throws GoalDefinitionException */ protected abstract List<Double> completeRollout(final MachineState from, final int fromLvl) throws TransitionDefinitionException, MoveDefinitionException, GoalDefinitionException; /** * Discount factor applied to each backup of the reward. The reward should * have a great effect on the states close to it and less to those further * away. */ private static final double discountFactor = 0.9; /** * Update every state visited in this path and update its average. Applying * a discount factor to the result at every stage. * * degrade reward to prefer earlier wins * * @param backupStatesPairs * @param outcome */ private void backpropogate(final Deque<StateActionPair> backupStatesPairs, final Deque<StateModel> backupStates, final List<Double> outcome) { for (StateModel m : backupStates) { m.timesExplored++; } int size = backupStatesPairs.size(); while (size > 0) { StateActionPair s = backupStatesPairs.pop(); s.updateAverage(outcome); // for (int i = 0; i < roleCount; i++) { // outcome.set(i, outcome.get(i) * discountFactor); // } size--; } } @Override public DetailPanel getDetailPanel() { return new ReflexDetailPanel(); } }