/** * This file is part of lavagna. * * lavagna is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * lavagna is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with lavagna. If not, see <http://www.gnu.org/licenses/>. */ package io.lavagna.web.security.login.oauth; import com.google.gson.Gson; import com.google.gson.GsonBuilder; import io.lavagna.web.security.Redirector; import io.lavagna.web.security.SecurityConfiguration.SessionHandler; import io.lavagna.web.security.SecurityConfiguration.User; import io.lavagna.web.security.SecurityConfiguration.Users; import org.scribe.model.*; import org.scribe.oauth.OAuthService; import org.springframework.web.util.UriUtils; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.io.UnsupportedEncodingException; import java.util.Collections; import java.util.List; import java.util.UUID; import static org.apache.commons.lang3.StringUtils.removeStart; public interface OAuthResultHandler { void handleAuthorizationUrl(HttpServletRequest req, HttpServletResponse resp) throws IOException; void handleCallback(HttpServletRequest req, HttpServletResponse resp) throws IOException; class OAuthResultHandlerAdapter implements OAuthResultHandler { private static final Gson GSON = new GsonBuilder().serializeNulls().create(); private final String provider; private final String profileUrl; private final Class<? extends RemoteUserProfile> profileClass; private final String verifierParamName; private final Users users; private final String errorPage; private final SessionHandler sessionHandler; protected final OAuthService oauthService; private final OAuthRequestBuilder reqBuilder; OAuthResultHandlerAdapter(String provider, String profileUrl, Class<? extends RemoteUserProfile> profileClass, String verifierParamName, Users users, SessionHandler sessionHandler, String errorPage, OAuthService oauthService, OAuthRequestBuilder reqBuilder) { this.provider = provider; this.profileUrl = profileUrl; this.profileClass = profileClass; this.verifierParamName = verifierParamName; // this.users = users; this.sessionHandler = sessionHandler; this.errorPage = errorPage; this.oauthService = oauthService; this.reqBuilder = reqBuilder; } private String stateForAttribute() { return "EXPECTED_STATE_FOR_" + provider; } @Override public void handleAuthorizationUrl(HttpServletRequest req, HttpServletResponse resp) throws IOException { // scribe does not support out of the box the state parameter, must // be overridden to be removed String state = UUID.randomUUID().toString(); saveStateAndRequestUrlParameter(req, state); resp.sendRedirect(oauthService.getAuthorizationUrl(null) + "&state=" + state); } protected void saveStateAndRequestUrlParameter(HttpServletRequest req, String state) throws UnsupportedEncodingException { req.getSession().setAttribute(stateForAttribute(), state); req.getSession().setAttribute("rememberMe-" + state, req.getParameter("rememberMe")); String reqUrl = req.getParameter("reqUrl"); if (reqUrl != null) { req.getSession().setAttribute("reqUrl-" + state, UriUtils.decode(reqUrl, "UTF-8")); } } // only for services that support the state parameter, must be // overridden to be ignored protected boolean validateStateParam(HttpServletRequest req) { String stateParam = req.getParameter("state"); String expectedState = (String) req.getSession().getAttribute(stateForAttribute()); req.getSession().removeAttribute(stateForAttribute()); return expectedState != null && expectedState.equals(stateParam); } @Override public void handleCallback(HttpServletRequest req, HttpServletResponse resp) throws IOException { String state = (String) req.getSession().getAttribute(stateForAttribute()); String reqUrl = (String) req.getSession().getAttribute("reqUrl-" + state); req.setAttribute("rememberMe", req.getSession().getAttribute("rememberMe-" + state)); req.getSession().removeAttribute("reqUrl-" + state); req.getSession().removeAttribute("rememberMe-" + state); if (!validateStateParam(req)) { Redirector.sendRedirect(req, resp, req.getContextPath() + "/" + removeStart(errorPage, "/"), Collections.<String, List<String>> emptyMap()); return; } // verify token Verifier verifier = new Verifier(req.getParameter(verifierParamName)); Token accessToken = oauthService.getAccessToken(reqToken(req), verifier); // fetch user profile OAuthRequest oauthRequest = reqBuilder.req(Verb.GET, profileUrl); oauthService.signRequest(accessToken, oauthRequest); Response oauthResponse = oauthRequest.send(); RemoteUserProfile profile = GSON.fromJson(oauthResponse.getBody(), profileClass); if (profile.valid(users, provider)) { String url = Redirector.cleanupRequestedUrl(reqUrl, req); User user = users.findUserByName(provider, profile.username()); sessionHandler.setUser(user.getId(), user.isAnonymous(), req, resp); Redirector.sendRedirect(req, resp, url, Collections.<String, List<String>> emptyMap()); } else { Redirector.sendRedirect(req, resp, req.getContextPath() + "/" + removeStart(errorPage, "/"), Collections.<String, List<String>> emptyMap()); } } protected Token reqToken(HttpServletRequest req) { return null; } } public static class OAuthRequestBuilder { public OAuthRequest req(Verb verb, String url) { return new OAuthRequest(verb, url); } } interface RemoteUserProfile { boolean valid(Users users, String provider); String username(); } }