package com.twilio.jwt.validation;
import com.google.common.base.Charsets;
import com.google.common.base.Function;
import com.google.common.base.Joiner;
import com.google.common.base.Strings;
import com.google.common.collect.Lists;
import com.google.common.hash.HashFunction;
import com.google.common.hash.Hashing;
import com.google.common.io.CharStreams;
import com.twilio.http.HttpMethod;
import com.twilio.jwt.Jwt;
import io.jsonwebtoken.SignatureAlgorithm;
import org.apache.http.Header;
import org.apache.http.HttpEntity;
import org.apache.http.HttpEntityEnclosingRequest;
import org.apache.http.HttpRequest;
import org.apache.http.message.BasicHeader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.security.PrivateKey;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class ValidationToken extends Jwt {
private static final HashFunction HASH_FUNCTION = Hashing.sha256();
private static final String CTY = "twilio-pkrv;v=1";
private static final String NEW_LINE = "\n";
private final String accountSid;
private final String credentialSid;
private final String signingKeySid;
private final String method;
private final String uri;
private final String queryString;
private final Header[] headers;
private final List<String> signedHeaders;
private final String requestBody;
private ValidationToken(Builder b) {
super(
SignatureAlgorithm.RS256,
b.privateKey,
b.credentialSid,
new Date(new Date().getTime() + b.ttl * 1000)
);
this.accountSid = b.accountSid;
this.credentialSid = b.credentialSid;
this.signingKeySid = b.signingKeySid;
this.method = b.method;
this.uri = b.uri;
this.queryString = b.queryString;
this.headers = b.headers;
this.signedHeaders = b.signedHeaders;
this.requestBody = b.requestBody;
}
@Override
public Map<String, Object> getHeaders() {
Map<String, Object> headers = new HashMap<>();
headers.put("cty", CTY);
headers.put("kid", this.credentialSid);
return headers;
}
@Override
public Map<String, Object> getClaims() {
Map<String, Object> payload = new HashMap<>();
payload.put("iss", this.signingKeySid);
payload.put("sub", this.accountSid);
// Sort the signed headers
Collections.sort(signedHeaders);
List<String> lowercaseSignedHeaders = Lists.transform(signedHeaders, LOWERCASE_STRING);
String includedHeaders = Joiner.on(";").join(lowercaseSignedHeaders);
payload.put("hrh", includedHeaders);
// Add the method and uri
StringBuilder signature = new StringBuilder();
signature.append(method).append(NEW_LINE);
signature.append(uri).append(NEW_LINE);
// Get the query args, sort and rejoin
String[] queryArgs = queryString.split("&");
Arrays.sort(queryArgs);
String sortedQueryString = Joiner.on("&").join(queryArgs);
signature.append(sortedQueryString).append(NEW_LINE);
// Normalize all the headers
Header[] lowercaseHeaders = LOWERCASE_KEYS.apply(headers);
Map<String, List<String>> combinedHeaders = COMBINE_HEADERS.apply(lowercaseHeaders);
// Add the headers that we care about
for (String header : lowercaseSignedHeaders) {
String lowercase = header.toLowerCase().trim();
if (combinedHeaders.containsKey(lowercase)) {
List<String> values = combinedHeaders.get(lowercase);
Collections.sort(values);
signature.append(lowercase)
.append(":")
.append(Joiner.on(',').join(values))
.append(NEW_LINE);
}
}
signature.append(NEW_LINE);
// Mark the headers that we care about
signature.append(includedHeaders).append(NEW_LINE);
// Hash and hex the request payload
if (!Strings.isNullOrEmpty(requestBody)) {
String hashedPayload = HASH_FUNCTION.hashString(requestBody, Charsets.UTF_8).toString();
signature.append(hashedPayload);
}
// Hash and hex the canonical request
String hashedSignature = HASH_FUNCTION.hashString(signature.toString(), Charsets.UTF_8).toString();
payload.put("rqh", hashedSignature);
return payload;
}
public static ValidationToken fromHttpRequest(
String accountSid,
String credentialSid,
String signingKeySid,
PrivateKey privateKey,
HttpRequest request,
List<String> signedHeaders
) throws IOException {
Builder builder = new Builder(accountSid, credentialSid, signingKeySid, privateKey);
String method = request.getRequestLine().getMethod();
builder.method(method);
String uri = request.getRequestLine().getUri();
if (uri.contains("?")) {
String[] uriParts = uri.split("\\?");
builder.uri(uriParts[0]);
builder.queryString(uriParts[1]);
} else {
builder.uri(uri);
}
builder.headers(request.getAllHeaders());
builder.signedHeaders(signedHeaders);
if (HttpMethod.POST.toString().equals(method.toUpperCase())) {
HttpEntity entity = ((HttpEntityEnclosingRequest)request).getEntity();
builder.requestBody(CharStreams.toString(new InputStreamReader(entity.getContent(), Charsets.UTF_8)));
}
return builder.build();
}
private static Function<Header[], Map<String, List<String>>> COMBINE_HEADERS = new Function<Header[], Map<String, List<String>>>() {
@Override
public Map<String, List<String>> apply(Header[] headers) {
Map<String, List<String>> combinedHeaders = new HashMap<>();
for (Header header : headers) {
if (combinedHeaders.containsKey(header.getName())) {
combinedHeaders.get(header.getName()).add(header.getValue());
} else {
combinedHeaders.put(header.getName(), Lists.newArrayList(header.getValue()));
}
}
return combinedHeaders;
}
};
private static Function<Header[], Header[]> LOWERCASE_KEYS = new Function<Header[], Header[]>() {
@Override
public Header[] apply(Header[] headers) {
Header[] lowercaseHeaders = new Header[headers.length];
for (int i = 0; i < headers.length; i++) {
lowercaseHeaders[i] = new BasicHeader(headers[i].getName().toLowerCase(), headers[i].getValue());
}
return lowercaseHeaders;
}
};
private static Function<String, String> LOWERCASE_STRING = new Function<String, String>() {
@Override
public String apply(String s) {
return s.toLowerCase();
}
};
public static class Builder {
private String accountSid;
private String credentialSid;
private String signingKeySid;
private PrivateKey privateKey;
private String method;
private String uri;
private String queryString = "";
private Header[] headers;
private List<String> signedHeaders = Collections.emptyList();
private String requestBody = "";
private int ttl = 300;
public Builder(String accountSid, String credentialSid, String signingKeySid, PrivateKey privateKey) {
this.accountSid = accountSid;
this.credentialSid = credentialSid;
this.signingKeySid = signingKeySid;
this.privateKey = privateKey;
}
public Builder method(String method) {
this.method = method;
return this;
}
public Builder uri(String uri) {
this.uri = uri;
return this;
}
public Builder queryString(String queryString) {
this.queryString = queryString;
return this;
}
public Builder headers(Header[] headers) {
this.headers = headers;
return this;
}
public Builder signedHeaders(List<String> signedHeaders) {
this.signedHeaders = signedHeaders;
return this;
}
public Builder requestBody(String requestBody) {
this.requestBody = requestBody;
return this;
}
public Builder ttl(int ttl) {
this.ttl = ttl;
return this;
}
public ValidationToken build() {
return new ValidationToken(this);
}
}
}