/* * Copyright (C) 2011 Roderick Baier * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package de.roderick.weberknecht; import java.net.URI; import java.nio.ByteBuffer; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.Arrays; import java.util.HashMap; public class WebSocketHandshake { private String key1 = null; private String key2 = null; private byte[] key3 = null; private byte[] expectedServerResponse = null; private URI url = null; private String origin = null; private String protocol = null; public WebSocketHandshake(URI url, String protocol) { this.url = url; this.protocol = null; generateKeys(); } public byte[] getHandshake() { String path = url.getPath(); String host = url.getHost(); origin = "http://" + host; String handshake = "GET " + path + " HTTP/1.1\r\n" + "Host: " + host + "\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key2: " + key2 + "\r\n"; if (protocol != null) { handshake += "Sec-WebSocket-Protocol: " + protocol + "\r\n"; } handshake += "Upgrade: WebSocket\r\n" + "Sec-WebSocket-Key1: " + key1 + "\r\n" + "Origin: " + origin + "\r\n" + "\r\n"; byte[] handshakeBytes = new byte[handshake.getBytes().length + 8]; System.arraycopy(handshake.getBytes(), 0, handshakeBytes, 0, handshake.getBytes().length); System.arraycopy(key3, 0, handshakeBytes, handshake.getBytes().length, 8); return handshakeBytes; } public void verifyServerResponse(byte[] bytes) throws WebSocketException { if (!Arrays.equals(bytes, expectedServerResponse)) { throw new WebSocketException("not a WebSocket Server"); } } public void verifyServerStatusLine(String statusLine) throws WebSocketException { int statusCode = Integer.valueOf(statusLine.substring(9, 12)); if (statusCode == 407) { throw new WebSocketException("connection failed: proxy authentication not supported"); } else if (statusCode == 404) { throw new WebSocketException("connection failed: 404 not found"); } else if (statusCode != 101) { throw new WebSocketException("connection failed: unknown status code " + statusCode); } } public void verifyServerHandshakeHeaders(HashMap<String, String> headers) throws WebSocketException { if (!headers.get("Upgrade").equals("WebSocket")) { throw new WebSocketException("connection failed: missing header field in server handshake: Upgrade"); } else if (!headers.get("Connection").equals("Upgrade")) { throw new WebSocketException("connection failed: missing header field in server handshake: Connection"); } else if (!headers.get("Sec-WebSocket-Origin").equals(origin)) { throw new WebSocketException("connection failed: missing header field in server handshake: Sec-WebSocket-Origin"); } // TODO see 4.1. step 41 // else if (!headers.get("Sec-WebSocket-Location").equals(url.toASCIIString())) { // System.out.println("location: " + url.toASCIIString()); // } // else if protocol } private void generateKeys() { int spaces1 = rand(1,12); int spaces2 = rand(1,12); int max1 = Integer.MAX_VALUE / spaces1; int max2 = Integer.MAX_VALUE / spaces2; int number1 = rand(0, max1); int number2 = rand(0, max2); int product1 = number1 * spaces1; int product2 = number2 * spaces2; key1 = Integer.toString(product1); key2 = Integer.toString(product2); key1 = insertRandomCharacters(key1); key2 = insertRandomCharacters(key2); key1 = insertSpaces(key1, spaces1); key2 = insertSpaces(key2, spaces2); key3 = createRandomBytes(); ByteBuffer buffer = ByteBuffer.allocate(4); buffer.putInt(number1); byte[] number1Array = buffer.array(); buffer = ByteBuffer.allocate(4); buffer.putInt(number2); byte[] number2Array = buffer.array(); byte[] challenge = new byte[16]; System.arraycopy(number1Array, 0, challenge, 0, 4); System.arraycopy(number2Array, 0, challenge, 4, 4); System.arraycopy(key3, 0, challenge, 8, 8); expectedServerResponse = md5(challenge); } private String insertRandomCharacters(String key) { int count = rand(1, 12); char[] randomChars = new char[count]; int randCount = 0; while (randCount < count) { int rand = (int) (Math.random() * 0x7e + 0x21); if (((0x21 < rand) && (rand < 0x2f)) || ((0x3a < rand) && (rand < 0x7e))) { randomChars[randCount] = (char) rand; randCount += 1; } } for (int i = 0; i < count; i++) { int split = rand(0, key.length()); String part1 = key.substring(0, split); String part2 = key.substring(split); key = part1 + randomChars[i] + part2; } return key; } private String insertSpaces(String key, int spaces) { for (int i = 0; i < spaces; i++) { int split = rand(1, key.length()-1); String part1 = key.substring(0, split); String part2 = key.substring(split); key = part1 + " " + part2; } return key; } private byte[] createRandomBytes() { byte[] bytes = new byte[8]; for (int i = 0; i < 8; i++) { bytes[i] = (byte) rand(0, 255); } return bytes; } private byte[] md5(byte[] bytes) { try { MessageDigest md = MessageDigest.getInstance("MD5"); return md.digest(bytes); } catch (NoSuchAlgorithmException e) { return null; } } private int rand(int min, int max) { int rand = (int) (Math.random() * max + min); return rand; } }