/**
* This file is part of lavagna.
*
* lavagna is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* lavagna is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with lavagna. If not, see <http://www.gnu.org/licenses/>.
*/
package io.lavagna.web.security;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.UUID;
import java.util.regex.Pattern;
import static org.apache.commons.lang3.tuple.ImmutablePair.of;
public class CSFRFilter extends AbstractBaseFilter {
private static final String CSRF_TOKEN_HEADER = "X-CSRF-TOKEN";
private static final String CSRF_FORM_PARAMETER = "_csrf";
private static final Pattern CSRF_METHOD_DONT_CHECK = Pattern.compile("^GET|HEAD|OPTIONS$");
private static final Logger LOG = LogManager.getLogger();
@Override
protected void doFilterInternal(HttpServletRequest req, HttpServletResponse resp, FilterChain chain) throws IOException, ServletException {
String token = CSRFToken.getToken(req);
if (token == null) {
token = UUID.randomUUID().toString();
CSRFToken.setToken(req, token);
}
resp.setHeader(CSRF_TOKEN_HEADER, token);
if (mustCheckCSRF(req)) {
ImmutablePair<Boolean, ImmutablePair<Integer, String>> res = checkCSRF(req);
if (!res.left) {
LOG.info("wrong csrf");
resp.sendError(res.right.left, res.right.right);
return;
}
}
//continue...
chain.doFilter(req, resp);
}
private static final Pattern WEBSOCKET_FALLBACK = Pattern.compile("^/api/socket/.*$");
/**
* Return true if the filter must check the request
*
* @param request
* @return
*/
private boolean mustCheckCSRF(HttpServletRequest request) {
// ignore the websocket fallback...
if ("POST".equals(request.getMethod()) && WEBSOCKET_FALLBACK.matcher(StringUtils.removeStart(request.getRequestURI(), request.getContextPath())).matches()) {
return false;
}
return !CSRF_METHOD_DONT_CHECK.matcher(request.getMethod()).matches();
}
private static ImmutablePair<Boolean, ImmutablePair<Integer, String>> checkCSRF(HttpServletRequest request) throws IOException {
String expectedToken = CSRFToken.getToken(request);
String token = request.getHeader(CSRF_TOKEN_HEADER);
if (token == null) {
token = request.getParameter(CSRF_FORM_PARAMETER);
}
if (token == null) {
return of(false, of(HttpServletResponse.SC_FORBIDDEN, "missing token in header or parameter"));
}
if (expectedToken == null) {
return of(false, of(HttpServletResponse.SC_FORBIDDEN, "missing token from session"));
}
if (!safeArrayEquals(token.getBytes("UTF-8"), expectedToken.getBytes("UTF-8"))) {
return of(false, of(HttpServletResponse.SC_FORBIDDEN, "token is not equal to expected"));
}
return of(true, null);
}
// ------------------------------------------------------------------------
// this function has been imported from KeyCzar.
/*
* Copyright 2008 Google Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for
* the specific language governing permissions and limitations under the License.
*/
/**
* An array comparison that is safe from timing attacks. If two arrays are of equal length, this code will always
* check all elements, rather than exiting once it encounters a differing byte.
*
* @param a1
* An array to compare
* @param a2
* Another array to compare
* @return True if these arrays are both null or if they have equal length and equal bytes in all elements
*/
private static boolean safeArrayEquals(byte[] a1, byte[] a2) {
if (a1 == null || a2 == null) {
return a1 == a2;
}
if (a1.length != a2.length) {
return false;
}
byte result = 0;
for (int i = 0; i < a1.length; i++) {
result |= a1[i] ^ a2[i];
}
return result == 0;
}
}