package org.fenixedu.bennu.alerts; import java.io.UnsupportedEncodingException; import java.net.URLDecoder; import java.nio.file.Paths; import java.util.Collection; import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpSession; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.Multimap; class FlashMapManager { private static final Object DEFAULT_FLASH_MAPS_MUTEX = new Object(); private static final String FLASH_MAPS_SESSION_ATTRIBUTE = FlashMapManager.class.getName() + ".FLASH_MAPS"; private static final String SESSION_MUTEX_ATTRIBUTE = FlashMapManager.class.getName() + ".MUTEX"; protected final Logger logger = LoggerFactory.getLogger(FlashMapManager.class); private int flashMapTimeout = 180; private void setFlashMapTimeout(int flashMapTimeout) { this.flashMapTimeout = flashMapTimeout; } private int getFlashMapTimeout() { return this.flashMapTimeout; } /** * Looks into the current request and tries to find FlashMaps that match the current request * * @param request the current request * @return a FlashMap that matches the current request */ public final FlashMap retrieveAndUpdate(HttpServletRequest request) { List<FlashMap> allFlashMaps = retrieveFlashMaps(request); if (allFlashMaps == null || allFlashMaps.isEmpty()) { return null; } if (logger.isDebugEnabled()) { logger.debug("Retrieved FlashMap(s): " + allFlashMaps); } List<FlashMap> mapsToRemove = getExpiredFlashMaps(allFlashMaps); FlashMap match = getMatchingFlashMap(allFlashMaps, request); if (match != null) { mapsToRemove.add(match); } if (!mapsToRemove.isEmpty()) { if (logger.isDebugEnabled()) { logger.debug("Removing FlashMap(s): " + mapsToRemove); } Object mutex = getFlashMapsMutex(request); if (mutex != null) { synchronized (mutex) { allFlashMaps = retrieveFlashMaps(request); if (allFlashMaps != null) { allFlashMaps.removeAll(mapsToRemove); updateFlashMaps(allFlashMaps, request); } } } else { allFlashMaps.removeAll(mapsToRemove); updateFlashMaps(allFlashMaps, request); } } return match; } private List<FlashMap> getExpiredFlashMaps(List<FlashMap> allMaps) { List<FlashMap> result = new LinkedList<FlashMap>(); for (FlashMap map : allMaps) { if (map.isExpired()) { result.add(map); } } return result; } private FlashMap getMatchingFlashMap(List<FlashMap> allMaps, HttpServletRequest request) { List<FlashMap> result = new LinkedList<FlashMap>(); for (FlashMap flashMap : allMaps) { if (isFlashMapForRequest(flashMap, request)) { result.add(flashMap); } } if (!result.isEmpty()) { Collections.sort(result); if (logger.isDebugEnabled()) { logger.debug("Found matching FlashMap(s): " + result); } return result.get(0); } return null; } /** * Helper method to process and transform a query string into a {@link com.google.common.collect.Multimap}. * * @param query the query String * @return the parameters in a Multimap */ protected Multimap<String, String> splitQuery(String query) { try { final Multimap<String, String> query_pairs = ArrayListMultimap.create(); if (query != null) { String[] pairs = query.split("&"); for (String pair : pairs) { final int idx = pair.indexOf("="); final String key = idx > 0 ? URLDecoder.decode(pair.substring(0, idx), "UTF-8") : pair; final String value = idx > 0 && pair.length() > idx + 1 ? URLDecoder.decode(pair.substring(idx + 1), "UTF-8") : null; query_pairs.put(key, value); } } return query_pairs; } catch (Exception e) { throw new RuntimeException(e); } } private boolean isFlashMapForRequest(FlashMap flashMap, HttpServletRequest request) { String expectedPath = flashMap.getTargetRequestPath(); if (expectedPath != null) { String requestUri = getOriginatingRequestUri(request); if (!requestUri.equals(expectedPath) && !requestUri.equals(expectedPath + "/")) { return false; } } Multimap<String, String> actualParams = splitQuery(request.getQueryString()); Multimap<String, String> expectedParams = flashMap.getTargetRequestParams(); for (String expectedName : expectedParams.keySet()) { Collection<String> actualValues = actualParams.get(expectedName); if (actualValues == null) { return false; } for (String expectedValue : expectedParams.get(expectedName)) { if (!actualValues.contains(expectedValue)) { return false; } } } return true; } /** * Saves the current FlashMap * * @param flashMap FlashMap to attach * @param request the current request */ public final void saveOutputFlashMap(FlashMap flashMap, HttpServletRequest request) { if (flashMap == null || flashMap.isEmpty()) { return; } String path = decodeAndNormalizePath(flashMap.getTargetRequestPath(), request); flashMap.setTargetRequestPath(path); if (logger.isDebugEnabled()) { logger.debug("Saving FlashMap=" + flashMap); } flashMap.startExpirationPeriod(getFlashMapTimeout()); Object mutex = getFlashMapsMutex(request); if (mutex != null) { synchronized (mutex) { List<FlashMap> allFlashMaps = retrieveFlashMaps(request); allFlashMaps = (allFlashMaps != null ? allFlashMaps : new CopyOnWriteArrayList<FlashMap>()); allFlashMaps.add(flashMap); updateFlashMaps(allFlashMaps, request); } } else { List<FlashMap> allFlashMaps = retrieveFlashMaps(request); allFlashMaps = (allFlashMaps != null ? allFlashMaps : new LinkedList<FlashMap>()); allFlashMaps.add(flashMap); updateFlashMaps(allFlashMaps, request); } } private String decodeRequestString(HttpServletRequest request, String source) { return decodeInternal(request, source); } private String determineEncoding(HttpServletRequest request) { String enc = request.getCharacterEncoding(); if (enc == null) { enc = "ISO-8859-1"; } return enc; } private String decodeInternal(HttpServletRequest request, String source) { String enc = determineEncoding(request); try { return URLDecoder.decode(source, enc); } catch (UnsupportedEncodingException e) { return null; } } private String decodeAndNormalizePath(String path, HttpServletRequest request) { if (path != null) { path = decodeRequestString(request, path); if (path.charAt(0) != '/') { String requestUri = getRequestUri(request); path = requestUri.substring(0, requestUri.lastIndexOf('/') + 1) + path; path = Paths.get(path).normalize().toString(); } } return path; } private String getOriginatingRequestUri(HttpServletRequest request) { String uri = (String) request.getAttribute("javax.servlet.forward.request_uri"); if (uri == null) { uri = request.getRequestURI(); } return decodeAndCleanUriString(request, uri); } private String decodeAndCleanUriString(HttpServletRequest request, String uri) { uri = removeSemicolonContent(uri); uri = decodeRequestString(request, uri); return uri; } private String removeSemicolonContent(String requestUri) { return removeSemicolonContentInternal(requestUri); } private String removeJsessionid(String requestUri) { int startIndex = requestUri.toLowerCase().indexOf(";jsessionid="); if (startIndex != -1) { int endIndex = requestUri.indexOf(';', startIndex + 12); String start = requestUri.substring(0, startIndex); requestUri = (endIndex != -1) ? start + requestUri.substring(endIndex) : start; } return requestUri; } private String removeSemicolonContentInternal(String requestUri) { int semicolonIndex = requestUri.indexOf(';'); while (semicolonIndex != -1) { int slashIndex = requestUri.indexOf('/', semicolonIndex); String start = requestUri.substring(0, semicolonIndex); requestUri = (slashIndex != -1) ? start + requestUri.substring(slashIndex) : start; semicolonIndex = requestUri.indexOf(';', semicolonIndex); } return requestUri; } private String getRequestUri(HttpServletRequest request) { String uri = (String) request.getAttribute("javax.servlet.include.request_uri"); if (uri == null) { uri = request.getRequestURI(); } return decodeAndCleanUriString(request, uri); } private List<FlashMap> retrieveFlashMaps(HttpServletRequest request) { HttpSession session = request.getSession(false); return (session != null ? (List<FlashMap>) session.getAttribute(FLASH_MAPS_SESSION_ATTRIBUTE) : null); } private void updateFlashMaps(List<FlashMap> flashMaps, HttpServletRequest request) { if (flashMaps.isEmpty()) { HttpSession session = request.getSession(false); if (session != null) { session.removeAttribute(FLASH_MAPS_SESSION_ATTRIBUTE); } } else { request.getSession().setAttribute(FLASH_MAPS_SESSION_ATTRIBUTE, flashMaps); } } private Object getFlashMapsMutex(HttpServletRequest request) { HttpSession session = request.getSession(); Object mutex = session.getAttribute(SESSION_MUTEX_ATTRIBUTE); if (mutex == null) { mutex = session; } return mutex; } }