package threshold.mr04;
import static threshold.mr04.Util.calculateMPrime;
import static threshold.mr04.Util.getBytes;
import static threshold.mr04.Util.isElementOfZn;
import static threshold.mr04.Util.randomFromZn;
import static threshold.mr04.Util.randomFromZnStar;
import static threshold.mr04.Util.sha256Hash;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.math.BigInteger;
import java.security.SecureRandom;
import org.bouncycastle.asn1.sec.SECNamedCurves;
import org.bouncycastle.asn1.x9.X9ECParameters;
import org.bouncycastle.crypto.params.ECDomainParameters;
import org.bouncycastle.math.ec.ECPoint;
import threshold.mr04.data.PublicParameters;
import threshold.mr04.data.Round1Message;
import threshold.mr04.data.Round2Message;
import threshold.mr04.data.Round3Message;
import threshold.mr04.data.Round4Message;
public class Alice implements Serializable {
transient private ECDomainParameters CURVE;
private final BigInteger q;
private final byte[] gRaw;
private final byte[] qRaw;
private final int kPrime;
private BigInteger h1;
private BigInteger h2;
private BigInteger g;
private BigInteger N;
private BigInteger nHat;
private BigInteger Nsquared;
Paillier paillier;
BigInteger gPrime;
BigInteger nPrime;
BigInteger nPrimeSquared;
private final BigInteger keyShare;
private final SecureRandom rand;
private final PaillierPublicKey alicesPaillierPubKey;
private BigInteger kAlice;
private BigInteger ciphertext1;
private BigInteger ciphertext2;
// the random values used for the Paillier ciphertexts
private BigInteger zAlice;
private BigInteger r1;
private BigInteger r2;
private byte[] rRaw;
BigInteger rPrime;
BigInteger mPrime;
public Alice(BigInteger keyShare, byte[] publicKey, SecureRandom rand, Paillier paillier, PublicParameters params) {
this.rand = rand;
this.keyShare = keyShare;
this.paillier = paillier;
X9ECParameters CURVEparams = SECNamedCurves.getByName("secp256k1");
this.CURVE = new ECDomainParameters(CURVEparams.getCurve(), CURVEparams.getG(), CURVEparams.getN(),
CURVEparams.getH());
this.q = params.q;
this.gRaw = params.G(this.CURVE.getCurve()).getEncoded();
this.kPrime = params.kPrime;
this.h1 = params.h1;
this.h2 = params.h2;
g = params.alicesPaillierPubKey.g;
N = params.alicesPaillierPubKey.N;
this.nHat = params.nHat;
Nsquared = N.pow(2);
gPrime = params.otherPaillierPubKey.g;
nPrime = params.otherPaillierPubKey.N;
nPrimeSquared = nPrime.pow(2);
qRaw = CURVE.getCurve().decodePoint(publicKey).getEncoded();
alicesPaillierPubKey = params.alicesPaillierPubKey;
}
/**
* Always treat de-serialization as a full-blown constructor, by
* validating the final state of the de-serialized object.
*/
private void readObject(ObjectInputStream aInputStream) throws ClassNotFoundException, IOException {
//always perform the default de-serialization first
aInputStream.defaultReadObject();
X9ECParameters CURVEparams = SECNamedCurves.getByName("secp256k1");
this.CURVE = new ECDomainParameters(CURVEparams.getCurve(), CURVEparams.getG(), CURVEparams.getN(),
CURVEparams.getH());
}
/**
* This is the default implementation of writeObject.
* Customize if necessary.
*/
private void writeObject(ObjectOutputStream aOutputStream) throws IOException {
//perform the default serialization for all non-transient, non-static fields
aOutputStream.defaultWriteObject();
}
public Round1Message aliceToBobRound1(byte[] message) {
do {
kAlice = new BigInteger(256, rand);
} while (kAlice.compareTo(q) != -1);
zAlice = kAlice.modInverse(q);
r1 = new BigInteger(kPrime, rand);
r2 = new BigInteger(kPrime, rand);
ciphertext1 = Paillier.encrypt(zAlice, alicesPaillierPubKey, r1);
ciphertext2 = Paillier.encrypt(keyShare.multiply(zAlice).mod(q), alicesPaillierPubKey, r2);
mPrime = calculateMPrime(q, message);
return new Round1Message(mPrime, ciphertext1, ciphertext2);
}
public Round3Message aliceToBobRound3(Round2Message input) {
// verify that rBob * q = O
ECPoint rBob = input.getrBob(CURVE.getCurve());
if (!rBob.multiply(q).isInfinity()) {
throw new AssertionError();
}
// Ask Rosario. this is the equivalent of cheming that it's in Zp*
if (rBob.getCurve() != CURVE.getCurve()) {
throw new AssertionError();
}
setR(rBob.multiply(kAlice));
rPrime = getR().getX().toBigInteger().mod(q);
long startTime = System.nanoTime();
// first zkp
BigInteger alpha = randomFromZn(q.pow(3), rand);
BigInteger beta = randomFromZnStar(N, rand);
BigInteger gamma = randomFromZn(q.pow(3).multiply(nHat), rand);
BigInteger rho1 = randomFromZn(q.multiply(nHat), rand);
BigInteger delta = randomFromZn(q.pow(3), rand);
BigInteger mu = randomFromZnStar(N, rand);
BigInteger nu = randomFromZn(q.pow(3).multiply(nHat), rand);
BigInteger rho2 = randomFromZn(q.multiply(nHat), rand);
BigInteger rho3 = randomFromZn(q, rand);
BigInteger epsilon = randomFromZn(q, rand);
BigInteger x1 = zAlice;
BigInteger x2 = zAlice.multiply(keyShare).mod(q);
ECPoint c = getR();
ECPoint d = getG();
ECPoint w1 = getR().multiply(zAlice);
ECPoint w2 = getG().multiply(keyShare);
BigInteger m1 = ciphertext1;
BigInteger m2 = ciphertext2;
BigInteger z1 = h1.modPow(x1, nHat).multiply(h2.modPow(rho1, nHat)).mod(nHat);
ECPoint u1 = c.multiply(alpha);
BigInteger u2 = g.modPow(alpha, Nsquared).multiply(beta.modPow(N, Nsquared)).mod(Nsquared);
BigInteger u3 = h1.modPow(alpha, nHat).multiply(h2.modPow(gamma, nHat)).mod(nHat);
BigInteger z2 = h1.modPow(x2, nHat).multiply(h2.modPow(rho2, nHat)).mod(nHat);
ECPoint y = d.multiply(x2.add(rho3));
ECPoint v1 = d.multiply(delta.add(epsilon));
ECPoint v2 = w2.multiply(alpha).add(getG().multiply(epsilon));
BigInteger v3 = g.modPow(delta, Nsquared).multiply(mu.modPow(N, Nsquared)).mod(Nsquared);
BigInteger v4 = h1.modPow(delta, nHat).multiply(h2.modPow(nu, nHat)).mod(nHat);
byte[] digest = sha256Hash(getBytes(c), getBytes(w1), getBytes(d), getBytes(w2),
getBytes(m1), getBytes(m2), getBytes(z1), getBytes(u1), getBytes(u2), getBytes(u3),
getBytes(z2), getBytes(y), getBytes(v1), getBytes(v2), getBytes(v3), getBytes(v4));
if (digest == null) {
throw new AssertionError();
}
BigInteger e = new BigInteger(1, digest);
BigInteger s1 = e.multiply(x1).add(alpha);
BigInteger s2 = r1.modPow(e, N).multiply(beta).mod(N);
BigInteger s3 = e.multiply(rho1).add(gamma);
BigInteger t1 = e.multiply(x2).add(delta);
BigInteger t2 = e.multiply(rho3).add(epsilon).mod(q);
BigInteger t3 = r2.modPow(e, Nsquared).multiply(mu).mod(Nsquared);
BigInteger t4 = e.multiply(rho2).add(nu);
System.out.println("create zkp1: " + (System.nanoTime() - startTime));
return new Round3Message(getR(), z1, z2, y, e, s1, s2, s3, t1, t2, t3, t4);
}
public BigInteger[] aliceOutput(Round4Message input) {
verifyZkp2(input);
BigInteger u = input.getU();
BigInteger s = paillier.decrypt(u).mod(q);
return new BigInteger[] { rPrime, s };
}
private void verifyZkp2(Round4Message input) {
long startTime = System.nanoTime();
ECPoint c = getR().multiply(zAlice);//G.multiply(kBob);
ECPoint d = getG();
ECPoint w1 = getG();
ECPoint w2 = getQ().multiply(keyShare.modInverse(q));//G.multiply(bobShare);
BigInteger m1 = input.getUPrime();
BigInteger m2 = input.getU();
BigInteger m3 = ciphertext1.modPow(mPrime, paillier.nSquared);
BigInteger m4 = ciphertext2.modPow(rPrime, paillier.nSquared);
BigInteger z1 = input.getZ1();
BigInteger z2 = input.getZ2();
BigInteger z3 = input.getZ3();
ECPoint y = input.getY(CURVE.getCurve());
BigInteger e = input.getE();
BigInteger s1 = input.getS1();
BigInteger s2 = input.getS2();
BigInteger s3 = input.getS3();
BigInteger t1 = input.getT1();
BigInteger t2 = input.getT2();
BigInteger t3 = input.getT3();
BigInteger t4 = input.getT4();
BigInteger t5 = input.getT5();
BigInteger t6 = input.getT6();
// verification
if (!isElementOfZn(s1, q.pow(3))) {
throw new AssertionError();
}
if (!isElementOfZn(t1, q.pow(3))) {
throw new AssertionError();
}
if (!isElementOfZn(t5, q.pow(7))) {
throw new AssertionError();
}
ECPoint u1Recovered = c.multiply(s1).add(w1.negate().multiply(e));
BigInteger u2Recovered = gPrime.modPow(s1, nPrimeSquared)
.multiply(s2.modPow(nPrime, nPrimeSquared))
.multiply(m1.modPow(e.negate(), nPrimeSquared)).mod(nPrimeSquared);
BigInteger u3Recovered = h1.modPow(s1, nHat).multiply(h2.modPow(s3, nHat))
.multiply(z1.modPow(e.negate(), nHat)).mod(nHat);
ECPoint v1Recovered = d.multiply(t1.add(t2)).add(y.negate().multiply(e));
ECPoint v2Recovered = w2.multiply(s1).add(d.multiply(t2)).add(y.negate().multiply(e));
BigInteger v3Recovered = m3.modPow(s1, Nsquared).multiply(m4.modPow(t1, Nsquared))
.multiply(g.modPow(q.multiply(t5), Nsquared)).multiply(t3.modPow(N, Nsquared))
.multiply(m2.modPow(e.negate(), Nsquared)).mod(Nsquared);
BigInteger v4Recovered = h1.modPow(t1, nHat).multiply(h2.modPow(t4, nHat))
.multiply(z2.modPow(e.negate(), nHat)).mod(nHat);
BigInteger v5Recovered = h1.modPow(t5, nHat).multiply(h2.modPow(t6, nHat))
.multiply(z3.modPow(e.negate(), nHat)).mod(nHat);
byte[] digestRecovered = sha256Hash(getBytes(c), getBytes(w1), getBytes(d), getBytes(w2),
getBytes(m1), getBytes(m2), getBytes(z1), getBytes(u1Recovered),
getBytes(u2Recovered), getBytes(u3Recovered), getBytes(z2), getBytes(z3),
getBytes(y), getBytes(v1Recovered), getBytes(v2Recovered), getBytes(v3Recovered),
getBytes(v4Recovered), getBytes(v5Recovered));
if (digestRecovered == null) {
throw new AssertionError();
}
BigInteger eRecovered = new BigInteger(1, digestRecovered);
if (!e.equals(eRecovered)) {
throw new AssertionError();
}
System.out.println("verifyZkp2: " + (System.nanoTime() - startTime));
}
private void setR(ECPoint r) {
rRaw = r.getEncoded();
}
private ECPoint getR() {
return CURVE.getCurve().decodePoint(rRaw);
}
private ECPoint getG() {
return CURVE.getCurve().decodePoint(gRaw);
}
public ECPoint getQ() {
return CURVE.getCurve().decodePoint(qRaw);
}
}