/* * Copyright (C) 2010-2012 Stichting Akvo (Akvo Foundation) * * This file is part of Akvo FLOW. * * Akvo FLOW is free software: you can redistribute it and modify it under the terms of * the GNU Affero General Public License (AGPL) as published by the Free Software Foundation, * either version 3 of the License or any later version. * * Akvo FLOW is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. * See the GNU Affero General Public License included below for more details. * * The full license text can also be seen at <http://www.gnu.org/licenses/agpl.html>. */ package com.gallatinsystems.framework.servlet; import java.io.IOException; import java.net.URLEncoder; import java.text.DateFormat; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.TimeZone; import java.util.logging.Logger; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletResponse; import com.gallatinsystems.common.util.MD5Util; import com.gallatinsystems.common.util.PropertyUtil; import com.gallatinsystems.framework.rest.RestRequest; /** * Handles verifying that the incoming request is authorized by checking the hash. * * @author Christopher Fagiani */ public class RestAuthFilter implements Filter { private static final long MAX_TIME = 60 * 10 * 1000; // 10 minutes private static final Logger log = Logger.getLogger(RestAuthFilter.class .getName()); private static final String ENABLED_PROP = "enableRestSecurity"; private static final String REST_PRIVATE_KEY_PROP = "restPrivateKey"; private String privateKey; private boolean isEnabled = false; /** * checks to see if auth is */ @Override public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException { if (isEnabled) { try { if (isAuthorized(req)) { chain.doFilter(req, res); } else { HttpServletResponse response = (HttpServletResponse) res; response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Authorization failed"); } } catch (Exception e) { log.severe("Auth failure " + e.getMessage()); HttpServletResponse response = (HttpServletResponse) res; response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Authorization failed"); } } else { chain.doFilter(req, res); } } @SuppressWarnings({ "unchecked", "rawtypes" }) private boolean isAuthorized(ServletRequest req) throws Exception { Map paramMap = req.getParameterMap(); String incomingHash = null; long incomingTimestamp = 0; List<String> names = new ArrayList<String>(); if (paramMap != null) { names.addAll(paramMap.keySet()); Collections.sort(names); StringBuilder builder = new StringBuilder(); for (String name : names) { if (!RestRequest.HASH_PARAM.equals(name)) { if (builder.length() > 0) { builder.append("&"); } if (RestRequest.TIMESTAMP_PARAM.equals(name)) { String timestamp = ((String[]) paramMap.get(name))[0]; try { DateFormat df = new SimpleDateFormat( "yyyy/MM/dd HH:mm:ss"); df.setTimeZone(TimeZone.getTimeZone("GMT")); incomingTimestamp = df.parse(timestamp).getTime(); } catch (Exception e) { log.warning("Recived rest api request with invalid timestamp"); return false; } } String[] vals = ((String[]) paramMap.get(name)); int count = 0; for (String v : vals) { if (count > 0) { builder.append("&"); } builder.append(name).append("=").append(URLEncoder.encode(v, "UTF-8")); count++; } } else { incomingHash = ((String[]) paramMap.get(name))[0]; incomingHash = incomingHash.replaceAll(" ", "+"); } } if (incomingHash != null) { String ourHash = MD5Util.generateHMAC(builder.toString(), privateKey); if (ourHash == null) { // Do something but for now return false; return false; } if (ourHash.equals(incomingHash)) { return isTimestampValid(incomingTimestamp); } else { return false; } } else { return false; } } return false; } private boolean isTimestampValid(long theirTime) { long time = System.currentTimeMillis(); if (Math.abs(time - theirTime) > MAX_TIME) { return false; } else { return true; } } @Override public void init(FilterConfig arg) throws ServletException { String enabledFlag = PropertyUtil.getProperty(ENABLED_PROP); if (enabledFlag != null) { try { isEnabled = Boolean.parseBoolean(enabledFlag.trim()); } catch (Exception e) { log.severe("Could not parse " + ENABLED_PROP + " value of " + enabledFlag); isEnabled = false; } } privateKey = PropertyUtil.getProperty(REST_PRIVATE_KEY_PROP); } @Override public void destroy() { } }