/*
* $Id$
*
* Authors:
* Jeff Buchbinder <jeff@freemedsoftware.org>
*
* CXF Interceptor that provides HTTP Basic Authentication validation.
*
* Based on the concepts outline here:
* http://chrisdail.com/2008/03/31/apache-cxf-with-http-basic-authentication
*
* REMITT Electronic Medical Information Translation and Transmission
* Copyright (C) 1999-2014 FreeMED Software Foundation
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
package org.remitt.server.cxf;
import java.io.IOException;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.MultivaluedMap;
import org.apache.cxf.binding.soap.interceptor.SoapHeaderInterceptor;
import org.apache.cxf.configuration.security.AuthorizationPolicy;
import org.apache.cxf.endpoint.Endpoint;
import org.apache.cxf.interceptor.Fault;
import org.apache.cxf.jaxrs.ext.MessageContext;
import org.apache.cxf.jaxrs.ext.MessageContextImpl;
import org.apache.cxf.jaxrs.impl.MetadataMap;
import org.apache.cxf.message.Exchange;
import org.apache.cxf.message.Message;
import org.apache.cxf.transport.Conduit;
import org.apache.cxf.ws.addressing.EndpointReferenceType;
import org.apache.log4j.Logger;
import org.remitt.server.Configuration;
import org.remitt.server.DbUtil;
public class BasicAuthAuthorizationInterceptor extends SoapHeaderInterceptor {
@Context
private MessageContext messageContext;
protected boolean DEBUG = true;
public static final String REALM = "REMITT Services";
public static final String SQL_GET_USERS = "SELECT "
+ " u.username AS username, " + " u.passhash AS passhash "
+ " FROM tUser u "
+ " LEFT OUTER JOIN tRole r ON r.username = u.username "
+ " WHERE r.rolename = 'default' " + " GROUP BY u.username;";
protected Logger log = Logger
.getLogger(BasicAuthAuthorizationInterceptor.class);
private static final char[] HEX_CHARS = { '0', '1', '2', '3', '4', '5',
'6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', };
/** Map of allowed users to this system with their corresponding passwords. */
private static Map<String, String> users = null;
private static Long lastCached = null;
private static Long MAX_CACHE_AGE_IN_MS = 5 * 60 * 1000L; // 5 min
/* -Required */
public void setUsers(Map<String, String> u) {
users = u;
}
/**
* Cache user credentials in the system.
*/
protected void loadUsers() {
debug("BasicAuthAuthorizationInterceptor.loadUsers() called");
if (lastCached == null || System.currentTimeMillis() - lastCached > MAX_CACHE_AGE_IN_MS) {
lastCached = System.currentTimeMillis();
debug("BasicAuthAuthorizationInterceptor.loadUsers(): users not loaded, loading");
Configuration.loadConfiguration();
Connection conn = Configuration.getConnection();
users = new HashMap<String, String>();
PreparedStatement cStmt = null;
try {
cStmt = conn.prepareStatement(SQL_GET_USERS);
if (cStmt.execute()) {
ResultSet r = cStmt.getResultSet();
while (r.next()) {
String user = r.getString("username");
String pass = r.getString("passhash");
users.put(user, pass);
debug("Found user = " + user + ", passhash = " + pass);
}
r.close();
}
} catch (NullPointerException npe) {
log.error("Caught NullPointerException", npe);
debug("Caught NullPointerException: " + npe.toString());
} catch (SQLException e) {
log.error("Caught SQLException", e);
debug("Caught SQLException: " + e.toString());
} finally {
DbUtil.closeSafely(cStmt);
DbUtil.closeSafely(conn);
}
}
}
@Override
public void handleMessage(Message message) throws Fault {
debug("BasicAuthAuthorizationInterceptor.handleMessage() called");
/*
* Attempt to load users hash, will skip if users have already been
* loaded.
*/
loadUsers();
// This is set by CXF
AuthorizationPolicy policy = message.get(AuthorizationPolicy.class);
/*
* If the policy is not set, the user did not specify credentials, a 401
* is sent to the client to indicate that authentication is required
*/
if (policy == null) {
if (log.isDebugEnabled()) {
log.debug("User attempted to log in with no credentials");
debug("User attempted to log in with no credentials");
}
sendErrorResponse(message, HttpURLConnection.HTTP_UNAUTHORIZED);
return;
}
if (log.isDebugEnabled()) {
log.debug("Logging in use: " + policy.getUserName());
}
// Verify the password
String realPassword = users.get(policy.getUserName());
if (DEBUG) {
debug("md5 hash of users.get(user) = " + realPassword);
debug("md5 hash of policy's password = "
+ md5hash(policy.getPassword()));
}
if (realPassword == null
|| !realPassword.equals(md5hash(policy.getPassword()))) {
log.warn("Invalid username or password for user: "
+ policy.getUserName());
debug("Invalid username or password for user: "
+ policy.getUserName());
sendErrorResponse(message, HttpURLConnection.HTTP_FORBIDDEN);
}
debug("Message should be clear to finish being handled, auth succeeded");
if (messageContext != null) {
debug("MessageContext object set");
messageContext.put("principal", policy.getUserName());
} else {
messageContext = new MessageContextImpl(message);
messageContext.put("principal", policy.getUserName());
}
message.put("X-Principal-Username", policy.getUserName());
MultivaluedMap<String, Object> headers = new MetadataMap<String, Object>();
headers.putSingle("X-Principal-Username", policy.getUserName());
message.put(Message.PROTOCOL_HEADERS, headers);
message.getInterceptorChain().resume();
}
@SuppressWarnings("unchecked")
private void sendErrorResponse(Message message, int responseCode) {
Message outMessage = getOutMessage(message);
outMessage.put(Message.RESPONSE_CODE, responseCode);
// Set the response headers
Map<String, List<String>> responseHeaders = (Map<String, List<String>>) message
.get(Message.PROTOCOL_HEADERS);
if (responseHeaders != null) {
responseHeaders.put("WWW-Authenticate",
Arrays.asList(new String[] { "Basic realm=" + REALM }));
responseHeaders.put("Content-length",
Arrays.asList(new String[] { "0" }));
}
message.getInterceptorChain().abort();
try {
getConduit(message).prepare(outMessage);
close(outMessage);
} catch (IOException e) {
log.warn(e.getMessage(), e);
}
}
private Message getOutMessage(Message inMessage) {
Exchange exchange = inMessage.getExchange();
Message outMessage = exchange.getOutMessage();
if (outMessage == null) {
Endpoint endpoint = exchange.get(Endpoint.class);
outMessage = endpoint.getBinding().createMessage();
exchange.setOutMessage(outMessage);
}
outMessage.putAll(inMessage);
return outMessage;
}
private Conduit getConduit(Message inMessage) throws IOException {
Exchange exchange = inMessage.getExchange();
EndpointReferenceType target = exchange
.get(EndpointReferenceType.class);
Conduit conduit = exchange.getDestination().getBackChannel(inMessage,
null, target);
exchange.setConduit(conduit);
return conduit;
}
private void close(Message outMessage) throws IOException {
OutputStream os = outMessage.getContent(OutputStream.class);
os.flush();
os.close();
}
/**
* Get MD5 hash of a string.
*
* @param original
* @return
*/
protected String md5hash(String original) {
log.info("md5 hash for " + original);
MessageDigest digest = null;
try {
digest = java.security.MessageDigest.getInstance("MD5");
digest.update(original.getBytes());
byte[] hash = digest.digest();
String hashed = asHex(hash);
log.info("md5 hashed to " + hashed);
debug("md5 hashed to " + hashed);
return hashed;
} catch (NoSuchAlgorithmException e) {
log.error("Could not find MD5 algorithm", e);
debug("Could not find MD5 algorithm: " + e.toString());
}
return null;
}
/**
* Turns array of bytes into string representing each byte as unsigned hex
* number.
*
* @param hash
* Array of bytes to convert to hex-string
* @return Generated hex string
*/
protected String asHex(byte hash[]) {
char buf[] = new char[hash.length * 2];
for (int i = 0, x = 0; i < hash.length; i++) {
buf[x++] = HEX_CHARS[(hash[i] >>> 4) & 0xf];
buf[x++] = HEX_CHARS[hash[i] & 0xf];
}
return new String(buf);
}
protected void debug(String st) {
if (DEBUG) {
System.out.println(this.getClass().getName() + "| " + st);
}
}
}