package org.wikibrain.sr.word2vec;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.TIntIntMap;
import gnu.trove.map.TLongIntMap;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.map.hash.TLongIntHashMap;
import org.apache.commons.cli.*;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.LineIterator;
import org.apache.commons.lang3.ArrayUtils;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.conf.DefaultOptionBuilder;
import org.wikibrain.core.cmd.Env;
import org.wikibrain.core.cmd.EnvBuilder;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.DaoFilter;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.model.LocalPage;
import org.wikibrain.core.nlp.Dictionary;
import org.wikibrain.sr.wikify.WBCorpusDocReader;
import org.wikibrain.sr.wikify.WbCorpusLineReader;
import org.wikibrain.utils.*;
import java.io.*;
import java.util.*;
import java.util.concurrent.atomic.AtomicLong;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.regex.Matcher;
/**
* @author Shilad Sen
*
* Heavily adapted from https://github.com/piskvorky/gensim/blob/develop/gensim/models/word2vec.py
*/
public class Word2VecTrainer {
private static final Logger LOG = LoggerFactory.getLogger(Word2VecTrainer.class);
private static final int ARTICLE_COUNT_BONUS = 10;
private static final int MAX_EXP = 6;
private static final int EXP_TABLE_SIZE = 1000;
private final Language language;
private final LocalPageDao pageDao;
// These are actually indexes and counts of hashes of words.
private final TLongIntMap wordIndexes = new TLongIntHashMap();
private final TLongIntMap wordCounts = new TLongIntHashMap();
// Mapping from article id to hash of string representation ("/w/en/1000/Hercule_Poirot").
private final TIntIntMap articleIndexes = new TIntIntHashMap();
// Total number of words, counting repeats many times
private long totalWords;
/**
* Minimum word frequency for it to be included in the model.
*/
private int minWordFrequency = 5;
/**
* Minimum number of times an article must be mentioned for it to be included in the model.
*/
private int minMentionFrequency = 5;
/**
* Maximum number of words
*/
private int maxWords = 5000000;
private double startingAlpha = 0.025;
private double alpha = startingAlpha;
private int window = 5;
private int layer1Size = 200;
private float syn0[][];
private float syn1[][];
/**
* Fast sigmoid function table.
*/
private static final double[] EXP_TABLE = new double[EXP_TABLE_SIZE];
static {
for (int i = 0; i < EXP_TABLE_SIZE; i++) {
EXP_TABLE[i] = Math.exp(((i / (double) EXP_TABLE_SIZE * 2 - 1) * MAX_EXP));
EXP_TABLE[i] = EXP_TABLE[i] / (EXP_TABLE[i] + 1);
}
}
private AtomicLong wordsTrainedSoFar = new AtomicLong();
private Random random = new Random();
private byte[][] wordCodes;
private int[][] wordParents;
private String[] words = null;
private boolean keepAllArticles = false;
private int iterations = 2;
public Word2VecTrainer(LocalPageDao pageDao, Language language) {
this.pageDao = pageDao;
this.language = language;
}
public void train(File directory) throws IOException, DaoException {
train(directory, true);
}
public void train(File directory, boolean wikibrainFormat) throws IOException, DaoException {
LOG.info("counting word frequencies.");
readWords(new File(directory, "dictionary.txt"));
buildTree();
syn0 = new float[wordIndexes.size()][layer1Size];
for (float[] row :syn0) {
for (int i = 0; i < row.length; i++) {
row[i] = (random.nextFloat() - 0.5f) / layer1Size;
}
}
syn1 = new float[wordIndexes.size()][layer1Size];
for (int it = 0; it < iterations; it++) {
if (wikibrainFormat) {
WBCorpusDocReader reader = new WBCorpusDocReader(new File(directory, "corpus.txt"));
ParallelForEach.iterate(reader.iterator(),
WpThreadUtils.getMaxThreads(),
1000,
new Procedure<WBCorpusDocReader.Doc>() {
@Override
public void call(WBCorpusDocReader.Doc doc) throws Exception {
int n = 0;
for (String line : doc.getLines()) {
n += trainSentence(doc.getDoc().getId(), line);
}
wordsTrainedSoFar.addAndGet(n);
// update the learning rate
alpha = Math.max(
startingAlpha * (1 - wordsTrainedSoFar.get() / (iterations * totalWords + 1.0)),
startingAlpha * 0.0001);
}
},
10000);
} else {
LineIterator iterator = FileUtils.lineIterator(new File(directory, "corpus.txt"));
ParallelForEach.iterate(iterator,
WpThreadUtils.getMaxThreads(),
1000,
new Procedure<String>() {
@Override
public void call(String sentence) throws Exception {
int n = trainSentence(null, sentence);
wordsTrainedSoFar.addAndGet(n);
// update the learning rate
alpha = Math.max(
startingAlpha * (1 - wordsTrainedSoFar.get() / (iterations * totalWords + 1.0)),
startingAlpha * 0.0001);
}
},
10000);
iterator.close();
}
}
}
public void readWords(File dictionary) throws IOException, DaoException {
LOG.info("reading word counts");
Dictionary dict = new Dictionary(language, Dictionary.WordStorage.IN_MEMORY);
dict.setCountBigrams(false);
dict.setContainsMentions(true);
dict.read(dictionary, maxWords, minWordFrequency);
totalWords = dict.getTotalCount();
List<String> top = dict.getFrequentUnigramsAndMentions(pageDao, maxWords, minWordFrequency, minMentionFrequency);
for (int i = 0; i < top.size(); i++) {
String w = top.get(i);
long h = hashWord(w);
wordIndexes.put(h, i);
if (w.startsWith("/w/")) {
int wpId = Integer.valueOf(w.split("/", 5)[3]);
articleIndexes.put(wpId, i);
wordCounts.put(h, dict.getMentionCount(wpId) + ARTICLE_COUNT_BONUS);
} else {
wordCounts.put(h, dict.getUnigramCount(w));
}
}
if (keepAllArticles) {
for (LocalPage page : pageDao.get(DaoFilter.normalPageFilter(language))) {
if (!articleIndexes.containsKey(page.getLocalId())) {
String w = page.getCompactUrl();
long h = hashWord(w);
if (wordIndexes.containsKey(h)) {
LOG.warn("hash collision on " + w + " with hash " + h);
} else {
int i = top.size();
wordIndexes.put(h, i);
top.add(w);
articleIndexes.put(page.getLocalId(), i);
wordCounts.put(h, ARTICLE_COUNT_BONUS);
}
}
}
}
words = top.toArray(new String[top.size()]);
LOG.info("retained " + dict.getNumUnigrams() + " words and " + articleIndexes.size() + " articles");
}
private int trainSentence(Integer wpId, String sentence) {
int wpIdIndex = (wpId != null && articleIndexes.containsKey(wpId)) ? articleIndexes.get(wpId) : -1;
String words[] = sentence.trim().split(" +");
TIntList indexList = new TIntArrayList(words.length * 3 / 2);
for (int i = 0; i < words.length; i++) {
int wordIndex = -1;
int mentionIndex = -1;
int mentionStart = words[i].indexOf(":/w/");
if (mentionStart >= 0) {
Matcher m = Dictionary.PATTERN_MENTION.matcher(words[i].substring(mentionStart));
if (m.matches()) {
int wpId2 = Integer.valueOf(m.group(3));
if (articleIndexes.containsKey(wpId2)) {
mentionIndex = articleIndexes.get(wpId2);
}
words[i] = words[i].substring(0, mentionStart);
}
}
if (words[i].length() > 0) {
long h = hashWord(words[i]);
if (wordIndexes.containsKey(h)) {
wordIndex = wordIndexes.get(h);
}
}
if (mentionIndex >= 0) {
if (random.nextDouble() >= 0.5) {
indexList.add(wordIndex);
indexList.add(mentionIndex);
} else {
indexList.add(mentionIndex);
indexList.add(wordIndex);
}
} else {
indexList.add(wordIndex);
}
}
int indexes[] = indexList.toArray();
float[] neu1e = new float[layer1Size];
for (int i = 0; i < indexes.length; i++) {
if (indexes[i] < 0) {
continue; // skip out of vocabulary words
}
byte [] code = wordCodes[indexes[i]];
int [] parents = wordParents[indexes[i]];
if (code.length != parents.length) {
throw new IllegalStateException();
}
// now go over all words from the (reduced) window, predicting each one in turn
int reducedWindow = random.nextInt(window);
int start = Math.max(0, i - window + reducedWindow);
int end = Math.min(indexes.length, i + window + 1 - reducedWindow);
for (int j = start; j < end; j++) {
int q;
if (i == j) {
// hack: update the parent document, if it exists.
// Otherwise word2vec skips the word itself.
q = wpIdIndex;
} else {
q = indexes[j];
}
if (q < 0) {
continue;
}
Arrays.fill(neu1e, 0f);
float l1[] = syn0[q];
for (int k = 0; k < parents.length; k++) {
float l2[] = syn1[parents[k]];
double f = WbMathUtils.dot(l1, l2);
if (f <= -MAX_EXP || f >= MAX_EXP) {
continue;
}
double s = EXP_TABLE[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))];
double g = (1 - code[k] - s) * alpha;
for (int c = 0; c < layer1Size; c++) {
neu1e[c] += g * l2[c];
l2[c] += g * l1[c];
}
}
for (int c = 0; c < layer1Size; c++) {
l1[c] += neu1e[c];
}
}
}
return indexes.length;
}
private class Node implements Comparable<Node> {
long hash;
int index;
int count;
Node left;
Node right;
private Node(long hash, int count, int index) {
this.hash = hash;
this.count = count;
this.index = index;
}
private Node(long hash, int count, int index, Node left, Node right) {
this.hash = hash;
this.count = count;
this.index = index;
this.left = left;
this.right = right;
}
public void setCode(byte [] code) {
if (hash != 0) wordCodes[index] = code;
if (left != null) left.setCode(ArrayUtils.add(code, (byte)0));
if (right != null) right.setCode(ArrayUtils.add(code, (byte)1));
}
public void setPoints(int [] points) {
if (hash != 0) wordParents[index] = points;
points = ArrayUtils.add(points, index - wordIndexes.size());
if (left != null) left.setPoints(points);
if (right != null) right.setPoints(points);
}
public int getHeight() {
int height = 0;
if (left != null) height = Math.max(height, left.getHeight());
if (right != null) height = Math.max(height, right.getHeight());
return height + 1;
}
@Override
public int compareTo(Node o) {
return count - o.count;
}
}
private void buildTree() {
LOG.info("creating initial heap");
PriorityQueue<Node> heap = new PriorityQueue<Node>();
for (long hash : wordIndexes.keys()) {
heap.add(new Node(hash, wordCounts.get(hash), wordIndexes.get(hash)));
}
LOG.info("creating huffman tree");
for (int i = 0; heap.size() > 1; i++) {
Node n1 = heap.poll();
Node n2 = heap.poll();
Node n = new Node(0, n1.count + n2.count, i + wordIndexes.size(), n1, n2);
heap.add(n);
}
Node root = heap.poll();
if (!heap.isEmpty()) {
throw new IllegalStateException();
}
this.wordParents = new int[wordIndexes.size()][];
this.wordCodes = new byte[wordIndexes.size()][];
root.setPoints(new int[0]);
root.setCode(new byte[0]);
LOG.info("built tree of height " + root.getHeight());
}
public void save(File path) throws IOException {
FileUtils.deleteQuietly(path);
path.getParentFile().mkdirs();
OutputStream stream = new BufferedOutputStream(new FileOutputStream(path));
stream.write((words.length + " " + layer1Size + "\n").getBytes());
for (String w : words) {
stream.write(w.getBytes("UTF-8"));
stream.write(' ');
float[] vector = syn0[wordIndexes.get(Word2VecUtils.hashWord(w))];
WbMathUtils.normalize(vector);
for (float f : vector) {
stream.write(floatToBytes(f));
}
}
stream.close();
}
private void test() {
long h = hashWord("person");
float [] v1 = syn0[wordIndexes.get(h)];
WbMathUtils.normalize(v1);
Map<String, Double> sims = new HashMap<String, Double>();
for (int i = 0; i < words.length; i++) {
float [] v2 = syn0[i];
WbMathUtils.normalize(v2);
double sim = WbMathUtils.dot(v1, v2);
sims.put(words[i], sim);
}
List<String> keys = new ArrayList<String>(sims.keySet());
Collections.sort(keys, new MapValueComparator(sims, false));
keys = keys.subList(0, 100);
for (String k : keys) {
System.out.println(sims.get(k) + " " + k);
}
}
public void setMaxWords(int maxWords) {
this.maxWords = maxWords;
}
public void setLayer1Size(int layer1Size) {
this.layer1Size = layer1Size;
}
public void setWindow(int window) {
this.window = window;
}
public void setKeepAllArticles(boolean keepAllArticles) {
this.keepAllArticles = keepAllArticles;
}
private static byte[] floatToBytes(float value) {
int bits = Float.floatToIntBits(value);
byte[] bytes = new byte[4];
bytes[0] = (byte)(bits & 0xff);
bytes[1] = (byte)((bits >> 8) & 0xff);
bytes[2] = (byte)((bits >> 16) & 0xff);
bytes[3] = (byte)((bits >> 24) & 0xff);
return bytes;
}
private static long hashWord(String word) {
return Word2VecUtils.hashWord(word);
}
public static void main(String args[]) throws ConfigurationException, IOException, DaoException {
Options options = new Options();
options.addOption(
new DefaultOptionBuilder()
.hasArg()
.isRequired()
.withLongOpt("output")
.withDescription("model output file")
.create("o"));
options.addOption(
new DefaultOptionBuilder()
.hasArg()
.isRequired()
.withLongOpt("input")
.withDescription("corpus input directory (as generated by WikiTextCorpusCreator)")
.create("i"));
options.addOption(
new DefaultOptionBuilder()
.hasArg()
.withLongOpt("layer1size")
.withDescription("size of the layer 1 neural network")
.create("z"));
options.addOption(
new DefaultOptionBuilder()
.hasArg()
.withLongOpt("window")
.withDescription("size of the sliding window")
.create("w"));
options.addOption(
new DefaultOptionBuilder()
.hasArg()
.withLongOpt("minfreq")
.withDescription("minimum word frequency")
.create("f"));
EnvBuilder.addStandardOptions(options);
CommandLineParser parser = new PosixParser();
CommandLine cmd;
try {
cmd = parser.parse(options, args);
} catch (ParseException e) {
System.err.println( "Invalid option usage: " + e.getMessage());
new HelpFormatter().printHelp("WikiTextCorpusCreator", options);
return;
}
Env env = new EnvBuilder(cmd).build();
Configurator c = env.getConfigurator();
Word2VecTrainer trainer = new Word2VecTrainer(
c.get(LocalPageDao.class),
env.getLanguages().getDefaultLanguage()
);
if (cmd.hasOption("f")) {
trainer.minWordFrequency = Integer.valueOf(cmd.getOptionValue("f"));
}
if (cmd.hasOption("w")) {
trainer.window = Integer.valueOf(cmd.getOptionValue("w"));
}
if (cmd.hasOption("z")) {
trainer.layer1Size = Integer.valueOf(cmd.getOptionValue("z"));
}
trainer.train(new File(cmd.getOptionValue("i")));
trainer.save(new File(cmd.getOptionValue("o")));
}
}