package ru.denull.mtproto; import java.math.BigInteger; import java.nio.ByteBuffer; import java.security.*; import java.security.spec.*; import java.security.interfaces.RSAPublicKey; import java.util.Arrays; import java.util.Random; import java.util.zip.*; import javax.crypto.BadPaddingException; import javax.crypto.Cipher; import javax.crypto.IllegalBlockSizeException; import javax.crypto.NoSuchPaddingException; import org.bouncycastle.crypto.*; import org.bouncycastle.crypto.engines.AESEngine; import org.bouncycastle.crypto.modes.IGEBlockCipher; import org.bouncycastle.crypto.params.KeyParameter; import org.bouncycastle.crypto.params.ParametersWithIV; public class CryptoUtils { public final static String TAG = "CryptoUtils"; /* Crypto functions */ public static int CRC32(byte[] buffer) { CRC32 crc = new CRC32(); crc.update(buffer); return (int) crc.getValue(); } public static int CRC32(String str) { return (int) CRC32(str.getBytes()); } public static byte[] SHA1(byte[] buf) throws NoSuchAlgorithmException { MessageDigest sha1 = MessageDigest.getInstance("SHA-1"); return sha1.digest(buf); } public static byte[] SHA1(ByteBuffer buf, int offset, int size) throws NoSuchAlgorithmException { MessageDigest sha1 = MessageDigest.getInstance("SHA-1"); for (int i = offset; i < offset + size; i++) { sha1.update(buf.get(i)); } return sha1.digest(); } public static byte[] RSAEncrypt(byte[] buf, BigInteger modulus, BigInteger pubExp) throws NoSuchAlgorithmException, InvalidKeySpecException, NoSuchPaddingException, InvalidKeyException, IllegalBlockSizeException, BadPaddingException { KeyFactory keyFactory = KeyFactory.getInstance("RSA"); RSAPublicKeySpec pubKeySpec = new RSAPublicKeySpec(modulus, pubExp); RSAPublicKey key = (RSAPublicKey) keyFactory.generatePublic(pubKeySpec); Cipher rsa = Cipher.getInstance("RSA/ECB/NoPadding"); rsa.init(Cipher.ENCRYPT_MODE, key); return rsa.doFinal(buf); } public static byte[] AES(boolean encrypt, Object buf, int offset, int size, byte[] key, byte[] iv) throws DataLengthException, IllegalStateException, InvalidCipherTextException { BufferedBlockCipher cipher = new BufferedBlockCipher(new IGEBlockCipher(new AESEngine())); cipher.init(encrypt, new ParametersWithIV(new KeyParameter(key), iv)); byte[] answer = new byte[cipher.getOutputSize(size)]; byte[] buffer = null; if (buf instanceof byte[]) { buffer = (byte[]) buf; } else { buffer = new byte[((ByteBuffer) buf).capacity()]; ((ByteBuffer) buf).rewind(); ((ByteBuffer) buf).get(buffer); } int len = cipher.processBytes(buffer, offset, size, answer, 0); cipher.doFinal(answer, len); return answer; } public static byte[] AESEncrypt(Object buf, int offset, int size, byte[] key, byte[] iv) throws DataLengthException, IllegalStateException, InvalidCipherTextException { return AES(true, buf, offset, size, key, iv); } public static byte[] AESDecrypt(Object buf, int offset, int size, byte[] key, byte[] iv) throws DataLengthException, IllegalStateException, InvalidCipherTextException { return AES(false, buf, offset, size, key, iv); } /* Misc utils */ public static byte[] substr(byte[] array, int start, int count) { return Arrays.copyOfRange(array, start, start + count); } public static byte[] concat(byte[] first, byte[]... rest) { int totalLength = first.length; for (byte[] array : rest) { totalLength += array.length; } byte[] result = Arrays.copyOf(first, totalLength); int offset = first.length; for (byte[] array : rest) { System.arraycopy(array, 0, result, offset, array.length); offset += array.length; } return result; } // returns such a prime number that x is divisible by // bad: 1524705608009140637 (shanks), 2291122014370375721 // 2012708483660954293 public static BigInteger factor(BigInteger x) { //Log.i(TAG, "Factorisation started (" + x + ")..."); if (x.compareTo(BigInteger.valueOf(Long.MAX_VALUE)) < 0) { //Log.i(TAG, "Shanks result: " + factor_shanks(x.longValue())); //Log.i(TAG, "Fermat result: " + factor_fermat(x.longValue())); //Log.i(TAG, "Pollard result: " + factor_pollard(x.longValue())); //return BigInteger.valueOf(factor_shanks(x.longValue())); return BigInteger.valueOf(findSmallMultiplierLopatin(x.longValue())); } Log.w(TAG, "Using long arithmetics"); long k = 1; while (true) { BigInteger n = x.multiply(BigInteger.valueOf(k)); BigInteger sq = sqrt(n); BigInteger p0 = sq; if (p0.multiply(p0).equals(n)) { return p0; } BigInteger q0 = BigInteger.ONE; BigInteger q1 = n.subtract(p0.multiply(p0)); BigInteger sq_q1 = sqrt(q1); while (!sq_q1.multiply(sq_q1).equals(q1)) { BigInteger b1 = sq.add(p0).divide(q1); BigInteger p1 = b1.multiply(q1).subtract(p0); BigInteger q2 = q0.add(b1.multiply(p0.subtract(p1))); p0 = p1; q0 = q1; q1 = q2; sq_q1 = sqrt(q1); } BigInteger b0 = sq.subtract(p0).divide(sq_q1); p0 = b0.multiply(sq_q1).add(p0); q0 = sq_q1; q1 = n.subtract(p0.multiply(p0)).divide(q0); while (true) { BigInteger b1 = sq.add(p0).divide(q1); BigInteger p1 = b1.multiply(q1).subtract(p0); BigInteger q2 = q0.add(b1.multiply(p0.subtract(p1))); if (p1.equals(p0)) break; p0 = p1; q0 = q1; q1 = q2; } BigInteger ans = x.gcd(p0); if (!ans.equals(BigInteger.ONE) && !ans.equals(x)) return ans; k++; } } public static long factor_shanks(long x) { long start_time = System.currentTimeMillis(); if (isPerfectSquare(x)) { return (long) Math.sqrt(x); } long k = 1; //if (x % 4 == 1) k = 2; while (true) { long n = k * x; /*long sq = (long) Math.sqrt(n); long p0 = sq; long q0 = 1; long q1 = n - p0 * p0; while (!isPerfectSquare(q1)) { long b1 = (sq + p0) / q1; long p1 = (b1 * q1) - p0; long q2 = q0 + b1 * (p0 - p1); p0 = p1; q0 = q1; q1 = q2; } long sq_q1 = (long) Math.sqrt(q1); long b0 = (sq - p0) / sq_q1; p0 = (b0 * sq_q1) + p0; q0 = sq_q1; q1 = (n - p0 * p0) / q0; while (true) { long b1 = (sq + p0) / q1; long p1 = (b1 * q1) - p0; long q2 = q0 + b1 * (p0 - p1); if (p0 == p1) break; p0 = p1; q0 = q1; q1 = q2; } long ans = gcd(p0, x);*/ long sq = (long) Math.sqrt(n); long P0 = 0; long Q0 = 1; long r0 = sq; long P1 = sq; long Q1 = n - r0 * r0; long r1 = 2 * r0 / Q1; while (!isPerfectSquare(Q1)) { long P2 = r1 * Q1 - P1; long Q2 = Q0 + (P1 - P2) * r1; long r2 = (P2 + sq) / Q2; P0 = P1; P1 = P2; Q0 = Q1; Q1 = Q2; r1 = r2; } P0 = -P1; Q0 = (long) Math.sqrt(Q1); r0 = (P0 + sq) / Q0; P1 = r0 * Q0 - P0; Q1 = (n - P1 * P1) / Q0; r1 = (P1 + sq) / Q1; while (P1 != P0) { long P2 = r1 * Q1 - P1; long Q2 = Q0 + (P1 - P2) * r1; long r2 = (P2 + sq) / Q2; P0 = P1; P1 = P2; Q0 = Q1; Q1 = Q2; r1 = r2; } long ans = gcd(x, Q0); //Log.i(TAG, "P0=" + P0 + ", Q0=" + Q0 + ", Q1=" + Q1); if (ans != 1 && ans != x) { //Log.d(TAG, "x % 4 = " + (x % 4) + ", used k = " + k); Log.i(TAG, "Shanks took " + (System.currentTimeMillis() - start_time) + "ms"); return ans; } if (k > 30) { return factor_pollard(x); } k++; } } public static long factor_pollard(long n) { long start_time = System.currentTimeMillis(); long x = 1 + (long) (Math.random() * (n - 3)); long y = 1, i = 0, stage = 2; while (gcd(n, Math.abs(x - y)) == 1) { if (i == stage) { y = x; stage <<= 1; } BigInteger t = BigInteger.valueOf(x); x = t.multiply(t).add(BigInteger.ONE).mod(BigInteger.valueOf(n)).longValue(); // x = (x * x + 1) % n; i++; } Log.i(TAG, "Pollard took " + (System.currentTimeMillis() - start_time) + "ms"); return gcd(n, Math.abs(x - y)); } public static long factor_fermat(long n) { long start_time = System.currentTimeMillis(); long x = (long) Math.sqrt(n); long y = 0; long r = x*x - y*y - n; while (true) { if (r == 0) { Log.i(TAG, "Fermat took " + (System.currentTimeMillis() - start_time) + "ms"); return (x != y) ? (x - y) : (x + y); } else if (r > 0) { r -= y + y + 1; y++; } else { r += x + x + 1; x++; } } } public static BigInteger sqrt(BigInteger x) { if (x.compareTo(BigInteger.ZERO) < 0) { Log.e(TAG, "Invalig argument for sqrt"); } if (x.equals(BigInteger.ZERO) || x.equals(BigInteger.ONE)) { return x; } BigInteger two = BigInteger.valueOf(2L); BigInteger y; for (y = x.divide(two); y.compareTo(x.divide(y)) > 0; y = ((x.divide(y)).add(y)).divide(two)); return y; } public static long gcd(long a, long b) { while (b != 0) { long t = b; b = a % t; a = t; } return a; } private final static boolean isPerfectSquare(long n) { if (n < 0) return false; switch((int)(n & 0x3F)) { case 0x00: case 0x01: case 0x04: case 0x09: case 0x10: case 0x11: case 0x19: case 0x21: case 0x24: case 0x29: case 0x31: case 0x39: long sqrt; if(n < 410881L) { //John Carmack hack, converted to Java. // See: http://www.codemaestro.com/reviews/9 int i; float x2, y; x2 = n * 0.5F; y = n; i = Float.floatToRawIntBits(y); i = 0x5f3759df - ( i >> 1 ); y = Float.intBitsToFloat(i); y = y * ( 1.5F - ( x2 * y * y ) ); sqrt = (long)(1.0F/y); } else { //Carmack hack gives incorrect answer for n >= 410881. sqrt = (long)Math.sqrt(n); } return sqrt*sqrt == n; default: return false; } } public static void debug(Object... rest) { /*for (Object o : rest) { if (o instanceof ByteBuffer) { for (int i = 0; i < ((ByteBuffer) o).capacity(); i++) { System.out.print(String.format("%2x ", ((ByteBuffer) o).get(i))); if (i % 16 == 15) System.out.println(); } System.out.println(); } else if (o instanceof byte[]) { for (int i = 0; i < ((byte[]) o).length; i++) { System.out.print(String.format("%2x ", ((byte[]) o)[i])); if (i % 16 == 15) System.out.println(); } System.out.println(); } else if (o instanceof long[]) { for (int i = 0; i < ((long[]) o).length; i++) { System.out.print(((long[]) o)[i] + " "); if (i % 16 == 15) System.out.println(); } System.out.println(); } else { System.out.println(o); } }*/ } public static BigInteger unsignedBigInt(byte[] buffer) { byte[] buf2 = new byte[buffer.length + 1]; System.arraycopy(buffer, 0, buf2, 1, buffer.length); return new BigInteger(buf2); } public static long GCD(long a, long b) { while (a != 0 && b != 0) { while ((b & 1) == 0) { b >>= 1; } while ((a & 1) == 0) { a >>= 1; } if (a > b) { a -= b; } else { b -= a; } } return b == 0 ? a : b; } public static long findSmallMultiplierLopatin(long what) { Random r = new Random(); long g = 0; int it = 0; for (int i = 0; i < 3; i++) { int q = (r.nextInt(128) & 15) + 17; long x = r.nextInt(1000000000) + 1, y = x; int lim = 1 << (i + 18); for (int j = 1; j < lim; j++) { ++it; long a = x, b = x, c = q; while (b != 0) { if ((b & 1) != 0) { c += a; if (c >= what) { c -= what; } } a += a; if (a >= what) { a -= what; } b >>= 1; } x = c; long z = x < y ? y - x : x - y; g = GCD(z, what); if (g != 1) { break; } if ((j & (j - 1)) == 0) { y = x; } } if (g > 1) { break; } } long p = what / g; return Math.min(p, g); } }