package won.matcher.utils.tensor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
/**
* This class builds up the relations between needs and attributes.
* It to builds an internal tensor data structure (RESCAL three-way-tensor).
* The data structure can be build incrementally and when finished written to file system
* for further processing by the RESCAL algorithm or evaluation algorithms.
*
* User: hfriedrich
* Date: 17.07.2014
*/
public class TensorMatchingData {
private static final Logger logger = LoggerFactory.getLogger(TensorMatchingData.class);
private static final int MAX_DIMENSION = 1000000;
public static final String NEED_PREFIX = "Need: ";
public static final String ATTRIBUTE_PREFIX = "Attr: ";
public static final String HEADERS_FILE = "headers.txt";
public static final String CONNECTION_SLICE_NAME = "connection";
private ThirdOrderSparseTensor tensor;
private ArrayList<String> needs;
private ArrayList<String> attributes;
private ArrayList<String> slices;
private int nextIndex = 0;
public TensorMatchingData() {
tensor = new ThirdOrderSparseTensor(MAX_DIMENSION, MAX_DIMENSION);
needs = new ArrayList<>();
attributes = new ArrayList<>();
slices = new ArrayList<>();
}
public void addNeedConnection(String need1, String need2, boolean addOnlyIfNeedsExist) {
checkAttributeOrNeedName(need1);
checkAttributeOrNeedName(need2);
if (!addOnlyIfNeedsExist || (addOnlyIfNeedsExist &&
needs.contains(need1) && needs.contains(need2))) {
int x1 = addNeed(need1);
int x2 = addNeed(need2);
int x3 = addSlice(CONNECTION_SLICE_NAME);
// connections are bidirectional
tensor.setEntry(1.0d, x1, x2, x3);
tensor.setEntry(1.0d, x2, x1, x3);
}
}
public void addNeedAttribute(String sliceName, String needUri, String attributeValue) {
checkAttributeOrNeedName(needUri);
checkAttributeOrNeedName(attributeValue);
checkSliceName(sliceName, false);
int x1 = addNeed(needUri);
int x2 = addAttribute(attributeValue);
int x3 = addSlice(sliceName);
tensor.setEntry(1.0d, x1, x2, x3);
}
public void addNeedAttribute(TensorEntry entry) {
addNeedAttribute(entry.getSliceName(), entry.getNeedUri(), entry.getValue());
}
public String getFirstAttributeOfNeed(String need, String slice) {
int needIndex = needs.indexOf(need);
if (needIndex < 0) {
return null;
}
Iterator<Integer> iter = tensor.getNonZeroIndicesOfRow(needIndex, slices.indexOf(slice)).iterator();
if (iter.hasNext()) {
return attributes.get(iter.next());
}
return null;
}
public boolean isValidTensor() {
return (needs.size() > 0 && attributes.size() > 0 && slices.size() > 0 && getSliceIndex(CONNECTION_SLICE_NAME) != -1);
}
public int[] getTensorDimensions() {
return tensor.getDimensions();
}
/**
* remove empty needs without attributes and their connections by building up a new
* matching data object and add only non-empty needs and connections between those
*/
protected TensorMatchingData removeEmptyNeedsAndConnections() {
// build up a new tensor
TensorMatchingData cleanedMatchingData = new TensorMatchingData();
// add the non-empty needs
for (int i = 0; i < needs.size(); i++) {
String need = needs.get(i);
if ((need != null) && needHasAttributes(i)) {
cleanedMatchingData.addNeed(need);
}
}
// add all attributes and connections to non-empty needs
for (int i = 0; i < needs.size(); i++) {
String need = needs.get(i);
if ((need != null) && needHasAttributes(i)) {
for (int sliceIndex = 0; sliceIndex < slices.size(); sliceIndex++) {
for (int attrIndex : tensor.getNonZeroIndicesOfRow(i, sliceIndex)) {
if (slices.get(sliceIndex).equals(CONNECTION_SLICE_NAME)) {
cleanedMatchingData.addNeedConnection(need, needs.get(attrIndex), true);
} else {
cleanedMatchingData.addNeedAttribute(slices.get(sliceIndex), need, attributes.get(attrIndex));
}
}
}
}
}
return cleanedMatchingData;
}
/**
* Add a need to the need list
*
* @param need
* @return
*/
private int addNeed(String need) {
if (!needs.contains(need)) {
needs.add(nextIndex, need);
attributes.add(nextIndex, null);
nextIndex++;
}
return needs.indexOf(need);
}
/**
* Add an attribute to the attribute list
*
* @param attr
* @return
*/
private int addAttribute(String attr) {
if (!attributes.contains(attr)) {
attributes.add(nextIndex, attr);
needs.add(nextIndex, null);
nextIndex++;
}
return attributes.indexOf(attr);
}
private int addSlice(String slice) {
if (!slices.contains(slice)) {
slices.add(slice);
}
return slices.indexOf(slice);
}
/**
* check if names of needs/attributes are well-formed
*
* @param name
*/
private void checkAttributeOrNeedName(String name) {
if (name == null || name.isEmpty()) {
throw new IllegalArgumentException("Need/Attribute is not allowed to be null or empty");
}
}
private void checkSliceName(String name, boolean connectionSlice) {
if (name == null || name.isEmpty()) {
throw new IllegalArgumentException("Slice is not allowed to be null or empty");
}
if ((connectionSlice && !name.equals(CONNECTION_SLICE_NAME)) ||
(!connectionSlice && name.equals(CONNECTION_SLICE_NAME))) {
throw new IllegalArgumentException("Only connection slice is allowed the name: '" + CONNECTION_SLICE_NAME + "' ");
}
}
/**
* Used for testing
*
* @return
*/
protected ThirdOrderSparseTensor getTensor() {
return tensor;
}
/**
* After all the needs, connections and attributes have been added, this method is used
* before writing the tensor out to disk, to resize it to the right dimensions and
* remove connections of empty needs that do not have any attributes.
*
* @return
*/
protected ThirdOrderSparseTensor createFinalTensor() {
int dim = getNeeds().size() + getAttributes().size();
tensor.resize(dim, dim);
return tensor;
}
protected int getSliceIndex(String sliceName) {
return slices.indexOf(sliceName);
}
/**
* check if a need with a certain index has any attributes
*
* @param needIndex
* @return
*/
private boolean needHasAttributes(int needIndex) {
for (int i = 0; i < slices.size(); i++) {
if (tensor.hasNonZeroEntryInRow(needIndex, i) && getSliceIndex(CONNECTION_SLICE_NAME) != i) {
return true;
}
}
return false;
}
public ArrayList<String> getNeedHeaders() {
return needs;
}
public List<String> getNeeds() {
ArrayList<String> continuousList = new ArrayList<String>();
for (String need : needs) {
if (need != null) {
continuousList.add(need);
}
}
return continuousList;
}
public List<String> getAttributes() {
ArrayList<String> continuousList = new ArrayList<String>();
for (String attr : attributes) {
if (attr != null) {
continuousList.add(attr);
}
}
return continuousList;
}
public List<String> getSlices() {
ArrayList<String> continuousList = new ArrayList<String>();
continuousList.addAll(slices);
return continuousList;
}
public int getNumberOfConnections() {
int connectionSlice = getSliceIndex(CONNECTION_SLICE_NAME);
return (connectionSlice != -1) ? (getTensor().getNonZeroEntries(connectionSlice) / 2) : 0;
}
/**
* Same as {@link #writeCleanedOutputFiles(String)} but removes empty needs and their connections before writing
* the tensor
*
* @param folder
* @return cleaned tensor data
* @throws Exception
*/
public TensorMatchingData writeCleanedOutputFiles(String folder) throws IOException {
if (!isValidTensor()) {
throw new IllegalStateException("Tensor must filled with data before it can be written");
}
logger.info("remove empty needs and connections ...");
TensorMatchingData cleanedMatchingData = removeEmptyNeedsAndConnections();
logger.info("Number of needs before cleaning: " + getNeeds().size());
logger.info("Number of needs after cleaning: " + cleanedMatchingData.getNeeds().size());
logger.info("Number of attributes before cleaning: " + getAttributes().size());
logger.info("Number of attributes after cleaning: " + cleanedMatchingData.getAttributes().size());
logger.info("Number of connections before cleaning: " + getNumberOfConnections());
logger.info("Number of connections after cleaning: " + cleanedMatchingData.getNumberOfConnections());
cleanedMatchingData.writeOutputFiles(folder);
return cleanedMatchingData;
}
/**
* Write the tensor out to the file system for further processing.
* Create the following files:
* - header.txt file with the need/attribute names that correspond to the index in the tensor.
* - <Slice>.mtx files for the different slices e.g. connections, need type, title and other attributes
*
* @param folder
* @throws IOException
*/
public void writeOutputFiles(String folder) throws IOException {
File outFolder = new File(folder);
outFolder.mkdirs();
if (!outFolder.isDirectory()) {
return;
}
// write the data file
// remove the needs without attributes first
logger.info("create final tensor ...");
createFinalTensor();
int dim = tensor.getDimensions()[0];
if (dim > MAX_DIMENSION) {
logger.error("Maximum Dimension {} exceeded: {}", MAX_DIMENSION, dim);
return;
}
logger.info("create tensor data in folder: {}", folder);
for (int sliceIndex = 0; sliceIndex < slices.size(); sliceIndex++) {
logger.info("- " + slices.get(sliceIndex) + ".mtx");
tensor.writeSliceToFile(folder + "/" + slices.get(sliceIndex) + ".mtx", sliceIndex);
}
// write the headers file
FileOutputStream fos = new FileOutputStream(new File(folder + "/" + HEADERS_FILE));
OutputStreamWriter os = new OutputStreamWriter(fos, "UTF-8");
for (int i = 0; i < nextIndex; i++) {
String entity = (needs.get(i) != null) ? NEED_PREFIX + needs.get(i) : ATTRIBUTE_PREFIX + attributes.get(i);
os.append(entity + "\n");
}
os.close();
logger.info("- needs: {}", getNeeds().size());
logger.info("- attributes: {}", getAttributes().size());
logger.info("- connections: {}", tensor.getNonZeroEntries(slices.indexOf(CONNECTION_SLICE_NAME)) / 2);
logger.info("- tensor size: {} x {} x " + tensor.getDimensions()[2], tensor.getDimensions()[0],
tensor.getDimensions()[1]);
}
}