/**
* This file is part of lavagna.
*
* lavagna is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* lavagna is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with lavagna. If not, see <http://www.gnu.org/licenses/>.
*/
package io.lavagna.web.security.login;
import io.lavagna.web.security.LoginHandler.AbstractLoginHandler;
import io.lavagna.web.security.SecurityConfiguration.SessionHandler;
import io.lavagna.web.security.SecurityConfiguration.Users;
import io.lavagna.web.security.login.oauth.*;
import io.lavagna.web.security.login.oauth.OAuthResultHandler.OAuthRequestBuilder;
import org.scribe.builder.ServiceBuilder;
import org.springframework.util.StringUtils;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.*;
public class OAuthLogin extends AbstractLoginHandler {
private static final Map<String, OAuthResultHandlerFactory> SUPPORTED_OAUTH_HANDLER;
private static final String USER_PROVIDER = "oauth";
static {
Map<String, OAuthResultHandlerFactory> r = new LinkedHashMap<>();
//TODO: move the strings directly in the factory.
r.put("bitbucket", BitbucketHandler.FACTORY);
r.put("gitlab", GitlabHandler.FACTORY);
r.put("github", GithubHandler.FACTORY);
r.put("google", GoogleHandler.FACTORY);
r.put("twitter", TwitterHandler.FACTORY);
SUPPORTED_OAUTH_HANDLER = Collections.unmodifiableMap(r);
}
private final OauthConfigurationFetcher oauthConfigurationFetcher;
private final String errorPage;
private final ServiceBuilder serviceBuilder;
private final OAuthRequestBuilder reqBuilder = new OAuthRequestBuilder();
public OAuthLogin(Users users, SessionHandler sessionHandler, OauthConfigurationFetcher oauthConfigurationFetcher, ServiceBuilder serviceBuilder, String errorPage) {
super(users, sessionHandler);
this.oauthConfigurationFetcher = oauthConfigurationFetcher;
this.serviceBuilder = serviceBuilder;
this.errorPage = errorPage;
}
@Override
public boolean doAction(HttpServletRequest req, HttpServletResponse resp) throws IOException {
OAuthConfiguration conf = oauthConfigurationFetcher.fetch();
String requestURI = req.getRequestURI();
if ("POST".equals(req.getMethod())) {
OAuthProvider authHandler = conf.matchAuthorization(requestURI);
if (authHandler != null) {
from(authHandler, conf.baseUrl, users, sessionHandler, errorPage).handleAuthorizationUrl(req, resp);
return true;
}
}
OAuthProvider callbackHandler = conf.matchCallback(requestURI);
if (callbackHandler != null) {
from(callbackHandler, conf.baseUrl, users, sessionHandler, errorPage).handleCallback(req, resp);
return true;
}
return false;
}
@Override
public Map<String, Object> modelForLoginPage(HttpServletRequest request) {
Map<String, Object> m = super.modelForLoginPage(request);
OAuthConfiguration conf = oauthConfigurationFetcher.fetch();
if (conf == null) {
return m;
}
List<String> loginOauthProviders = new ArrayList<>();
for (String p : getAllHandlers().keySet()) {
if (conf.hasProvider(p)) {
loginOauthProviders.add(p);
}
}
m.put("loginOauthProviders", loginOauthProviders);
m.put("loginOauth", "block");
return m;
}
public static class OAuthConfiguration {
private final String baseUrl;
private final List<OAuthProvider> providers;
public OAuthConfiguration(String baseUrl, List<OAuthProvider> providers) {
this.baseUrl = baseUrl;
this.providers = providers;
}
public boolean hasProvider(String provider) {
for (OAuthProvider o : providers) {
if (provider.equals(o.getProvider())) {
return true;
}
}
return false;
}
public OAuthProvider getProviderWithName(String provider) {
for (OAuthProvider o : providers) {
if (provider.equals(o.getProvider())) {
return o;
}
}
return null;
}
public OAuthProvider matchAuthorization(String requestURI) {
for (OAuthProvider o : providers) {
if (o.matchAuthorization(requestURI)) {
return o;
}
}
return null;
}
public OAuthProvider matchCallback(String requestURI) {
for (OAuthProvider o : providers) {
if (o.matchCallback(requestURI)) {
return o;
}
}
return null;
}
}
public interface OauthConfigurationFetcher {
/**
* Can return null.
*
* @return
*/
OAuthConfiguration fetch();
}
public OAuthResultHandler from(OAuthProvider oauthProvider, String confBaseUrl, Users users, SessionHandler sessionHandler, String errorPage) {
String baseUrl = StringUtils.trimTrailingCharacter(confBaseUrl, '/');
String callbackUrl = baseUrl + "/login/oauth/"+ oauthProvider.getProvider() + "/callback";
Map<String, OAuthResultHandlerFactory> handlers = getAllHandlers();
if (handlers.containsKey(oauthProvider.getProvider())) {
return handlers.get(oauthProvider.getProvider()).build(serviceBuilder, reqBuilder, oauthProvider, callbackUrl, users, sessionHandler, errorPage);
} else {
throw new IllegalArgumentException("type " + oauthProvider.getProvider() + " is not supported");
}
}
public Map<String, OAuthResultHandlerFactory> getAllHandlers() {
Map<String, OAuthResultHandlerFactory> res = new HashMap<>(SUPPORTED_OAUTH_HANDLER);
OAuthConfiguration conf = oauthConfigurationFetcher.fetch();
if(conf != null && conf.providers != null) {
for(OAuthProvider provider : conf.providers) {
if(provider.isHasCustomBaseAndProfileUrl()) {
res.put(provider.getProvider(), new CustomOAuthResultHandlerFactory(SUPPORTED_OAUTH_HANDLER.get(provider.getBaseProvider())));
}
}
}
return res;
}
private static class CustomOAuthResultHandlerFactory implements OAuthResultHandlerFactory {
private final OAuthResultHandlerFactory factory;
private CustomOAuthResultHandlerFactory(OAuthResultHandlerFactory factory) {
this.factory = factory;
}
@Override
public OAuthResultHandler build(ServiceBuilder serviceBuilder,
OAuthRequestBuilder reqBuilder, OAuthProvider oauthProvider,
String callback, Users users, SessionHandler sessionHandler,
String errorPage) {
return factory.build(serviceBuilder, reqBuilder, oauthProvider, callback, users, sessionHandler, errorPage);
}
@Override
public boolean hasConfigurableBaseUrl() {
return factory.hasConfigurableBaseUrl();
}
@Override
public boolean isConfigurableInstance() {
return true;
}
}
@Override
public List<String> getAllHandlerNames() {
List<String> res = new ArrayList<>();
for (String sub : getAllHandlers().keySet()) {
res.add(USER_PROVIDER + "." + sub);
}
return res;
}
@Override
public String getBaseProviderName() {
return USER_PROVIDER;
}
}