package org.wikibrain.sr;
import com.typesafe.config.Config;
import java.util.*;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.ArrayUtils;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.WikiBrainException;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.DaoFilter;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.lang.LocalId;
import org.wikibrain.core.lang.LocalString;
import org.wikibrain.core.model.LocalPage;
import org.wikibrain.core.model.NameSpace;
import org.wikibrain.matrix.*;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.disambig.Disambiguator;
import org.wikibrain.sr.disambig.SimilarityDisambiguator;
import org.wikibrain.sr.normalize.IdentityNormalizer;
import org.wikibrain.sr.normalize.Normalizer;
import org.wikibrain.sr.utils.Leaderboard;
import org.wikibrain.sr.utils.SrNormalizers;
import org.wikibrain.utils.ParallelForEach;
import org.wikibrain.utils.Procedure;
import org.wikibrain.utils.WpIOUtils;
import org.wikibrain.utils.WpThreadUtils;
import java.io.File;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* This abstract class provides many useful building blocks for Monolingual SR Metrics.
*/
public abstract class BaseSRMetric implements SRMetric {
private static Logger LOG = LoggerFactory.getLogger(BaseSRMetric.class);
private final String name;
private final Language language;
private File dataDir;
private Disambiguator disambiguator;
private LocalPageDao localPageDao;
private boolean shouldReadNormalizers = true;
private SrNormalizers normalizers;
private boolean buildMostSimilarCache = false;
private SparseMatrix mostSimilarCache = null;
private TIntSet mostSimilarCacheRowIds = null;
// the number of senses to consider for each phrase
private int numSenses = 5;
/**
* Returns properties about the metric.
*/
public static class SRConfig {
// minimum and maximum scores BEFORE normalization
public float minScore = -1.1f;
public float maxScore = +1.1f;
}
public BaseSRMetric(String name, Language language, LocalPageDao dao, Disambiguator disambig) {
this.name = name;
this.language = language;
this.disambiguator = disambig;
this.localPageDao = dao;
this.normalizers = new SrNormalizers();
}
public abstract SRConfig getConfig();
@Override
public File getDataDir() {
return dataDir;
}
@Override
public String getName() {
return name;
}
@Override
public void setDataDir(File dir) {
this.dataDir= dir;
}
@Override
public void setMostSimilarNormalizer(Normalizer n){
normalizers.setMostSimilarNormalizer(n);
}
@Override
public void setSimilarityNormalizer(Normalizer n){
normalizers.setSimilarityNormalizer(n);
}
@Override
public boolean similarityIsTrained() {
return normalizers.getSimilarityNormalizer().isTrained();
}
@Override
public boolean mostSimilarIsTrained() {
return normalizers.getMostSimilarNormalizer().isTrained();
}
/**
* Throws an IllegalStateException if the model has not been mostSimilarTrained.
*/
protected void ensureSimilarityTrained() {
if (!similarityIsTrained()) {
throw new IllegalStateException("Model similarity has not been trained.");
}
}
/**
* Throws an IllegalStateException if the model has not been mostSimilarTrained.
*/
protected void ensureMostSimilarTrained() {
if (!mostSimilarIsTrained()) {
throw new IllegalStateException("Model mostSimilar has not been trained.");
}
}
/**
* Use the language-specific similarity normalizer to normalize a similarity if it exists.
* Otherwise use the default similarity normalizer if it's available.
* @param sr
* @return
*/
protected SRResult normalize(SRResult sr) {
sr.score=normalize(sr.score);
return sr;
}
/**
* Use the language-specific most similar normalizer to normalize a similarity if it exists.
* Otherwise use the default most similar normalizer if it's available.
* @param srl
* @return
*/
protected SRResultList normalize(SRResultList srl) {
ensureMostSimilarTrained();
return normalizers.getMostSimilarNormalizer().normalize(srl);
}
protected double normalize (double score){
ensureSimilarityTrained();
return normalizers.getSimilarityNormalizer().normalize(score);
}
@Override
public void write() throws IOException {
WpIOUtils.mkdirsQuietly(dataDir);
normalizers.write(dataDir);
}
public void setReadNormalizers(boolean shouldRead) {
this.shouldReadNormalizers = shouldRead;
}
@Override
public void read() throws IOException {
if (!dataDir.isDirectory()) {
LOG.warn("directory " + dataDir + " does not exist; cannot read files");
return;
}
if (shouldReadNormalizers && normalizers.hasReadableNormalizers(dataDir)) {
normalizers.read(dataDir);
}
IOUtils.closeQuietly(mostSimilarCache);
if (getMostSimilarMatrixPath().isFile()) {
mostSimilarCache = new SparseMatrix(getMostSimilarMatrixPath());
}
}
@Override
public synchronized void trainSimilarity(Dataset dataset) throws DaoException {
if (!dataset.getLanguage().equals(getLanguage())) {
throw new IllegalArgumentException("SR metric has language " + getLanguage() + " but dataset has language " + dataset.getLanguage());
}
normalizers.trainSimilarity(this, dataset);
}
@Override
public synchronized void trainMostSimilar(Dataset dataset, int numResults, TIntSet validIds){
if (!dataset.getLanguage().equals(getLanguage())) {
throw new IllegalArgumentException("SR metric has language " + getLanguage() + " but dataset has language " + dataset.getLanguage());
}
normalizers.trainMostSimilar(this, disambiguator, dataset, validIds, numResults);
try {
if (buildMostSimilarCache) {
writeMostSimilarCache(numResults, mostSimilarCacheRowIds, validIds);
}
} catch (Exception e) {
LOG.error("writing most similar cache failed:", e);
}
}
@Override
public abstract SRResult similarity(int pageId1, int pageId2, boolean explanations) throws DaoException;
@Override
public SRResult similarity(String phrase1, String phrase2, boolean explanations) throws DaoException {
Language language = getLanguage();
List<LocalString> phrases = Arrays.asList(
new LocalString(language, phrase1),
new LocalString(language, phrase2));
// debugSimilarityDisambiguator(phrases);
List<LocalId> resolutions = disambiguator.disambiguateTop(phrases, null);
if (resolutions.get(0) == null || resolutions.get(1) == null) {
return new SRResult();
}
// LocalPage lp1 = localPageDao.getById(language, resolutions.get(0).getId());
// LocalPage lp2 = localPageDao.getById(language, resolutions.get(1).getId());
// System.out.println("resolved " + phrase1 + ", " + phrase2 + " to " + lp1 + ", " + lp2);
return similarity(resolutions.get(0).getId(), resolutions.get(1).getId(), explanations);
}
private void debugSimilarityDisambiguator(List<LocalString> phrases) throws DaoException {
String last = null;
boolean same = true;
StringBuffer b = new StringBuffer("results for " + phrases.get(0).getString() + ", " + phrases.get(1).getString() + "\n");
for (SimilarityDisambiguator.Criteria c : SimilarityDisambiguator.Criteria.values()) {
if (c == SimilarityDisambiguator.Criteria.SIMILARITY) {
continue; // weird, so skip for now.
}
List<LocalId> resolutions;
synchronized (disambiguator) {
((SimilarityDisambiguator)disambiguator).setCriteria(c);
resolutions = disambiguator.disambiguateTop(phrases, null);
}
String page1 = resolutions.get(0) == null ? "null" : localPageDao.getById(language, resolutions.get(0).getId()).toString();
String page2 = resolutions.get(1) == null ? "null" : localPageDao.getById(language, resolutions.get(1).getId()).toString();
b.append("\t" + c + ": " + page1 + ", " + page2 + "\n");
if (last == null)
last = page1+page2;
if (!last.equals(page1+page2)) {
same = false;
}
}
if (!same) {
System.out.println(b.toString());
}
}
@Override
public SRResultList mostSimilar(int pageId, int maxResults) throws DaoException {
return mostSimilar(pageId, maxResults, null);
}
@Override
public abstract SRResultList mostSimilar(int pageId, int maxResults, TIntSet validIds) throws DaoException;
@Override
public SRResultList mostSimilar(String phrase, int maxResults) throws DaoException {
LocalId similar = disambiguator.disambiguateTop(new LocalString(getLanguage(), phrase), null);
if (similar==null){
SRResultList resultList = new SRResultList(1);
resultList.set(0, new SRResult());
return resultList;
}
return mostSimilar(similar.getId(), maxResults);
}
@Override
public SRResultList mostSimilar(String phrase, int maxResults, TIntSet validIds) throws DaoException{
LocalId similar = disambiguator.disambiguateTop(new LocalString(getLanguage(), phrase), null);
if (similar==null){
SRResultList resultList = new SRResultList(1);
resultList.set(0, new SRResult());
return resultList;
}
return mostSimilar(similar.getId(), maxResults,validIds);
}
@Override
public double[][] cosimilarity(int[] wpRowIds, int[] wpColIds) throws DaoException {
double[][] cos = new double[wpRowIds.length][wpColIds.length];
for (int i=0; i<wpRowIds.length; i++){
for (int j=0; j<wpColIds.length; j++){
if (wpRowIds[i]==wpColIds[j]){
cos[i][j]=normalize(1.0);
} else{
cos[i][j]=similarity(wpRowIds[i], wpColIds[j], false).getScore();
}
}
}
return cos;
}
@Override
public double[][] cosimilarity(String[] rowPhrases, String[] colPhrases) throws DaoException {
double[][] cos = new double[rowPhrases.length][colPhrases.length];
for (int i=0; i<rowPhrases.length; i++){
for (int j=0; j<colPhrases.length; j++){
if (rowPhrases[i].equals(colPhrases[j])){
cos[i][j]=normalize(1.0);
}
else{
cos[i][j]=similarity(rowPhrases[i],colPhrases[j],false).getScore();
}
}
}
return cos;
}
@Override
public double[][] cosimilarity(int[] ids) throws DaoException {
double[][] cos = new double[ids.length][ids.length];
for (int i=0; i<ids.length; i++){
cos[i][i]=normalize(1.0);
}
for (int i=0; i<ids.length; i++){
for (int j=i+1; j<ids.length; j++){
cos[i][j]=similarity(ids[i], ids[j], false).getScore();
cos[j][i]=cos[i][j];
}
}
return cos;
}
@Override
public double[][] cosimilarity(String[] phrases) throws DaoException {
int ids[] = new int[phrases.length];
List<LocalString> localStringList = new ArrayList<LocalString>();
for (String phrase : phrases){
localStringList.add(new LocalString(getLanguage(), phrase));
}
List<LocalId> localIds = disambiguator.disambiguateTop(localStringList, null);
for (int i=0; i<phrases.length; i++){
ids[i] = localIds.get(i).getId();
}
return cosimilarity(ids);
}
/**
* If the cache exists, and contains at least numResults valid ids for the requested id, return it.
* Otherwise, return null.
*
* @param wpId
* @param numResults
* @param validIds
* @return
* @throws DaoException
*/
protected SRResultList getCachedMostSimilar(int wpId, int numResults, TIntSet validIds) throws DaoException {
if (mostSimilarCache == null) {
return null;
}
MatrixRow row = null;
try {
row = mostSimilarCache.getRow(wpId);
} catch (IOException e) {
throw new DaoException(e);
}
if (row == null || row.getNumCols() < numResults ) {
return null;
}
Leaderboard leaderboard = new Leaderboard(numResults);
for (int i=0; i<row.getNumCols() ; i++){
int wpId2 = row.getColIndex(i);
if (validIds == null || validIds.contains(wpId2)){
leaderboard.tallyScore(wpId2, row.getColValue(i));
}
}
SRResultList results = leaderboard.getTop();
if (results.numDocs() < numResults) {
return null;
}
return results;
}
public void writeMostSimilarCache(int maxHits) throws IOException, DaoException, WikiBrainException {
writeMostSimilarCache(maxHits, null, null);
}
/**
* Creates and writes a sparse matrix that records the top-k results for every page.
* @param maxHits
* @param rowIds
* @param colIds
* @throws IOException
* @throws DaoException
* @throws WikiBrainException
*/
public void writeMostSimilarCache(final int maxHits, TIntSet rowIds, TIntSet colIds) throws IOException, DaoException, WikiBrainException{
TIntSet allPageIds = null;
// Get all page ids
if (rowIds == null || colIds == null) {
DaoFilter pageFilter = new DaoFilter()
.setLanguages(getLanguage())
.setNameSpaces(NameSpace.ARTICLE)
.setDisambig(false)
.setRedirect(false);
Iterable<LocalPage> localPages = localPageDao.get(pageFilter);
allPageIds = new TIntHashSet();
for (LocalPage page : localPages) {
if (page != null) {
allPageIds.add(page.getLocalId());
}
}
}
if (rowIds == null) rowIds = allPageIds;
if (colIds == null) colIds = allPageIds;
getDataDir().mkdirs();
IOUtils.closeQuietly(mostSimilarCache);
SRConfig config = getConfig();
final AtomicInteger idCounter = new AtomicInteger();
final AtomicLong cellCounter = new AtomicLong();
ValueConf vconf = new ValueConf(config.minScore, config.maxScore);
final SparseMatrixWriter writer = new SparseMatrixWriter(getMostSimilarMatrixPath(), vconf);
final TIntSet colIdSet = colIds == null ? null : new TIntHashSet(colIds);
Normalizer simNormalizer = getSimilarityNormalizer();
Normalizer mostSimNormalizer = getMostSimilarNormalizer();
setMostSimilarNormalizer(new IdentityNormalizer());
setSimilarityNormalizer(new IdentityNormalizer());
try {
ParallelForEach.loop(
Arrays.asList(ArrayUtils.toObject(rowIds.toArray())),
WpThreadUtils.getMaxThreads(),
new Procedure<Integer>() {
public void call(Integer wpId) throws IOException, DaoException {
writeSim(writer, wpId, colIdSet, maxHits, idCounter, cellCounter);
}
}, Integer.MAX_VALUE);
} finally {
setSimilarityNormalizer(simNormalizer);
setMostSimilarNormalizer(mostSimNormalizer);
}
LOG.info("wrote " + cellCounter.get() + " non-zero similarity cells");
writer.finish();
mostSimilarCache = new SparseMatrix(getMostSimilarMatrixPath());
}
protected File getMostSimilarMatrixPath() {
return new File(getDataDir(), "mostSimilar.matrix");
}
private void writeSim(SparseMatrixWriter writer, Integer wpId, TIntSet colIds, int maxSimsPerDoc, AtomicInteger idCounter, AtomicLong cellCounter) throws IOException, DaoException {
if (idCounter.incrementAndGet() % 10000 == 0) {
LOG.info("finding matches for page " + idCounter.get());
}
SRResultList scores = mostSimilar(wpId, maxSimsPerDoc, colIds);
if (scores != null) {
int ids[] = scores.getIds();
cellCounter.getAndIncrement();
writer.writeRow(new SparseMatrixRow(writer.getValueConf(), wpId, ids, scores.getScoresAsFloat()));
}
}
public Language getLanguage() {
return language;
}
public Disambiguator getDisambiguator() {
return disambiguator;
}
public LocalPageDao getLocalPageDao() {
return localPageDao;
}
@Override
public Normalizer getMostSimilarNormalizer() {
return normalizers.getMostSimilarNormalizer();
}
@Override
public Normalizer getSimilarityNormalizer() {
return normalizers.getSimilarityNormalizer();
}
public SparseMatrix getMostSimilarCache() {
return mostSimilarCache;
}
public void clearMostSimilarCache() {
IOUtils.closeQuietly(mostSimilarCache);
FileUtils.deleteQuietly(getMostSimilarMatrixPath());
mostSimilarCache = null;
}
public void setBuildMostSimilarCache(boolean buildMostSimilarCache) {
this.buildMostSimilarCache = buildMostSimilarCache;
}
public void setMostSimilarCacheRowIds(TIntSet rowIds) {
this.mostSimilarCacheRowIds = rowIds;
}
protected static void configureBase(Configurator configurator, BaseSRMetric sr, Config config) throws ConfigurationException {
Config rootConfig = configurator.getConf().get();
File path = new File(rootConfig.getString("sr.metric.path"));
sr.setDataDir(FileUtils.getFile(path, sr.getName(), sr.getLanguage().getLangCode()));
// initialize normalizers
sr.setSimilarityNormalizer(configurator.get(Normalizer.class, config.getString("similaritynormalizer")));
sr.setMostSimilarNormalizer(configurator.get(Normalizer.class, config.getString("mostsimilarnormalizer")));
boolean isTraining = rootConfig.getBoolean("sr.metric.training");
if (isTraining) {
sr.setReadNormalizers(false);
}
if (config.hasPath("buildMostSimilarCache")) {
sr.setBuildMostSimilarCache(config.getBoolean("buildMostSimilarCache"));
}
try {
sr.read();
} catch (IOException e){
throw new ConfigurationException(e);
}
LOG.info("finished base configuration of metric " + sr.getName());
}
}