/* * Copyright 2012 Research Studios Austria Forschungsges.m.b.H. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package won.matcher.utils.tensor; import org.la4j.Matrices; import org.la4j.matrix.sparse.CCSMatrix; import org.la4j.vector.functor.VectorProcedure; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; import java.text.DecimalFormat; import java.text.NumberFormat; import java.util.*; /** * Sparse third order tensor based on la4j implementation of sparse matrices. * <p/> * User: hfriedrich * Date: 09.07.2014 */ public class ThirdOrderSparseTensor { //private CCSMatrix[] slices; private ArrayList<CCSMatrix> slices; private int[] dims; public ThirdOrderSparseTensor(int dimX1, int dimX2) { dims = null; slices = new ArrayList<>(); resize(dimX1, dimX2); } public void resize(int dimX1, int dimX2) { for (int x3 = 0; x3 < slices.size(); x3++) { if (slices.get(x3) != null) { slices.set(x3, slices.get(x3).copyOfShape(dimX1, dimX2).to(Matrices.CCS)); } } dims = new int[]{dimX1, dimX2, slices.size()}; } public void setEntry(double value, int x1, int x2, int x3) { if (slices.size() <= x3) { for (int i = slices.size(); i <= x3; i++) { slices.add(i, CCSMatrix.zero(dims[0], dims[1])); } dims = new int[]{dims[0], dims[1], slices.size()}; } slices.get(x3).set(x1, x2, value); } public double getEntry(int x1, int x2, int x3) { return slices.get(x3).get(x1, x2); } public int getNonZeroEntries(int dimX3) { return slices.get(dimX3).cardinality(); } public int[] getDimensions() { return dims; } public void writeSliceToFile(String fileName, int slice) throws IOException { // write the mtx file (remove the column-major specification cause python mm does not read it) OutputStream os = new FileOutputStream(new File(fileName)); NumberFormat format = DecimalFormat.getInstance(Locale.US); os.write(slices.get(slice).toMatrixMarket(format).replace("column-major", "").getBytes()); } public Collection<Integer> getNonZeroIndicesOfRow(int x1, int x3) { NonZeroVectorProcedure nz = new NonZeroVectorProcedure(); slices.get(x3).eachNonZeroInRow(x1, nz); return nz.getNonZeroIndices(); } public boolean hasNonZeroEntryInRow(int x1, int x3) { return (slices.get(x3).getRow(x1).max() > 0.0d); } // class used to return all non-zero indices of a Vector private class NonZeroVectorProcedure implements VectorProcedure { private List<Integer> nonZeroIndices; public NonZeroVectorProcedure() { nonZeroIndices = new LinkedList<>(); } @Override public void apply(final int i, final double value) { nonZeroIndices.add(i); } public Collection<Integer> getNonZeroIndices() { return nonZeroIndices; } } }