package org.wikibrain.sr.phrasesim;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.map.TLongFloatMap;
import gnu.trove.map.hash.TLongFloatHashMap;
import java.io.Serializable;
import java.util.Arrays;
/**
* @author Shilad Sen
*/
public class PhraseVector implements Serializable {
final long ids[];
final float vals[];
public PhraseVector(TLongFloatMap map) {
this.ids = map.keys();
Arrays.sort(this.ids);
vals = new float[ids.length];
for (int i = 0; i < ids.length; i++) {
vals[i] = map.get(ids[i]);
}
}
public PhraseVector(TIntFloatMap map) {
this.ids = new long[map.size()];
int i = 0;
for (int id : map.keys()) {
ids[i++] = id;
}
Arrays.sort(this.ids);
vals = new float[ids.length];
for (i = 0; i < ids.length; i++) {
vals[i] = map.get((int) ids[i]);
}
}
@Override
public boolean equals(Object other) {
if (!(other instanceof PhraseVector)) {
return false;
}
PhraseVector that = (PhraseVector) other;
return Arrays.equals(ids, that.ids) && Arrays.equals(vals, that.vals);
}
public double cosineSim(PhraseVector that) {
int n1 = this.ids.length;
int n2 = that.ids.length;
double xDotX = 0.0;
double xDotY = 0.0;
double yDotY = 0.0;
int i = 0;
int j = 0;
while (i < n1 || j < n2) {
if (i >= n1) {
yDotY += that.vals[j] * that.vals[j];
j++;
} else if (j >= n2) {
xDotX += this.vals[i] * this.vals[i];
i++;
} else if (this.ids[i] < that.ids[j]) {
xDotX += this.vals[i] * this.vals[i];
i++;
} else if (this.ids[i] > that.ids[j]) {
yDotY += that.vals[j] * that.vals[j];
j++;
} else {
xDotX += this.vals[i] * this.vals[i];
yDotY += that.vals[j] * that.vals[j];
xDotY += this.vals[i] * that.vals[j];
i++;
j++;
}
}
if (xDotX == 0.0 || yDotY == 0.0) {
return 0.0;
} else {
return xDotY / Math.sqrt(xDotX * yDotY);
}
}
public double norm2() {
double sum2 = 0.0;
for (int i = 0; i < vals.length; i++) {
sum2 += vals[i] * vals[i];
}
return Math.sqrt(sum2);
}
}