package com.aptoide.amethyst.websockets;
import android.os.Handler;
import android.os.HandlerThread;
import android.text.TextUtils;
import android.util.Base64;
import android.util.Log;
import org.apache.http.Header;
import org.apache.http.HttpException;
import org.apache.http.HttpStatus;
import org.apache.http.NameValuePair;
import org.apache.http.StatusLine;
import org.apache.http.client.HttpResponseException;
import org.apache.http.message.BasicLineParser;
import org.apache.http.message.BasicNameValuePair;
import java.io.EOFException;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.net.Socket;
import java.net.URI;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import javax.net.SocketFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
public class WebSocketClient {
private static final String TAG = "WebSocketClient";
private URI mURI;
private Listener mListener;
private Socket mSocket;
private Thread mThread;
private HandlerThread mHandlerThread;
private Handler mHandler;
private List<BasicNameValuePair> mExtraHeaders;
private HybiParser mParser;
private boolean mConnected;
private final Object mSendLock = new Object();
private static TrustManager[] sTrustManagers;
public static void setTrustManagers(TrustManager[] tm) {
sTrustManagers = tm;
}
public WebSocketClient(URI uri, Listener listener, List<BasicNameValuePair> extraHeaders) {
mURI = uri;
mListener = listener;
mExtraHeaders = extraHeaders;
mConnected = false;
mParser = new HybiParser(this);
mHandlerThread = new HandlerThread("websocket-thread");
mHandlerThread.start();
mHandler = new Handler(mHandlerThread.getLooper());
}
public Listener getListener() {
return mListener;
}
public void connect() {
if (mThread != null && mThread.isAlive()) {
return;
}
mThread = new Thread(new Runnable() {
@Override
public void run() {
try {
int port = (mURI.getPort() != -1) ? mURI.getPort() : ((mURI.getScheme().equals("wss") || mURI.getScheme().equals("https")) ? 443 : 80);
String path = TextUtils.isEmpty(mURI.getPath()) ? "/" : mURI.getPath();
if (!TextUtils.isEmpty(mURI.getQuery())) {
path += "?" + mURI.getQuery();
}
String originScheme = mURI.getScheme().equals("wss") ? "https" : "http";
URI origin = new URI(originScheme, "//" + mURI.getHost(), null);
SocketFactory factory = (mURI.getScheme().equals("wss") || mURI.getScheme().equals("https")) ? getSSLSocketFactory() : SocketFactory.getDefault();
mSocket = factory.createSocket(mURI.getHost(), port);
PrintWriter out = new PrintWriter(mSocket.getOutputStream());
out.print("GET " + path + " HTTP/1.1\r\n");
out.print("Upgrade: websocket\r\n");
out.print("Connection: Upgrade\r\n");
out.print("Host: " + mURI.getHost()+":"+mURI.getPort() + "\r\n");
out.print("Origin: " + origin.toString() + "\r\n");
out.print("Sec-WebSocket-Key: " + createSecret() + "\r\n");
out.print("Sec-WebSocket-Version: 13\r\n");
if (mExtraHeaders != null) {
for (NameValuePair pair : mExtraHeaders) {
out.print(String.format("%s: %s\r\n", pair.getName(), pair.getValue()));
}
}
out.print("\r\n");
out.flush();
HybiParser.HappyDataInputStream stream = new HybiParser.HappyDataInputStream(mSocket.getInputStream());
// Read HTTP response status line.
StatusLine statusLine = parseStatusLine(readLine(stream));
if (statusLine == null) {
throw new HttpException("Received no reply from server.");
} else if (statusLine.getStatusCode() != HttpStatus.SC_SWITCHING_PROTOCOLS) {
throw new HttpResponseException(statusLine.getStatusCode(), statusLine.getReasonPhrase());
}
// Read HTTP response headers.
String line;
while (!TextUtils.isEmpty(line = readLine(stream))) {
Header header = parseHeader(line);
if (header.getName().equals("Sec-WebSocket-Accept")) {
// FIXME: Verify the response...
}
}
mListener.onConnect();
mConnected = true;
// Now decode websocket frames.
mParser.start(stream);
} catch (EOFException ex) {
Log.d(TAG, "WebSocket EOF!", ex);
mListener.onDisconnect(0, "EOF");
mConnected = false;
} catch (SSLException ex) {
// Connection reset by peer
Log.d(TAG, "Websocket SSL error!", ex);
mListener.onDisconnect(0, "SSL");
mConnected = false;
} catch (Exception ex) {
mListener.onError(ex);
}
}
});
mThread.start();
}
public void disconnect() {
if (mSocket != null) {
mHandler.post(new Runnable() {
@Override
public void run() {
if (mSocket != null) {
try {
mSocket.close();
} catch (IOException ex) {
Log.d(TAG, "Error while disconnecting", ex);
mListener.onError(ex);
}
mSocket = null;
}
mConnected = false;
}
});
}
}
public void send(String data) {
sendFrame(mParser.frame(data));
}
public void send(byte[] data) {
sendFrame(mParser.frame(data));
}
public boolean isConnected() {
return mConnected;
}
private StatusLine parseStatusLine(String line) {
if (TextUtils.isEmpty(line)) {
return null;
}
return BasicLineParser.parseStatusLine(line, new BasicLineParser());
}
private Header parseHeader(String line) {
return BasicLineParser.parseHeader(line, new BasicLineParser());
}
// Can't use BufferedReader because it buffers past the HTTP data.
private String readLine(HybiParser.HappyDataInputStream reader) throws IOException {
int readChar = reader.read();
if (readChar == -1) {
return null;
}
StringBuilder string = new StringBuilder("");
while (readChar != '\n') {
if (readChar != '\r') {
string.append((char) readChar);
}
readChar = reader.read();
if (readChar == -1) {
return null;
}
}
return string.toString();
}
private String createSecret() {
byte[] nonce = new byte[16];
for (int i = 0; i < 16; i++) {
nonce[i] = (byte) (Math.random() * 256);
}
return Base64.encodeToString(nonce, Base64.DEFAULT).trim();
}
void sendFrame(final byte[] frame) {
mHandler.post(new Runnable() {
@Override
public void run() {
try {
synchronized (mSendLock) {
OutputStream outputStream = mSocket.getOutputStream();
outputStream.write(frame);
outputStream.flush();
}
} catch (IOException e) {
mListener.onError(e);
} catch (NullPointerException e){
mListener.onError(e);
}
}
});
}
public interface Listener {
public void onConnect();
public void onMessage(String message);
public void onMessage(byte[] data);
public void onDisconnect(int code, String reason);
public void onError(Exception error);
}
private SSLSocketFactory getSSLSocketFactory() throws NoSuchAlgorithmException, KeyManagementException {
SSLContext context = SSLContext.getInstance("TLS");
context.init(null, sTrustManagers, null);
return context.getSocketFactory();
}
}