package shef.network; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Set; import java.util.Map.Entry; import org.encog.mathutil.randomize.GaussianRandomizer; import org.neuroph.core.Connection; import org.neuroph.core.NeuralNetwork; import org.neuroph.core.Neuron; import org.neuroph.nnet.comp.ThresholdNeuron; import shef.instantiator.andortree.Node; import shef.instantiator.andortree.Tuple; import util.gdl.factory.GdlFactory; import util.gdl.factory.exceptions.GdlFormatException; import util.gdl.grammar.Gdl; import util.gdl.grammar.GdlRelation; import util.gdl.grammar.GdlSentence; import util.statemachine.MachineState; import util.symbol.factory.exceptions.SymbolFormatException; import cs227b.teamIago.resolver.Atom; import cs227b.teamIago.resolver.Expression; import cs227b.teamIago.resolver.Predicate; import cs227b.teamIago.resolver.Term; /** * Instantiates a {@link CIL2PNet} shef.network and is able to query it * * @author jonathan * */ public class CIL2PManager { /** List of player names */ public ArrayList<Term> playerList = new ArrayList<Term>(); private ArrayList<Gdl> playerListGGP = new ArrayList<Gdl>(); /** * The number of times calculate is called when output is requested from * <code>getOutput</code> calculate is called * * This MAY not be necessary, but all sample code thus far suggests */ public final int CALCULATE_MULTIPLE_TIMES = 8; /** * Network being managed */ public CIL2PNet network; /** * Create a manager for this shef.network */ public CIL2PManager(CIL2PNet network) { this.network = network; // generate player list for (Predicate g : network.goalHash.keySet()) { Term player = (Atom) g.secondOp(); if (!playerList.contains(player)) { playerList.add(player); try { playerListGGP.add(GdlFactory.create("(role " + player.toString() +")")); } catch (GdlFormatException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (SymbolFormatException e) { // TODO Auto-generated catch block e.printStackTrace(); } } } } /** * Prints a set of goal nodes as multiple CIL2P trees * * @param goalProofs */ public static void printCIL2P_tree(Collection<Node> goalProofs) { for (Node goal : goalProofs) { System.out.println(goal.printTree()); } } /** * Get the value of a given state by querying the CIL2P shef.network. * @param state * @param player */ public List<Double> getStateValues(final MachineState state){ propagateInput(state); ArrayList<Double> scores = new ArrayList<Double>(); for(Term playerName : playerList){ scores.add(getPlayerScore(playerName)); } return scores; } /** * Get the value of a given state by querying the CIL2P shef.network. * @param state * @param player */ public double getStateValue(final MachineState state, int playerID){ propagateInput(state); return getPlayerScore(playerList.get(playerID)); } private final double sigmaOverTwo = 0.8; private final GaussianRandomizer gauss = new GaussianRandomizer(0, sigmaOverTwo*sigmaOverTwo); private final HashMap<MachineState, Double> mem = new HashMap<MachineState, Double>(); /** * Get a state Gaussian * @param state * @param player * @return a Gaussian value for the state and player with a Gaussian random factor */ public double getStateValueGaussian(final MachineState state, int playerID){ if(mem.containsKey(state)){ return mem.get(state); } propagateInput(state); double sc = getPlayerScore(playerList.get(playerID))/100d; double gaussR = gauss.randomize(0); mem.put(state, gaussR); return sc + gaussR; } public void propagateInput(final MachineState state){ Set<GdlSentence> stateElements = state.getContents(); // reset the state for (Neuron in : network.inputLayer.getNeurons()) { in.setInput(0); } for (Entry<GdlRelation, Neuron> fact : network.queryHashGGPBase.entrySet()) { if(stateElements.contains(fact.getKey())){ fact.getValue().setInput(1d); } else { fact.getValue().setInput(-1d); } } // calculate for (int i = 0; i < CALCULATE_MULTIPLE_TIMES; i++) { this.network.n.calculate(); } } private double getPlayerScore(Term playerName) { // get outputs float playerSum = 0f; float totalSum = 0f; for (Entry<Goal, ThresholdNeuron> goal : network.goalHash.entrySet()) { Goal goalClause = goal.getKey(); ThresholdNeuron goalNeuron = goal.getValue(); // modulate the neuron output between 0 and 1 double vNeuron = (goalNeuron.getOutput() + 1) / 2f; // should be modulated betwen 0 and 1 assert vNeuron <= 1 && vNeuron >= 0; double nScore = vNeuron * goalClause.score; // keep a track of possible output totalSum += nScore; // if this neuron is attached to the player we're // looking at add the value to their score if (playerName.equals(goalClause.player)) { playerSum += nScore; } } double stateScore = (playerSum / totalSum) * 100f; return stateScore; } /* * TESTING AND DEBUG METHODS */ /** * Find the output of the shef.network given an input vector of the state This * requires an ordered set of input * * <b>This should be FAST</b> as this will be the function called by the UCT * functions to complete UCT rollouts * * @param queryValues * an ordered array of input values to each neuron * @return the output of the shef.network for the goal neurons */ public double[] getOutput(double... queryValues) { double[] inputNeuronValues = new double[network.inputNeuronCount]; double[] outputNeuronValues = new double[network.goalNeuronCount]; NeuralNetwork net = network.n; // XXX this may be SLOW - may have to precompile all of this? for (int i = 0; i < network.queryNeuronCount; i++) { inputNeuronValues[network.queryNeuronIndices[i]] = queryValues[i]; } net.setInput(inputNeuronValues); // XXX what makes it necessary to for (int i = 0; i < CALCULATE_MULTIPLE_TIMES; i++) { net.calculate(); } double[] networkOut = net.getOutput(); // return only goal neurons output for (int i = 0; i < network.goalNeuronCount; i++) { outputNeuronValues[i] = networkOut[network.goalNeuronDetails.get(i) .getSecond()]; } return outputNeuronValues; } /** * Print the output from <code>getOutput</code> * * @param queryValues * the input values passed into <code>getOutput</code> */ public void printOutput(double... queryValues) { System.out.println(Arrays.toString(getOutput(queryValues))); } /** * Return a score of the current state for each player Each goal neuron has * activation [-1, 1] -> [0, 1] = v_n * * @param queryValues * @return the summed score for each player */ public double[] getAllPlayerScores(double... queryValues) { getOutput(queryValues); double[] player_V = new double[playerList.size()]; for (int i = 0; i < playerList.size(); i++) { player_V[i] = getPlayerScore(playerList.get(i)); } return player_V; } /** * print the scores returned by <code>getScores</code> given the * <code>queryValues</code> provided * * @param queryValues * input to shef.network */ public void printScores(double... queryValues) { System.out.println(Arrays.toString(getAllPlayerScores(queryValues))); } /** * get overall neural layout * * @return #input (preds), #hidden (clauses), #output (heads) */ public int[] getInfo() { return new int[] { network.inputLayer.getNeurons().size(), network.hiddenLayer.getNeurons().size(), network.outputLayer.getNeurons().size() }; } /** * print information about the overall neural layout<br/> * <code>#input units, [input unit expressions], #hidden units, * #output units, [output unit expressions]</code> */ public void printInfo() { System.out .println("# input units:" + network.inputLayer.getNeurons().size() + "(total predicates)"); System.out.println(network.inputHash.keySet()); System.out.println("# hidden units:" + network.hiddenLayer.getNeurons().size() + "(total clauses)"); System.out.println("# output units:" + network.outputLayer.getNeurons().size() + "(total unique heads)"); System.out.println(network.outputHash.keySet()); } /** * get information about the most important neurons in the current shef.network * * @return [#input (truths), #output (goals)] */ public int[] getPlayInfo() { return new int[] { network.queryNeuronDetails.size(), network.goalNeuronDetails.size() }; } /** * print information about the most important neurons in the current shef.network * <b>query neurons</b> and <b>goal neurons</b> */ public void printPlayInfo() { System.out .println("# query units " + network.queryNeuronDetails.size()); System.out.println(network.queryNeuronDetails); System.out.println("# goal units " + network.goalNeuronDetails.size()); System.out.println(network.goalNeuronDetails); } /** * @return the value amin which is the minimum activation a node must have * to consider its output true */ public double getMaxError() { return network.Amin; } /** * for each node matching <code>(true ?term)</code> is an input node in the * graph, print: <br/> * <code>(expression, output)</code> */ public void printQueryInputActivation() { for (Tuple<Expression, Integer> tup : network.queryNeuronDetails) { System.out.println(tup.getFirst() + "\t-> " + network.inputHash.get(tup.getFirst()).getOutput()); } } /** * for each node on the input layer print <br/> * <code>(expression, #inputs, input, #outputs, output)</code> */ public void printInputActivation() { for (Entry<Expression, Neuron> tup : network.inputHash.entrySet()) { System.out.println(tup.getKey() + "\tin[" + tup.getValue().getInputConnections().size() + "]: " + tup.getValue().getNetInput() + "out:[" + tup.getValue().getOutConnections().size() + "]: " + tup.getValue().getOutput()); } } /** * For each node matching <code>(GOAL ?player ?score)</code> print. <br/> * <code>(expression, output)</code> */ public void printGoalOutputActivation() { for (Tuple<Expression, Integer> tup : network.goalNeuronDetails) { System.out.println(tup.getFirst() + "\t-> " + network.outputHash.get(tup.getFirst()).getOutput()); } } /** * Print details about the output layer of the shef.network. <br/> * <code>(predicate, output, input, #inputs, [input activations])</code> */ public void printOutputActivation() { for (Entry<Expression, ThresholdNeuron> tup : network.outputHash .entrySet()) { String inps = ""; for (Connection in : tup.getValue().getInputConnections()) { inps = inps + (in.getWeight().getValue() < 0 ? "N" : "") + in.getFromNeuron().getOutput() + " "; } String out = tup.getKey() + "\t-> " + " " + " in[" + tup.getValue().getInputConnections().size() + "]: " + tup.getValue().getNetInput() + " out:" + tup.getValue().getOutput() + " ~ " + inps; System.out.println(out); } } /** * Print details about the hidden layer of the shef.network. <br/> * <code>(input, output)</code> */ public void printHiddenActivation() { for (Neuron hidden : network.hiddenLayer.getNeurons()) { System.out.println("in:" + hidden.getNetInput() + "out:" + hidden.getOutput()); } } /** * Get query predicates in order they appear in the input layer. */ public List<Expression> getQueryPredicates() { List<Expression> returnQs = new ArrayList<Expression>(); for (Tuple<Expression, Integer> tup : network.queryNeuronDetails) { returnQs.add(tup.getFirst()); } return returnQs; } }