/* (c) 2014 Boundless, http://boundlessgeo.com
* This code is licensed under the GPL 2.0 license.
*/
package com.boundlessgeo.geoserver;
import com.google.common.base.Function;
import com.google.common.collect.Lists;
import org.geoserver.filters.GeoServerFilter;
import org.geoserver.platform.GeoServerExtensions;
import org.geoserver.security.GeoServerSecurityFilterChain;
import org.geoserver.security.GeoServerSecurityManager;
import org.geoserver.security.RequestFilterChain;
import org.geoserver.security.filter.GeoServerCompositeFilter;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.stereotype.Component;
import javax.annotation.Nullable;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import javax.servlet.http.HttpSession;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Pattern;
import static org.geoserver.security.GeoServerSecurityFilterChain.DEFAULT_CHAIN;
import static org.geoserver.security.GeoServerSecurityFilterChain.DEFAULT_CHAIN_NAME;
import static org.geoserver.security.GeoServerSecurityFilterChain.WEB_CHAIN_NAME;
import static org.geoserver.security.GeoServerSecurityFilterChain.WEB_LOGIN_CHAIN_NAME;
import static org.geoserver.security.GeoServerSecurityFilterChain.WEB_LOGOUT_CHAIN_NAME;
/**
* Authenticates the backend service for the webapp by reusing the web filter chain.
*/
@Component
public class AppAuthFilter implements GeoServerFilter {
static final Pattern LOGIN_RE = Pattern.compile("/api/login/?");
static final Pattern LOGOUT_RE = Pattern.compile("/api/logout/?");
@Override
public void init(FilterConfig filterConfig) throws ServletException {
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest req = (HttpServletRequest) servletRequest;
HttpServletResponse res = (HttpServletResponse) servletResponse;
String path = req.getPathInfo();
if (req.getServletPath().startsWith("/app") && path.startsWith("/api")) {
if ("POST".equalsIgnoreCase(req.getMethod()) && LOGIN_RE.matcher(path).matches()) {
// hack: we have to jump through a few hoops to piggy back on the geoserver web auth:
// 1. we fake the request path to fool the security filter
// 2. we ignore redirects
boolean success = runSecurityFilters(new HttpServletRequestWrapper(req) {
@Override
public String getServletPath() {
return "";
}
@Override
public String getPathInfo() {
return "/j_spring_security_check";
}
}, new HttpServletResponseWrapper(res) {
@Override
public void sendRedirect(String location) throws IOException {
}
}, WEB_LOGIN_CHAIN_NAME);
if (success) {
filterChain.doFilter(servletRequest, servletResponse);
}
else {
res.setStatus(401);
}
}
else if (LOGOUT_RE.matcher(path).matches()) {
// invalidate the session if it exists
HttpSession session = req.getSession(false);
if (session != null) {
session.invalidate();
}
}
else {
// two modes of authentication, basic vs form.
String chainName = req.getHeader("Authorization") != null ? DEFAULT_CHAIN_NAME : WEB_CHAIN_NAME;
if (runSecurityFilters(req, res, chainName)) {
filterChain.doFilter(servletRequest, servletResponse);
}
else {
res.setStatus(401);
}
}
}
else {
filterChain.doFilter(servletRequest, servletResponse);
}
}
boolean isAuthenticated() {
Authentication auth = SecurityContextHolder.getContext().getAuthentication();
return auth != null && auth.isAuthenticated() && !(auth instanceof AnonymousAuthenticationToken);
}
boolean runSecurityFilters(HttpServletRequest req, HttpServletResponse res, String... chainNames)
throws IOException, ServletException {
final GeoServerSecurityManager secMgr = GeoServerExtensions.bean(GeoServerSecurityManager.class);
GeoServerSecurityFilterChain secFilterChain =
new GeoServerSecurityFilterChain(secMgr.getSecurityConfig().getFilterChain());
List<Filter> filters = new ArrayList<Filter>();
for (String chainName : chainNames) {
RequestFilterChain reqFilterChain = secFilterChain.getRequestChainByName(chainName);
filters.addAll(Lists.transform(reqFilterChain.getCompiledFilterNames(), new Function<String, Filter>() {
@Override
public Filter apply(@Nullable String s) {
try {
return secMgr.loadFilter(s);
} catch (IOException e) {
//TODO: something better here
throw new RuntimeException("Unable to load security filter:" + s);
}
}
}));
}
GeoServerCompositeFilter compFilter = new GeoServerCompositeFilter();
compFilter.setNestedFilters(filters);
compFilter.doFilter(req, res, new FilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException {
}
});
return isAuthenticated();
}
@Override
public void destroy() {
}
}