/*
* Copyright (C) 2011 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.bordeaux.learning;
import android.util.Log;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
/**
* A histogram based predictor which records co-occurrences of applations with a speficic
* feature, for example, location, * time of day, etc. The histogram is kept in a two level
* hash table. The first level key is the feature value and the second level key is the app
* id.
*/
// TODOS:
// 1. Use forgetting factor to downweight istances propotional to the time
// 2. Different features could have different weights on prediction scores.
// 3. Add function to remove sampleid (i.e. remove apps that are uninstalled).
public class HistogramPredictor {
final static String TAG = "HistogramPredictor";
private HashMap<String, HistogramCounter> mPredictor =
new HashMap<String, HistogramCounter>();
private HashMap<String, Integer> mClassCounts = new HashMap<String, Integer>();
private HashSet<String> mBlacklist = new HashSet<String>();
private static final int MINIMAL_FEATURE_VALUE_COUNTS = 5;
private static final int MINIMAL_APP_APPEARANCE_COUNTS = 5;
// This parameter ranges from 0 to 1 which determines the effect of app prior.
// When it is set to 0, app prior means completely neglected. When it is set to 1
// the predictor is a standard naive bayes model.
private static final int PRIOR_K_VALUE = 1;
private static final String[] APP_BLACKLIST = {
"com.android.contacts",
"com.android.chrome",
"com.android.providers.downloads.ui",
"com.android.settings",
"com.android.vending",
"com.android.mms",
"com.google.android.gm",
"com.google.android.gallery3d",
"com.google.android.apps.googlevoice",
};
public HistogramPredictor(String[] blackList) {
for (String appName : blackList) {
mBlacklist.add(appName);
}
}
/*
* This class keeps the histogram counts for each feature and provide the
* joint probabilities of <feature, class>.
*/
private class HistogramCounter {
private HashMap<String, HashMap<String, Integer> > mCounter =
new HashMap<String, HashMap<String, Integer> >();
public HistogramCounter() {
mCounter.clear();
}
public void setCounter(HashMap<String, HashMap<String, Integer> > counter) {
resetCounter();
mCounter.putAll(counter);
}
public void resetCounter() {
mCounter.clear();
}
public void addSample(String className, String featureValue) {
HashMap<String, Integer> classCounts;
if (!mCounter.containsKey(featureValue)) {
classCounts = new HashMap<String, Integer>();
mCounter.put(featureValue, classCounts);
} else {
classCounts = mCounter.get(featureValue);
}
int count = (classCounts.containsKey(className)) ?
classCounts.get(className) + 1 : 1;
classCounts.put(className, count);
}
public HashMap<String, Double> getClassScores(String featureValue) {
HashMap<String, Double> classScores = new HashMap<String, Double>();
if (mCounter.containsKey(featureValue)) {
int totalCount = 0;
for(Map.Entry<String, Integer> entry :
mCounter.get(featureValue).entrySet()) {
String app = entry.getKey();
int count = entry.getValue();
// For apps with counts less than or equal to one, we treated
// those as having count one. Hence their score, i.e. log(count)
// would be zero. classScroes stores only apps with non-zero scores.
// Note that totalCount also neglect app with single occurrence.
if (count > 1) {
double score = Math.log((double) count);
classScores.put(app, score);
totalCount += count;
}
}
if (totalCount < MINIMAL_FEATURE_VALUE_COUNTS) {
classScores.clear();
}
}
return classScores;
}
public byte[] getModel() {
try {
ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
ObjectOutputStream objStream = new ObjectOutputStream(byteStream);
synchronized(mCounter) {
objStream.writeObject(mCounter);
}
byte[] bytes = byteStream.toByteArray();
return bytes;
} catch (IOException e) {
throw new RuntimeException("Can't get model");
}
}
public boolean setModel(final byte[] modelData) {
mCounter.clear();
HashMap<String, HashMap<String, Integer> > model;
try {
ByteArrayInputStream input = new ByteArrayInputStream(modelData);
ObjectInputStream objStream = new ObjectInputStream(input);
model = (HashMap<String, HashMap<String, Integer> >) objStream.readObject();
} catch (IOException e) {
throw new RuntimeException("Can't load model");
} catch (ClassNotFoundException e) {
throw new RuntimeException("Learning class not found");
}
synchronized(mCounter) {
mCounter.putAll(model);
}
return true;
}
public HashMap<String, HashMap<String, Integer> > getCounter() {
return mCounter;
}
public String toString() {
String result = "";
for (Map.Entry<String, HashMap<String, Integer> > entry :
mCounter.entrySet()) {
result += "{ " + entry.getKey() + " : " +
entry.getValue().toString() + " }";
}
return result;
}
}
/*
* Given a map of feature name -value pairs returns topK mostly likely apps to
* be launched with corresponding likelihoods. If topK is set zero, it will return
* the whole list.
*/
public List<Map.Entry<String, Double> > findTopClasses(Map<String, String> features, int topK) {
// Most sophisticated function in this class
HashMap<String, Double> appScores = new HashMap<String, Double>();
int validFeatureCount = 0;
// compute all app scores
for (Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) {
String featureName = entry.getKey();
HistogramCounter counter = entry.getValue();
if (features.containsKey(featureName)) {
String featureValue = features.get(featureName);
HashMap<String, Double> scoreMap = counter.getClassScores(featureValue);
if (scoreMap.isEmpty()) {
continue;
}
validFeatureCount++;
for (Map.Entry<String, Double> item : scoreMap.entrySet()) {
String appName = item.getKey();
double appScore = item.getValue();
if (appScores.containsKey(appName)) {
appScore += appScores.get(appName);
}
appScores.put(appName, appScore);
}
}
}
HashMap<String, Double> appCandidates = new HashMap<String, Double>();
for (Map.Entry<String, Double> entry : appScores.entrySet()) {
String appName = entry.getKey();
if (mBlacklist.contains(appName)) {
Log.i(TAG, appName + " is in blacklist");
continue;
}
if (!mClassCounts.containsKey(appName)) {
throw new RuntimeException("class count error!");
}
int appCount = mClassCounts.get(appName);
if (appCount < MINIMAL_APP_APPEARANCE_COUNTS) {
Log.i(TAG, appName + " doesn't have enough counts");
continue;
}
double appScore = entry.getValue();
double appPrior = Math.log((double) appCount);
appCandidates.put(appName,
appScore - appPrior * (validFeatureCount - PRIOR_K_VALUE));
}
// sort app scores
List<Map.Entry<String, Double> > appList =
new ArrayList<Map.Entry<String, Double> >(appCandidates.size());
appList.addAll(appCandidates.entrySet());
Collections.sort(appList, new Comparator<Map.Entry<String, Double> >() {
public int compare(Map.Entry<String, Double> o1,
Map.Entry<String, Double> o2) {
return o2.getValue().compareTo(o1.getValue());
}
});
if (topK == 0) {
topK = appList.size();
}
return appList.subList(0, Math.min(topK, appList.size()));
}
/*
* Add a new observation of given sample id and features to the histograms
*/
public void addSample(String sampleId, Map<String, String> features) {
for (Map.Entry<String, String> entry : features.entrySet()) {
String featureName = entry.getKey();
String featureValue = entry.getValue();
useFeature(featureName);
HistogramCounter counter = mPredictor.get(featureName);
counter.addSample(sampleId, featureValue);
}
int sampleCount = (mClassCounts.containsKey(sampleId)) ?
mClassCounts.get(sampleId) + 1 : 1;
mClassCounts.put(sampleId, sampleCount);
}
/*
* reset predictor to a empty model
*/
public void resetPredictor() {
// TODO: not sure this step would reduce memory waste
for (HistogramCounter counter : mPredictor.values()) {
counter.resetCounter();
}
mPredictor.clear();
mClassCounts.clear();
}
/*
* convert the prediction model into a byte array
*/
public byte[] getModel() {
// TODO: convert model to a more memory efficient data structure.
HashMap<String, HashMap<String, HashMap<String, Integer > > > model =
new HashMap<String, HashMap<String, HashMap<String, Integer > > >();
for(Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) {
model.put(entry.getKey(), entry.getValue().getCounter());
}
try {
ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
ObjectOutputStream objStream = new ObjectOutputStream(byteStream);
objStream.writeObject(model);
byte[] bytes = byteStream.toByteArray();
return bytes;
} catch (IOException e) {
throw new RuntimeException("Can't get model");
}
}
/*
* set the prediction model from a model data in the format of byte array
*/
public boolean setModel(final byte[] modelData) {
HashMap<String, HashMap<String, HashMap<String, Integer > > > model;
try {
ByteArrayInputStream input = new ByteArrayInputStream(modelData);
ObjectInputStream objStream = new ObjectInputStream(input);
model = (HashMap<String, HashMap<String, HashMap<String, Integer > > >)
objStream.readObject();
} catch (IOException e) {
throw new RuntimeException("Can't load model");
} catch (ClassNotFoundException e) {
throw new RuntimeException("Learning class not found");
}
resetPredictor();
for (Map.Entry<String, HashMap<String, HashMap<String, Integer> > > entry :
model.entrySet()) {
useFeature(entry.getKey());
mPredictor.get(entry.getKey()).setCounter(entry.getValue());
}
// TODO: this is a temporary fix for now
loadClassCounter();
return true;
}
private void loadClassCounter() {
String TIME_OF_WEEK = "Time of Week";
if (!mPredictor.containsKey(TIME_OF_WEEK)) {
throw new RuntimeException("Precition model error: missing Time of Week!");
}
HashMap<String, HashMap<String, Integer> > counter =
mPredictor.get(TIME_OF_WEEK).getCounter();
mClassCounts.clear();
for (HashMap<String, Integer> map : counter.values()) {
for (Map.Entry<String, Integer> entry : map.entrySet()) {
int classCount = entry.getValue();
String className = entry.getKey();
// mTotalClassCount += classCount;
if (mClassCounts.containsKey(className)) {
classCount += mClassCounts.get(className);
}
mClassCounts.put(className, classCount);
}
}
Log.i(TAG, "class counts: " + mClassCounts);
}
private void useFeature(String featureName) {
if (!mPredictor.containsKey(featureName)) {
mPredictor.put(featureName, new HistogramCounter());
}
}
}