package org.wikibrain.sr.word2vec; import org.apache.commons.cli.*; import org.apache.commons.lang.StringUtils; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.DefaultOptionBuilder; import org.wikibrain.core.cmd.Env; import org.wikibrain.core.cmd.EnvBuilder; import org.wikibrain.core.lang.Language; import org.wikibrain.core.nlp.Dictionary; import org.wikibrain.core.nlp.StringTokenizer; import org.wikibrain.phrases.AnchorTextPhraseAnalyzer; import org.wikibrain.phrases.PhraseAnalyzer; import org.wikibrain.phrases.PhraseAnalyzerDao; import org.wikibrain.utils.Scoreboard; import org.wikibrain.utils.WpIOUtils; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.File; import java.io.IOException; import java.util.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Given a corpus of words and phrases, iteratively concatenates bigrams that are * likely to be more meaningful together than apart. * * ie "speed of light" => "speed_of light" (pass 1) -> "speed_of_light" (pass 2) * * @author Shilad Sen */ public class Word2Phrase { private static final Logger LOG = LoggerFactory.getLogger(Word2Phrase.class); private final PhraseAnalyzerDao phraseDao; private final Language language; private int minReduce = 2; private int minCount = 5; private int threshold = -1; private StringTokenizer tokenizer = new StringTokenizer(); /** * Counts of word and phrases. */ private Dictionary dictionary; public Word2Phrase(Language language, PhraseAnalyzerDao phraseDao) { this.language = language; this.phraseDao = phraseDao; } public void concatenateBigrams(File inputDir, File outputDir, int maxWords) throws IOException { File[][] ioPairs = new File[maxWords-1][2]; for (int i = 0; i < ioPairs.length; i++) { ioPairs[i][0] = new File(outputDir, "phrases.txt." + (i+1)); ioPairs[i][1] = new File(outputDir, "phrases.txt." + (i+2)); } ioPairs[0][0] = new File(inputDir, "corpus.txt"); ioPairs[ioPairs.length - 1][1] = new File(outputDir, "corpus.txt"); for (int i = 0; i < ioPairs.length; i++) { LOG.info("pass " + i + ": joining phrases of length " + (i+1) + " to " + (i+2)); File in = ioPairs[i][0]; File out = ioPairs[i][1]; dictionary = new Dictionary(language, Dictionary.WordStorage.ON_DISK); dictionary.setCountBigrams(true); dictionary.countNormalizedFile(in); if (i == 0 && this.threshold < 0) { this.threshold = learnThreshold(i+2); // index 0 goes from unigrams to bigrams } else { this.threshold = Math.max(5, this.threshold/3); } processFile(in, out, i + 2); } dictionary.write(new File(outputDir, "dictionary.txt")); } private void processFile(File input, File output, int maxWords) throws IOException { BufferedReader reader = WpIOUtils.openBufferedReader(input); BufferedWriter writer = WpIOUtils.openWriter(output); while (true) { String line = reader.readLine(); if (line == null) { break; } String tokens[] = line.trim().split(" +"); StringBuilder newLine = new StringBuilder(); for (int i = 0; i < tokens.length; i++) { if (i > 0) { if (shouldConcatenate(tokens[i-1], tokens[i], maxWords)) { // System.out.println("concatenating " + tokens[i-1] + " and " + tokens[i]); newLine.append('_'); } else { newLine.append(' '); } } newLine.append(tokens[i]); } newLine.append('\n'); writer.write(newLine.toString()); } reader.close(); } private boolean shouldConcatenate(String token1, String token2, int maxWords) { int numWordsInToken1 = StringUtils.countMatches(token1, "_") + 1; int numWordsInToken2 = StringUtils.countMatches(token1, "_") + 1; return ( numWordsInToken1 == maxWords - 1 && numWordsInToken2 == 1 && scoreBigram(token1, token2) >= threshold ); } private double scoreBigram(String token1, String token2) { if (Word2VecUtils.PATTERN_ID.matcher(token1).matches()) { return 0.0; } int pa = dictionary.getUnigramCount(token1); int pb = dictionary.getUnigramCount(token2); if (pa < minCount || pb < minCount) { return 0.0; } int pab = dictionary.getBigramCount(token1, token2); return 1.0 * (pab - minCount) * dictionary.getTotalCount() / (pa * pb); } public int learnThreshold(int numWords) { List<String[]> bigrams = getKnownBigrams(numWords); if (bigrams.isEmpty()) { throw new IllegalStateException("Found no anchor texts of length " + numWords); } List<String[]> nonBigrams = getNonBigramSample(bigrams, bigrams.size()); List<Double> bigramScores = new ArrayList<Double>(); List<Double> nonBigramScores = new ArrayList<Double>(); for (String [] bigram : bigrams) { bigramScores.add(scoreBigram(bigram[0], bigram[1])); } for (String [] bigram : nonBigrams) { nonBigramScores.add(scoreBigram(bigram[0], bigram[1])); } Collections.sort(bigramScores); Collections.sort(nonBigramScores); double bestScore = 0.0; int bestThreshold = 0; for (int t = 0; t < 1000; t++) { // Calculate nearest indexes for threshold in each list int i1 = Collections.binarySearch(bigramScores, (double)t); int i2 = Collections.binarySearch(nonBigramScores, (double)t); // Calculate number of bigrams / non-bigrams returned for each list. int n1 = bigramScores.size() - Math.abs(i1); int n2 = nonBigramScores.size() - Math.abs(i2); double precision = 0.0; double recall = 0.0; if (n1 + n2 > 0) { precision = 1.0 * n1 / (n1 + n2); recall = 1.0 * n1 / bigramScores.size(); double score = (precision * recall) / (0.25 * precision + recall); if (score > bestScore) { bestThreshold = t; bestScore = score; } } // System.out.println("for " + n1 + ", " + n2 + " threshold " + t + ", precision is " + precision + "; recall is " + recall); } LOG.info("learned threshold " + bestThreshold + " for words of length " + numWords + " with " + bigrams.size() + " known bigrams"); return bestThreshold; } private List<String[]> getKnownBigrams(int numWords) { List<String[]> bigrams = new ArrayList<String[]>(); Iterator<String> iter = phraseDao.getAllPhrases(language); while (iter.hasNext()) { String phrase = iter.next(); List<String> tokens = tokenizer.getWords(language, phrase); if (tokens.size() == numWords) { String first = StringUtils.join(tokens.subList(0, tokens.size() - 1), '_'); String second = tokens.get(tokens.size() - 1); int n1 = dictionary.getUnigramCount(first); int n2 = dictionary.getUnigramCount(second); if (n1 >= minCount && n2 >= minCount) { bigrams.add(new String[]{ first, second }); } } } return bigrams; } private List<String[]> getNonBigramSample(List<String[]> bigrams, int n) { Scoreboard<String> topFirst = new Scoreboard<String>(1000); Scoreboard<String> topSecond = new Scoreboard<String>(1000); Set<String> known = new HashSet<String>(); for (String[] bigram : bigrams) { topFirst.add(bigram[0], dictionary.getUnigramCount(bigram[0])); topSecond.add(bigram[1], dictionary.getUnigramCount(bigram[1])); known.add(bigram[0] + "_" + bigram[1]); } Random random = new Random(); ArrayList<String[]> nonBigrams = new ArrayList<String[]>(); while (nonBigrams.size() < n) { String first = topFirst.getElement(random.nextInt(topFirst.size())); String second = topSecond.getElement(random.nextInt(topSecond.size())); if (!known.contains(first + "_" + second)) { nonBigrams.add(new String[] { first, second }); } } return nonBigrams; } public static void main(String args[]) throws IOException, ConfigurationException { Options options = new Options(); options.addOption( new DefaultOptionBuilder() .hasArg() .isRequired() .withLongOpt("input") .withDescription("corpus input directory") .create("i")); options.addOption( new DefaultOptionBuilder() .hasArg() .isRequired() .withLongOpt("output") .withDescription("corpus output directory (existing data will be lost)") .create("o")); options.addOption( new DefaultOptionBuilder() .hasArg() .withLongOpt("minCount") .withDescription("minimum frequency for unigrams that should be collapsed") .create("m")); options.addOption( new DefaultOptionBuilder() .hasArg() .withLongOpt("maxngram") .withDescription("maximum number of words that should be concatenated together") .create("g")); EnvBuilder.addStandardOptions(options); CommandLineParser parser = new PosixParser(); CommandLine cmd; try { cmd = parser.parse(options, args); } catch (ParseException e) { System.err.println( "Invalid option usage: " + e.getMessage()); new HelpFormatter().printHelp("Word2Phrase", options); return; } Env env = new EnvBuilder(cmd).build(); AnchorTextPhraseAnalyzer phraseAnalyzer = (AnchorTextPhraseAnalyzer) env.getConfigurator().get(PhraseAnalyzer.class, "anchortext"); PhraseAnalyzerDao dao = phraseAnalyzer.getDao(); Word2Phrase w2p = new Word2Phrase(env.getLanguages().getDefaultLanguage(), dao); w2p.concatenateBigrams( new File(cmd.getOptionValue("i")), new File(cmd.getOptionValue("o")), Integer.valueOf(cmd.getOptionValue("g", "4"))); } }