/*
* Copyright (c) 2016 OBiBa. All rights reserved.
*
* This program and the accompanying materials
* are made available under the terms of the GNU Public License v3.0.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.obiba.shiro.realm;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import org.apache.http.client.HttpClient;
import org.apache.http.conn.ssl.NoopHostnameVerifier;
import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.shiro.SecurityUtils;
import org.apache.shiro.authc.AccountException;
import org.apache.shiro.authc.AuthenticationException;
import org.apache.shiro.authc.AuthenticationInfo;
import org.apache.shiro.authc.AuthenticationToken;
import org.apache.shiro.authc.SimpleAuthenticationInfo;
import org.apache.shiro.authc.UsernamePasswordToken;
import org.apache.shiro.authc.credential.AllowAllCredentialsMatcher;
import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.authz.SimpleAuthorizationInfo;
import org.apache.shiro.codec.Base64;
import org.apache.shiro.realm.AuthorizingRealm;
import org.apache.shiro.subject.PrincipalCollection;
import org.apache.shiro.subject.SimplePrincipalCollection;
import org.obiba.shiro.authc.TicketAuthenticationToken;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.client.ResourceAccessException;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Strings;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Header;
import io.jsonwebtoken.Jwt;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.MalformedJwtException;
import static java.net.URLEncoder.encode;
/**
* A realm for the CAS-like implementation protocol by Obiba.
*/
public class ObibaRealm extends AuthorizingRealm {
private final static Logger log = LoggerFactory.getLogger(ObibaRealm.class);
public static final String OBIBA_REALM = "obiba-realm";
public static final String TICKET_COOKIE_NAME = "obibaid";
public static final String APPLICATION_AUTH_HEADER = "X-App-Auth";
public static final String APPLICATION_AUTH_SCHEMA = "Basic";
public static final String DEFAULT_REST_PREFIX = "/ws";
public static final String DEFAULT_LOGIN_PATH = "/tickets";
public static final String DEFAULT_TICKET_PATH = "/ticket/{id}";
public static final String DEFAULT_VALIDATE_PATH = DEFAULT_TICKET_PATH + "/username";
public static final String DEFAULT_SUBJECT_PATH = DEFAULT_TICKET_PATH + "/subject";
private static final String SET_COOKIE_HEADER = "Set-Cookie";
private static final int DEFAULT_HTTPS_PORT = 443;
private HttpComponentsClientHttpRequestFactory httpRequestFactory;
private String baseUrl = "https://localhost:8444";
private String serviceName;
private String serviceKey;
public ObibaRealm() {
super(null, new AllowAllCredentialsMatcher());
}
@Override
public boolean supports(AuthenticationToken token) {
return token != null && (UsernamePasswordToken.class.isAssignableFrom(token.getClass()) ||
TicketAuthenticationToken.class.isAssignableFrom(token.getClass()));
}
@Override
protected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken token) throws AuthenticationException {
if(UsernamePasswordToken.class.isAssignableFrom(token.getClass()))
return doGetUsernameAuthenticationInfo((UsernamePasswordToken) token);
else return doGetTicketAuthenticationInfo((TicketAuthenticationToken) token);
}
private synchronized AuthenticationInfo doGetUsernameAuthenticationInfo(UsernamePasswordToken token)
throws AuthenticationException {
String username = token.getUsername();
// Null username is invalid
if(Strings.isNullOrEmpty(username)) {
throw new AccountException("Empty usernames are not allowed by this realm.");
}
try {
RestTemplate template = newRestTemplate();
HttpHeaders headers = new HttpHeaders();
headers.set(APPLICATION_AUTH_HEADER, getApplicationAuth());
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
String form = "username=" + encode(username, "UTF-8") + "&password=" + encode(new String(token.getPassword()), "UTF-8");
HttpEntity<String> entity = new HttpEntity<String>(form, headers);
ResponseEntity<String> response = template.exchange(getLoginUrl(token), HttpMethod.POST, entity, String.class);
if (response.getStatusCode() == HttpStatus.CREATED) {
HttpHeaders responseHeaders = response.getHeaders();
String ticketId = getTicketIdFromHeaders(responseHeaders);
SecurityUtils.getSubject().getSession().setAttribute(TICKET_COOKIE_NAME, ticketId);
List<String> principals = Lists.newArrayList(username);
if (!Strings.isNullOrEmpty(ticketId)) principals.add(ticketId);
return new SimpleAuthenticationInfo(new SimplePrincipalCollection(principals, getName()), token.getCredentials());
}
// not an account in this realm
log.debug("Invalid credentials. Response status code [{}], response body [{}], credentials used [{}]", response.getStatusCode(), response.getBody(), token);
return null;
} catch(HttpClientErrorException e) {
if (HttpStatus.FORBIDDEN.equals(e.getStatusCode())) {
log.debug("Invalid credentials. Response status code [{}], response body [{}], credentials used [{}]", e.getStatusCode(), e.getResponseBodyAsString(), token);
return null;
}
if (log.isDebugEnabled())
log.error("Connection failure with identification server", e);
else
log.error(String.format("Connection failure with identification server: [%s]", e.getMessage()));
return null;
} catch(ResourceAccessException e) {
if (log.isDebugEnabled())
log.error("Connection failure with identification server", e);
else
log.error(String.format("Connection failure with identification server: [%s]", e.getMessage()));
return null;
} catch(Exception e) {
if (log.isDebugEnabled())
log.error("Authentication failure", e);
else
log.error(String.format("Authentication failure: [%s]", e.getMessage()));
throw new AuthenticationException("Failed authenticating on " + baseUrl, e);
}
}
private synchronized AuthenticationInfo doGetTicketAuthenticationInfo(TicketAuthenticationToken token)
throws AuthenticationException {
// Null ticket id is invalid
if(Strings.isNullOrEmpty(token.getTicketId())) {
throw new AccountException("Empty tickets are not allowed by this realm.");
}
try {
RestTemplate template = newRestTemplate();
HttpHeaders headers = new HttpHeaders();
headers.set(APPLICATION_AUTH_HEADER, getApplicationAuth());
HttpEntity<String> entity = new HttpEntity<String>(null, headers);
ResponseEntity<String> response = template.exchange(getValidateUrl(token.getTicketId()), HttpMethod.GET, entity, String.class);
if(response.getStatusCode() == HttpStatus.OK) {
HttpHeaders responseHeaders = response.getHeaders();
String ticketId = getTicketIdFromHeaders(responseHeaders);
SecurityUtils.getSubject().getSession().setAttribute(TICKET_COOKIE_NAME, ticketId);
List<String> principals = Lists.newArrayList(response.getBody());
if(!Strings.isNullOrEmpty(ticketId)) principals.add(ticketId);
return new SimpleAuthenticationInfo(new SimplePrincipalCollection(principals, getName()),token.getCredentials());
}
// not an account in this realm
log.info("Invalid ticket. Response status code [{}], response body [{}], ticket used [{}]", response.getStatusCode(), response.getBody(), token);
return null;
} catch(HttpClientErrorException|ResourceAccessException e) {
log.error(String.format("Impossible to contact identification server: [%s]", e.getMessage()), e);
return null;
} catch(Exception e) {
throw new AuthenticationException("Failed authenticating on " + baseUrl, e);
}
}
private String getTicketIdFromHeaders(HttpHeaders responseHeaders) {
String ticketId = null;
for(String cookieValue : responseHeaders.get(SET_COOKIE_HEADER)) {
if(cookieValue.startsWith(TICKET_COOKIE_NAME + "=")) {
// set in the subject's session the cookie that will allow to perform the single sign-on
SecurityUtils.getSubject().getSession().setAttribute(SET_COOKIE_HEADER, cookieValue);
// keep ticket reference for logout
ticketId = cookieValue.split(";")[0].substring(TICKET_COOKIE_NAME.length() + 1);
}
}
return ticketId;
}
@Override
protected synchronized AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection principals) {
Collection<?> thisPrincipals = principals.fromRealm(getName());
if(thisPrincipals != null && !thisPrincipals.isEmpty()) {
try {
Jwt<Header, Claims> webToken = getWebTokenFromPrincipals(thisPrincipals);
if(webToken != null) {
TicketContextUser user = new ObjectMapper()
.convertValue(webToken.getBody().get("context", Map.class).get("user"),
TicketContextUser.class);
return new SimpleAuthorizationInfo(Sets.newHashSet(user.getGroups()));
} else { //backward compatibility. web token not found in principals.
RestTemplate template = newRestTemplate();
HttpHeaders headers = new HttpHeaders();
headers.set(APPLICATION_AUTH_HEADER, getApplicationAuth());
HttpEntity<String> entity = new HttpEntity<String>(null, headers);
ResponseEntity<Subject> response = template
.exchange(getSubjectUrl(getTicketFromSession()), HttpMethod.GET, entity, Subject.class);
if(response.getStatusCode().equals(HttpStatus.OK) && response.getBody().groups != null) {
return new SimpleAuthorizationInfo(Sets.newHashSet(response.getBody().groups));
}
}
} catch(HttpClientErrorException e) {
return new SimpleAuthorizationInfo();
} catch(Exception e) {
throw new AuthenticationException("Failed authorizing on " + baseUrl, e);
}
}
return new SimpleAuthorizationInfo();
}
private Jwt<Header, Claims> getWebTokenFromPrincipals(Collection<?> principals) {
for(Object principal : principals) {
try {
String[] webTokenParts = ((String) principal).split("\\.");
if(webTokenParts.length > 1) {
String webToken = String.format("%s.%s.", webTokenParts[0], webTokenParts[1]); //do not validate signature
return Jwts.parser().parse(webToken);
}
} catch(MalformedJwtException e) {
}
}
return null;
}
@Override
public void onLogout(PrincipalCollection principals) {
if (principals.getRealmNames().contains(OBIBA_REALM)) {
cleanTicket();
}
super.onLogout(principals);
}
private void cleanTicket() {
try {
String ticketId = getTicketFromSession();
if (ticketId != null) {
log.debug("Deleting ticket: {}", ticketId);
RestTemplate template = newRestTemplate();
HttpHeaders headers = new HttpHeaders();
headers.set(APPLICATION_AUTH_HEADER, getApplicationAuth());
HttpEntity<String> entity = new HttpEntity<String>(null, headers);
template.exchange(getTicketUrl(ticketId), HttpMethod.DELETE, entity, String.class);
}
} catch(Exception e) {
log.warn("Unable to clean Obiba session: " + e.getMessage(), e);
}
}
/**
* Extract ticket reference from the shiro session.
* @return null if not found
*/
@Nullable
private String getTicketFromSession() {
Object cookie = SecurityUtils.getSubject().getSession().getAttribute(TICKET_COOKIE_NAME);
return cookie != null && !Strings.isNullOrEmpty(cookie.toString()) ? cookie.toString() : null;
}
/**
* Base url of Agate application.
*
* @param baseUrl
*/
public void setBaseUrl(String baseUrl) {
if(baseUrl.endsWith("/")) {
this.baseUrl = baseUrl.substring(0, baseUrl.length() - 1);
} else {
this.baseUrl = baseUrl;
}
}
/**
* Service name issuing credentials requests.
*
* @param serviceName
*/
public void setServiceName(String serviceName) {
this.serviceName = serviceName;
}
/**
* Service key issuing credentials requests.
*
* @param serviceKey
*/
public void setServiceKey(String serviceKey) {
this.serviceKey = serviceKey;
}
@Override
public String getName() {
return OBIBA_REALM;
}
private RestTemplate newRestTemplate() {
log.debug("Connecting to Agate: {}", baseUrl);
if (baseUrl.toLowerCase().startsWith("https://")) {
if(httpRequestFactory == null) {
httpRequestFactory = new HttpComponentsClientHttpRequestFactory(createHttpClient());
}
return new RestTemplate(httpRequestFactory);
} else {
return new RestTemplate();
}
}
private HttpClient createHttpClient() {
HttpClientBuilder builder = HttpClientBuilder.create();
try {
builder.setSSLSocketFactory(getSocketFactory());
} catch(NoSuchAlgorithmException | KeyManagementException e) {
throw new RuntimeException(e);
}
return builder.build();
}
/**
* Do not check anything from the remote host (Agate server is trusted).
* @return
* @throws NoSuchAlgorithmException
* @throws KeyManagementException
*/
private SSLConnectionSocketFactory getSocketFactory() throws NoSuchAlgorithmException, KeyManagementException {
// Accepts any SSL certificate
TrustManager tm = new X509TrustManager() {
@Override
public void checkClientTrusted(X509Certificate[] arg0, String arg1) throws CertificateException {
}
@Override
public void checkServerTrusted(X509Certificate[] arg0, String arg1) throws CertificateException {
}
@Override
public X509Certificate[] getAcceptedIssuers() {
return null;
}
};
SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(null, new TrustManager[] { tm }, null);
return new SSLConnectionSocketFactory(sslContext, new NoopHostnameVerifier());
}
private String getLoginUrl(UsernamePasswordToken token) {
UriComponentsBuilder builder = UriComponentsBuilder.fromHttpUrl(baseUrl).path(DEFAULT_REST_PREFIX)
.path(DEFAULT_LOGIN_PATH);
builder.queryParam("rememberMe", token.isRememberMe());
return builder.build().toUriString();
}
private String getValidateUrl(String ticket) {
UriComponentsBuilder builder = UriComponentsBuilder.fromHttpUrl(baseUrl).path(DEFAULT_REST_PREFIX)
.path(DEFAULT_VALIDATE_PATH);
return builder.buildAndExpand(ticket).toUriString();
}
private String getSubjectUrl(String ticketId) {
UriComponentsBuilder builder = UriComponentsBuilder.fromHttpUrl(baseUrl).path(DEFAULT_REST_PREFIX)
.path(DEFAULT_SUBJECT_PATH);
return builder.buildAndExpand(ticketId).toUriString();
}
private String getTicketUrl(String id) {
UriComponentsBuilder builder = UriComponentsBuilder.fromHttpUrl(baseUrl).path(DEFAULT_REST_PREFIX)
.path(DEFAULT_TICKET_PATH);
return builder.buildAndExpand(id).toUriString();
}
private String getApplicationAuth() {
String token = serviceName + ":" + serviceKey;
return APPLICATION_AUTH_SCHEMA + " " + Base64.encodeToString(token.getBytes());
}
public static class Subject {
private String username;
private List<String> groups;
private List<Map<String, String>> attributes;
public String getUsername() {
return username;
}
public void setUsername(String username) {
this.username = username;
}
public List<String> getGroups() {
return groups;
}
public void setGroups(List<String> groups) {
this.groups = groups;
}
public List<Map<String, String>> getAttributes() {
return attributes;
}
public void setAttributes(List<Map<String, String>> attributes) {
this.attributes = attributes;
}
}
private static class TicketContextUser {
private List<String> groups;
private String name;
@JsonProperty("first_name")
private String firstName;
@JsonProperty("last_name")
private String lastName;
private String locale;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getFirstName() {
return firstName;
}
public void setFirstName(String firstName) {
this.firstName = firstName;
}
public String getLastName() {
return lastName;
}
public void setLastName(String lastName) {
this.lastName = lastName;
}
public String getLocale() {
return locale;
}
public void setLocale(String locale) {
this.locale = locale;
}
public List<String> getGroups() {
return groups;
}
public void setGroups(List<String> groups) {
this.groups = groups;
}
}
}