/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.crypto;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.io.BufferedInputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.math.BigInteger;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.DataInputBuffer;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.RandomDatum;
import org.apache.hadoop.test.GenericTestUtils;
import org.apache.hadoop.util.NativeCodeLoader;
import org.apache.hadoop.util.ReflectionUtils;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.Before;
import org.junit.Test;
import com.google.common.primitives.Longs;
public class TestCryptoCodec {
private static final Log LOG= LogFactory.getLog(TestCryptoCodec.class);
private static byte[] key = new byte[16];
private static byte[] iv = new byte[16];
private static final int bufferSize = 4096;
private Configuration conf = new Configuration();
private int count = 10000;
private int seed = new Random().nextInt();
private final String jceCodecClass =
"org.apache.hadoop.crypto.JceAesCtrCryptoCodec";
private final String opensslCodecClass =
"org.apache.hadoop.crypto.OpensslAesCtrCryptoCodec";
@Before
public void setUp() throws IOException {
Random random = new SecureRandom();
random.nextBytes(key);
random.nextBytes(iv);
}
@Test(timeout=120000)
public void testJceAesCtrCryptoCodec() throws Exception {
GenericTestUtils.assumeInNativeProfile();
if (!NativeCodeLoader.buildSupportsOpenssl()) {
LOG.warn("Skipping test since openSSL library not loaded");
Assume.assumeTrue(false);
}
Assert.assertEquals(null, OpensslCipher.getLoadingFailureReason());
cryptoCodecTest(conf, seed, 0, jceCodecClass, jceCodecClass, iv);
cryptoCodecTest(conf, seed, count, jceCodecClass, jceCodecClass, iv);
cryptoCodecTest(conf, seed, count, jceCodecClass, opensslCodecClass, iv);
// Overflow test, IV: xx xx xx xx xx xx xx xx ff ff ff ff ff ff ff ff
for(int i = 0; i < 8; i++) {
iv[8 + i] = (byte) 0xff;
}
cryptoCodecTest(conf, seed, count, jceCodecClass, jceCodecClass, iv);
cryptoCodecTest(conf, seed, count, jceCodecClass, opensslCodecClass, iv);
}
@Test(timeout=120000)
public void testOpensslAesCtrCryptoCodec() throws Exception {
GenericTestUtils.assumeInNativeProfile();
if (!NativeCodeLoader.buildSupportsOpenssl()) {
LOG.warn("Skipping test since openSSL library not loaded");
Assume.assumeTrue(false);
}
Assert.assertEquals(null, OpensslCipher.getLoadingFailureReason());
cryptoCodecTest(conf, seed, 0, opensslCodecClass, opensslCodecClass, iv);
cryptoCodecTest(conf, seed, count, opensslCodecClass, opensslCodecClass, iv);
cryptoCodecTest(conf, seed, count, opensslCodecClass, jceCodecClass, iv);
// Overflow test, IV: xx xx xx xx xx xx xx xx ff ff ff ff ff ff ff ff
for(int i = 0; i < 8; i++) {
iv[8 + i] = (byte) 0xff;
}
cryptoCodecTest(conf, seed, count, opensslCodecClass, opensslCodecClass, iv);
cryptoCodecTest(conf, seed, count, opensslCodecClass, jceCodecClass, iv);
}
private void cryptoCodecTest(Configuration conf, int seed, int count,
String encCodecClass, String decCodecClass, byte[] iv) throws IOException,
GeneralSecurityException {
CryptoCodec encCodec = null;
try {
encCodec = (CryptoCodec)ReflectionUtils.newInstance(
conf.getClassByName(encCodecClass), conf);
} catch (ClassNotFoundException cnfe) {
throw new IOException("Illegal crypto codec!");
}
LOG.info("Created a Codec object of type: " + encCodecClass);
// Generate data
DataOutputBuffer data = new DataOutputBuffer();
RandomDatum.Generator generator = new RandomDatum.Generator(seed);
for(int i = 0; i < count; ++i) {
generator.next();
RandomDatum key = generator.getKey();
RandomDatum value = generator.getValue();
key.write(data);
value.write(data);
}
LOG.info("Generated " + count + " records");
// Encrypt data
DataOutputBuffer encryptedDataBuffer = new DataOutputBuffer();
CryptoOutputStream out = new CryptoOutputStream(encryptedDataBuffer,
encCodec, bufferSize, key, iv);
out.write(data.getData(), 0, data.getLength());
out.flush();
out.close();
LOG.info("Finished encrypting data");
CryptoCodec decCodec = null;
try {
decCodec = (CryptoCodec)ReflectionUtils.newInstance(
conf.getClassByName(decCodecClass), conf);
} catch (ClassNotFoundException cnfe) {
throw new IOException("Illegal crypto codec!");
}
LOG.info("Created a Codec object of type: " + decCodecClass);
// Decrypt data
DataInputBuffer decryptedDataBuffer = new DataInputBuffer();
decryptedDataBuffer.reset(encryptedDataBuffer.getData(), 0,
encryptedDataBuffer.getLength());
CryptoInputStream in = new CryptoInputStream(decryptedDataBuffer,
decCodec, bufferSize, key, iv);
DataInputStream dataIn = new DataInputStream(new BufferedInputStream(in));
// Check
DataInputBuffer originalData = new DataInputBuffer();
originalData.reset(data.getData(), 0, data.getLength());
DataInputStream originalIn = new DataInputStream(
new BufferedInputStream(originalData));
for(int i=0; i < count; ++i) {
RandomDatum k1 = new RandomDatum();
RandomDatum v1 = new RandomDatum();
k1.readFields(originalIn);
v1.readFields(originalIn);
RandomDatum k2 = new RandomDatum();
RandomDatum v2 = new RandomDatum();
k2.readFields(dataIn);
v2.readFields(dataIn);
assertTrue("original and encrypted-then-decrypted-output not equal",
k1.equals(k2) && v1.equals(v2));
// original and encrypted-then-decrypted-output have the same hashCode
Map<RandomDatum, String> m = new HashMap<RandomDatum, String>();
m.put(k1, k1.toString());
m.put(v1, v1.toString());
String result = m.get(k2);
assertEquals("k1 and k2 hashcode not equal", result, k1.toString());
result = m.get(v2);
assertEquals("v1 and v2 hashcode not equal", result, v1.toString());
}
// Decrypt data byte-at-a-time
originalData.reset(data.getData(), 0, data.getLength());
decryptedDataBuffer.reset(encryptedDataBuffer.getData(), 0,
encryptedDataBuffer.getLength());
in = new CryptoInputStream(decryptedDataBuffer,
decCodec, bufferSize, key, iv);
// Check
originalIn = new DataInputStream(new BufferedInputStream(originalData));
int expected;
do {
expected = originalIn.read();
assertEquals("Decrypted stream read by byte does not match",
expected, in.read());
} while (expected != -1);
// Seek to a certain position and decrypt
originalData.reset(data.getData(), 0, data.getLength());
decryptedDataBuffer.reset(encryptedDataBuffer.getData(), 0,
encryptedDataBuffer.getLength());
in = new CryptoInputStream(new TestCryptoStreams.FakeInputStream(
decryptedDataBuffer), decCodec, bufferSize, key, iv);
int seekPos = data.getLength() / 3;
in.seek(seekPos);
// Check
TestCryptoStreams.FakeInputStream originalInput =
new TestCryptoStreams.FakeInputStream(originalData);
originalInput.seek(seekPos);
do {
expected = originalInput.read();
assertEquals("Decrypted stream read by byte does not match",
expected, in.read());
} while (expected != -1);
LOG.info("SUCCESS! Completed checking " + count + " records");
// Check secure random generator
testSecureRandom(encCodec);
}
/** Test secure random generator */
private void testSecureRandom(CryptoCodec codec) {
// len = 16
checkSecureRandom(codec, 16);
// len = 32
checkSecureRandom(codec, 32);
// len = 128
checkSecureRandom(codec, 128);
}
private void checkSecureRandom(CryptoCodec codec, int len) {
byte[] rand = new byte[len];
byte[] rand1 = new byte[len];
codec.generateSecureRandom(rand);
codec.generateSecureRandom(rand1);
Assert.assertEquals(len, rand.length);
Assert.assertEquals(len, rand1.length);
Assert.assertFalse(Arrays.equals(rand, rand1));
}
/**
* Regression test for IV calculation, see HADOOP-11343
*/
@Test(timeout=120000)
public void testCalculateIV() throws Exception {
JceAesCtrCryptoCodec codec = new JceAesCtrCryptoCodec();
codec.setConf(conf);
SecureRandom sr = new SecureRandom();
byte[] initIV = new byte[16];
byte[] IV = new byte[16];
long iterations = 1000;
long counter = 10000;
// Overflow test, IV: 00 00 00 00 00 00 00 00 ff ff ff ff ff ff ff ff
for(int i = 0; i < 8; i++) {
initIV[8 + i] = (byte)0xff;
}
for(long j = 0; j < counter; j++) {
assertIVCalculation(codec, initIV, j, IV);
}
// Random IV and counter sequence test
for(long i = 0; i < iterations; i++) {
sr.nextBytes(initIV);
for(long j = 0; j < counter; j++) {
assertIVCalculation(codec, initIV, j, IV);
}
}
// Random IV and random counter test
for(long i = 0; i < iterations; i++) {
sr.nextBytes(initIV);
for(long j = 0; j < counter; j++) {
long c = sr.nextLong();
assertIVCalculation(codec, initIV, c, IV);
}
}
}
private void assertIVCalculation(CryptoCodec codec, byte[] initIV,
long counter, byte[] IV) {
codec.calculateIV(initIV, counter, IV);
BigInteger iv = new BigInteger(1, IV);
BigInteger ref = calculateRef(initIV, counter);
assertTrue("Calculated IV don't match with the reference", iv.equals(ref));
}
private static BigInteger calculateRef(byte[] initIV, long counter) {
byte[] cb = Longs.toByteArray(counter);
BigInteger bi = new BigInteger(1, initIV);
return bi.add(new BigInteger(1, cb));
}
}