package edu.uw.cse.netlab.utils;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.net.InetAddress;
import java.nio.ByteBuffer;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.cert.Certificate;
import java.util.BitSet;
import java.util.HashSet;
import java.util.Random;
import java.util.Set;
import org.gudy.azureus2.core3.util.ByteFormatter;
public class BloomFilter
implements Serializable
{
private static final long serialVersionUID = 1L;
// TODO: rewrite this wihtout the use of this bitset object (which forces us to retain and serialize mBitsCapacity separately!)
//BitSet mBits = null;
BitArray mBits = null;
int mHashesCount = -1;
transient MessageDigest mDigest = null;
transient ByteBuffer buff = null;
int mBitsCapacity;
int mInToStore = -1; // this gives the fractional amount intermediaries should attribute
byte[][] salts = null;
private final static int SALT_LENGTH = 20;
public boolean equals(BloomFilter rhs) {
return mBits.equals(mBits) && mHashesCount == rhs.mHashesCount
&& mBitsCapacity == rhs.mBitsCapacity && mInToStore == rhs.mInToStore;
}
@Override
public String toString() {
String out = "[BloomFilter: store: " + mInToStore + " hashes: "
+ mHashesCount + " bits: " + mBits.length() + "] ";
if (mBits.length() < 100) {
for (int i = 0; i < mBits.length(); i++)
out += mBits.get(i) == true ? "1" : "0";
}
return out;
}
public int getStoredCount() {
return mInToStore;
}
public BloomFilter(int inNumBits, int inMaxToStore)
throws NoSuchAlgorithmException {
mBitsCapacity = inNumBits;
mBits = new BitArray(inNumBits);
mHashesCount = BloomFilter.computeHashes(inNumBits, inMaxToStore);
salts = new byte[mHashesCount][0];
/*
* use random salts to avoid any bloomfilter filling attacks
*/
for (int i = 0; i < mHashesCount; i++) {
byte[] salt = new byte[SALT_LENGTH];
rand.nextBytes(salt);
salts[i] = salt;
}
mInToStore = inMaxToStore;
buff = ByteBuffer.allocate(4);
mDigest = MessageDigest.getInstance("MD5");
}
private void readObject(java.io.ObjectInputStream in) throws IOException,
ClassNotFoundException {
try {
mDigest = MessageDigest.getInstance("MD5");
} catch (Exception e) {
System.err.println("** Couldn't get MD5 hasher!");
}
buff = ByteBuffer.allocate(4);
in.defaultReadObject();
}
private static int computeHashes(int inNumBits, int inMaxToStore) {
return (int) Math.ceil((Math.log(2) * ((double) inNumBits / (double) inMaxToStore)));
}
// Returns a bitset containing the values in bytes.
// The byte-ordering of bytes must be big-endian which means the most significant bit is in element 0.
public static BitSet fromByteArray(byte[] bytes) {
BitSet bits = new BitSet();
for (int i = 0; i < bytes.length * 8; i++) {
if ((bytes[bytes.length - i / 8 - 1] & (1 << (i % 8))) > 0) {
bits.set(i);
}
}
return bits;
}
private int[] getBits(byte[] inBytes) {
int[] bits = new int[mHashesCount];
for (int funcItr = 0; funcItr < mHashesCount; funcItr++) {
try {
// create many bloom filter hashes through consistent salting
mDigest.update(salts[funcItr]);
byte[] hash = mDigest.digest(inBytes);
buff.position(0);
buff.put(hash, 0, 4);
buff.position(0);
bits[funcItr] = Math.abs(buff.getInt() % mBitsCapacity);
} catch (Exception e) {
e.printStackTrace();
System.err.println("error hashing cert into bloom filter: " + e);
return null;
}
}
return bits;
}
public void insert(byte[] inBytes) {
int[] bits = getBits(inBytes);
boolean alreadyThere = true;
for (int i : bits) {
if (mBits.get(i) == false) {
alreadyThere = false;
}
mBits.set(i);
}
if (!alreadyThere) {
objectsStored++;
}
}
private int objectsStored = 0;
public int getUniqueObjectsStored() {
return objectsStored;
}
public boolean test(byte[] inBytes) {
int[] bits = getBits(inBytes);
// StringBuilder sb = new StringBuilder();
// for( int i : bits )
// sb.append(i + " ");
//
// System.out.println("testing: " + ByteFormatter.encodeString(inBytes) + " gives: " + sb.toString());
for (int i : bits) {
if (mBits.get(i) == false)
return false;
}
return true;
}
public void clear() {
mBits.clear();
}
static Random rand = new Random();
public static byte[] random_bytes(int inSize) {
byte[] b = new byte[inSize];
rand.nextBytes(b);
return b;
}
public double getPredictedFalsePositiveRate() {
return getPredictedFalsePositiveRate(mBitsCapacity,
getUniqueObjectsStored());
}
public static double getPredictedFalsePositiveRate(int size, int to_store) {
return Math.pow(0.6185, (double) size / (double) to_store);
}
public static final void main(String[] args) throws Exception {
int size = 512 * 1024, to_store = 20000;
BloomFilter bf = new BloomFilter(size, to_store);
Set<String> set = new HashSet<String>();
System.out.println("mbits size (ints)=" + bf.mBits.back.length);
byte[] to_test = null;
long start = System.currentTimeMillis();
for (int i = 0; i < to_store; i++) {
byte[] b = random_bytes(8);
set.add(new String(b));
bf.insert(b);
if (!bf.test(b)) {
System.err.println("inserted but test failed!");
}
if (to_test == null)
to_test = b;
}
System.out.println("inserting took: "
+ (System.currentTimeMillis() - start) + " ms");
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(
"/tmp/foo"));
oos.writeObject(bf);
oos.close();
int fps = 0, to_check = 100000;
for (int i = 0; i < to_check; i++) {
byte[] b = null;
do {
b = random_bytes(8);
} while (set.contains(new String(b)) == true);
if (bf.test(b) == true)
fps++;
}
System.out.println("false positives: " + fps + " of " + to_check + " / "
+ ((double) fps / (double) to_check) * 100.0 + "% / Predicted: "
+ (getPredictedFalsePositiveRate(size, to_store) * 100.0)
+ "% / based on current state="
+ (bf.getPredictedFalsePositiveRate() * 100) + "%");
// serialization
ByteArrayOutputStream baos = new ByteArrayOutputStream();
oos = new ObjectOutputStream(baos);
oos.writeObject(bf);
oos.close();
BloomFilter two = (BloomFilter) (new ObjectInputStream(
new ByteArrayInputStream(baos.toByteArray()))).readObject();
System.out.println(two);
two.test(to_test);
}
}