package won.matcher.utils.tensor;/* * Copyright 2012 Research Studios Austria Forschungsges.m.b.H. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import org.junit.Assert; import org.junit.Before; import org.junit.Test; import java.util.ArrayList; import java.util.Collection; /** * User: hfriedrich * Date: 09.07.2014 */ public class ThirdOrderSparseTensorTest { private static final double DELTA = 0.001d; private ThirdOrderSparseTensor testTensor1; @Before public void initTestTensor() { testTensor1 = new ThirdOrderSparseTensor(4, 4); } @Test public void tensorCreation() { ThirdOrderSparseTensor tensor = new ThirdOrderSparseTensor(4, 3); int[] dim = {4, 3, 0}; Assert.assertArrayEquals(dim, tensor.getDimensions()); for (int x3 = 0; x3 < dim[2]; x3++) { for (int x2 = 0; x2 < dim[1]; x2++) { for (int x1 = 0; x1 < dim[0]; x1++) { Assert.assertEquals(0.0d, tensor.getEntry(x1, x2, x3), 0.0d); } } } } @Test public void setGetEntry() { testTensor1.setEntry(0.5d, 0, 0, 0); testTensor1.setEntry(1.0d, 0, 0, 0); testTensor1.setEntry(2.0d, 1, 0, 1); testTensor1.setEntry(3.0d, 0, 2, 2); testTensor1.setEntry(4.0d, 3, 3, 2); Assert.assertEquals(1.0d, testTensor1.getEntry(0, 0, 0), DELTA); Assert.assertEquals(2.0d, testTensor1.getEntry(1, 0, 1), DELTA); Assert.assertEquals(3.0d, testTensor1.getEntry(0, 2, 2), DELTA); Assert.assertEquals(4.0d, testTensor1.getEntry(3, 3, 2), DELTA); testTensor1.setEntry(0.0d, 3, 3, 2); Assert.assertEquals(0.0d, testTensor1.getEntry(3, 3, 2), DELTA); } @Test public void resizeUp() { int[] dim = testTensor1.getDimensions(); testTensor1.setEntry(1.0d, 3, 1, 2); int[] newDim = {dim[0] + 1, dim[1] + 2, dim[2] + 3}; testTensor1.resize(newDim[0], newDim[1]); Assert.assertArrayEquals(newDim, testTensor1.getDimensions()); Assert.assertEquals(1.0d, testTensor1.getEntry(3, 1, 2), DELTA); for (int x3 = 0; x3 < newDim[2]; x3++) { for (int x2 = 0; x2 < newDim[1]; x2++) { for (int x1 = 0; x1 < newDim[0]; x1++) { if (x1 != 3 || x2 != 1 || x3 != 2) { Assert.assertEquals(0.0d, testTensor1.getEntry(x1, x2, x3), 0.0d); } } } } } @Test public void resizeDown() { int[] dim = testTensor1.getDimensions(); testTensor1.setEntry(1.0d, 3, 1, 2); int[] newDim = {dim[0] - 1, dim[1] - 1, 3}; testTensor1.resize(newDim[0], newDim[1]); Assert.assertArrayEquals(newDim, testTensor1.getDimensions()); for (int x3 = 0; x3 < newDim[2]; x3++) { for (int x2 = 0; x2 < newDim[1]; x2++) { for (int x1 = 0; x1 < newDim[0]; x1++) { Assert.assertEquals(0.0d, testTensor1.getEntry(x1, x2, x3), 0.0d); } } } } @Test public void scaleSlicesDynamically() { Assert.assertEquals(0, testTensor1.getDimensions()[2]); testTensor1.setEntry(1.0, 1, 1, 0); Assert.assertEquals(1, testTensor1.getDimensions()[2]); Assert.assertEquals(1.0d, testTensor1.getEntry(1, 1, 0), DELTA); Assert.assertEquals(0.0d, testTensor1.getEntry(0, 0, 0), DELTA); testTensor1.setEntry(1.0, 2, 3, 0); Assert.assertEquals(1, testTensor1.getDimensions()[2]); testTensor1.setEntry(1.0, 2, 3, 2); Assert.assertEquals(1.0d, testTensor1.getEntry(2, 3, 2), DELTA); Assert.assertEquals(3, testTensor1.getDimensions()[2]); testTensor1.setEntry(1.0, 2, 3, 1); Assert.assertEquals(3, testTensor1.getDimensions()[2]); testTensor1.resize(5, 5); testTensor1.setEntry(1.0, 2, 3, 3); Assert.assertEquals(4, testTensor1.getDimensions()[2]); } @Test public void getNonZeroIndicesOfRow() { testTensor1.setEntry(0.5d, 0, 0, 0); testTensor1.setEntry(1.0d, 0, 1, 0); testTensor1.setEntry(0.5d, 1, 0, 0); testTensor1.setEntry(1.0d, 1, 1, 1); Collection<Integer> indices = new ArrayList<>(); indices.add(0); indices.add(1); Assert.assertEquals(indices, testTensor1.getNonZeroIndicesOfRow(0, 0)); indices.clear(); indices.add(0); Assert.assertEquals(indices, testTensor1.getNonZeroIndicesOfRow(1, 0)); } @Test public void hasNonZeroEntryInRow() { testTensor1.setEntry(0.5d, 0, 0, 0); testTensor1.setEntry(1.0d, 0, 0, 0); testTensor1.setEntry(2.0d, 1, 0, 1); testTensor1.setEntry(3.0d, 0, 2, 2); testTensor1.setEntry(3.0d, 1, 2, 2); testTensor1.setEntry(4.0d, 3, 3, 2); testTensor1.setEntry(4.0d, 3, 3, 0); Assert.assertTrue(testTensor1.hasNonZeroEntryInRow(0, 0)); Assert.assertTrue(testTensor1.hasNonZeroEntryInRow(1, 1)); Assert.assertTrue(testTensor1.hasNonZeroEntryInRow(0, 2)); Assert.assertTrue(testTensor1.hasNonZeroEntryInRow(1, 2)); Assert.assertTrue(testTensor1.hasNonZeroEntryInRow(3, 2)); Assert.assertTrue(testTensor1.hasNonZeroEntryInRow(3, 0)); Assert.assertFalse(testTensor1.hasNonZeroEntryInRow(0, 1)); Assert.assertFalse(testTensor1.hasNonZeroEntryInRow(1, 0)); Assert.assertFalse(testTensor1.hasNonZeroEntryInRow(1, 0)); Assert.assertFalse(testTensor1.hasNonZeroEntryInRow(2, 0)); Assert.assertFalse(testTensor1.hasNonZeroEntryInRow(2, 1)); Assert.assertFalse(testTensor1.hasNonZeroEntryInRow(2, 2)); } }