package org.wikipedia.miner.examples; import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.io.InputStreamReader; import org.wikipedia.miner.annotation.ArticleCleaner.SnippetLength; import org.wikipedia.miner.annotation.Disambiguator; import org.wikipedia.miner.annotation.TopicDetector; import org.wikipedia.miner.annotation.weighting.LinkDetector; import org.wikipedia.miner.db.WDatabase.DatabaseType; import org.wikipedia.miner.model.Wikipedia; import org.wikipedia.miner.util.ArticleSet; import org.wikipedia.miner.util.ArticleSetBuilder; import org.wikipedia.miner.util.Result; import org.wikipedia.miner.util.WikipediaConfiguration; import weka.classifiers.Classifier; import weka.core.Utils; public class AnnotationWorkbench { private Wikipedia _wikipedia ; //directory in which files will be stored private File _dataDir ; //classes for performing annotation private Disambiguator _disambiguator ; private TopicDetector _topicDetector ; private LinkDetector _linkDetector ; //article set files private File _artsTrain, _artsTestDisambig, _artsTestDetect ; //feature data files private File _arffDisambig, _arffDetect ; //model files private File _modelDisambig, _modelDetect ; public AnnotationWorkbench(File dataDir, Wikipedia wikipedia) throws Exception { _dataDir = dataDir ; _wikipedia = wikipedia ; _disambiguator = new Disambiguator(_wikipedia) ; _topicDetector = new TopicDetector(_wikipedia, _disambiguator) ; _linkDetector = new LinkDetector(_wikipedia) ; _artsTrain = new File(_dataDir.getPath() + "/articlesTrain.csv") ; _artsTestDisambig = new File(_dataDir.getPath() + "/articlesTestDisambig.csv") ; _artsTestDetect = new File(_dataDir.getPath() + "/articlesTestDetect.csv") ; _arffDisambig = new File(_dataDir.getPath() + "/disambig.arff") ; _arffDetect = new File(_dataDir.getPath() + "/detect.arff") ; _modelDisambig = new File(_dataDir.getPath() + "/disambig.model") ; _modelDetect = new File(_dataDir.getPath() + "/detect.model") ; } private void gatherArticleSets() throws IOException{ int[] sizes = {200,100,100} ; ArticleSet[] articleSets = new ArticleSetBuilder() .setMinOutLinks(15) .setMinInLinks(20) .setMaxListProportion(0.1) .setMinWordCount(200) .setMaxWordCount(2000) .buildExclusiveSets(sizes, _wikipedia) ; articleSets[0].save(_artsTrain) ; articleSets[1].save(_artsTestDisambig) ; articleSets[2].save(_artsTestDetect) ; } private void createArffFiles(String datasetName) throws IOException, Exception { if (!_artsTrain.canRead()) throw new Exception("Article sets have not yet been created") ; ArticleSet trainingSet = new ArticleSet(_artsTrain, _wikipedia) ; _disambiguator.train(trainingSet, SnippetLength.full, datasetName + "_disambiguation", null) ; _disambiguator.saveTrainingData(_arffDisambig) ; _disambiguator.buildDefaultClassifier(); _linkDetector.train(trainingSet, SnippetLength.full, datasetName + "_detection", _topicDetector, null) ; _linkDetector.saveTrainingData(_arffDetect) ; } private void createClassifiers(String configDisambig, String configDetect) throws Exception { if (!_arffDisambig.canRead() || !_arffDetect.canRead()) throw new Exception("Arff files have not yet been created") ; _disambiguator.loadTrainingData(_arffDisambig) ; if (configDisambig == null || configDisambig.trim().length() == 0) { _disambiguator.buildDefaultClassifier() ; } else { Classifier classifier = buildClassifierFromOptString(configDisambig) ; _disambiguator.buildClassifier(classifier) ; } _disambiguator.saveClassifier(_modelDisambig) ; _linkDetector.loadTrainingData(_arffDetect) ; if (configDetect == null || configDisambig.trim().length() == 0) { _linkDetector.buildDefaultClassifier() ; } else { Classifier classifier = buildClassifierFromOptString(configDisambig) ; _linkDetector.buildClassifier(classifier) ; } _linkDetector.saveClassifier(_modelDetect) ; } private Classifier buildClassifierFromOptString(String optString) throws Exception { String[] options = Utils.splitOptions(optString) ; String classname = options[0] ; options[0] = "" ; return (Classifier) Utils.forName(Classifier.class, classname, options) ; } private void evaluate() throws Exception { if (!_modelDisambig.canRead() || !_modelDetect.canRead()) throw(new Exception("Classifier models have not yet been created")) ; if (!_artsTestDisambig.canRead() || !_artsTestDetect.canRead()) throw(new Exception("Article sets have not yet been created")) ; ArticleSet disambigSet = new ArticleSet(_artsTestDisambig, _wikipedia) ; _disambiguator.loadClassifier(_modelDisambig) ; Result<Integer> disambigResults = _disambiguator.test(disambigSet, _wikipedia, SnippetLength.full, null) ; ArticleSet detectSet = new ArticleSet(_artsTestDetect, _wikipedia) ; _linkDetector.loadClassifier(_modelDetect) ; Result<Integer> detectResults = _linkDetector.test(detectSet, SnippetLength.full, _topicDetector, null) ; System.out.println(); System.out.println("Disambig results: " + disambigResults) ; System.out.println("Detect results: " + detectResults) ; } public static void main(String args[]) throws Exception { File dataDir = new File(args[0]) ; WikipediaConfiguration conf = new WikipediaConfiguration(new File(args[1])) ; conf.addDatabaseToCache(DatabaseType.label) ; conf.addDatabaseToCache(DatabaseType.pageLinksInNoSentences) ; Wikipedia wikipedia = new Wikipedia(conf, false) ; AnnotationWorkbench trainer = new AnnotationWorkbench(dataDir, wikipedia) ; BufferedReader input = new BufferedReader(new InputStreamReader(System.in)) ; while (true) { System.out.println("What would you like to do?") ; System.out.println(" - [1] create article sets.") ; System.out.println(" - [2] create arff files.") ; System.out.println(" - [3] create classifiers.") ; System.out.println(" - [4] evaluate classifiers.") ; System.out.println(" - or ENTER to quit.") ; String line = input.readLine() ; if (line.trim().length() == 0) break ; Integer choice = 0 ; try { choice = Integer.parseInt(line) ; } catch (Exception e) { System.out.println("Invalid Input") ; continue ; } switch(choice) { case 1: trainer.gatherArticleSets() ; break ; case 2: System.out.println("Dataset name:") ; String datasetName = input.readLine() ; trainer.createArffFiles(datasetName) ; break ; case 3: System.out.println("Disambiguation classifer config (or ENTER to use default):") ; String configDisambig = input.readLine() ; System.out.println("Detection classifer config (or ENTER to use default):") ; String configDetect = input.readLine() ; trainer.createClassifiers(configDisambig, configDetect) ; break ; case 4: trainer.evaluate() ; break ; default: System.out.println("Invalid Input") ; } } } }