/* * Copyright 2012 SURFnet bv, The Netherlands * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package nl.surfnet.coin.api.client; import nl.surfnet.coin.api.client.domain.Group; import nl.surfnet.coin.api.client.domain.Group20; import nl.surfnet.coin.api.client.domain.Person; import nl.surfnet.coin.api.client.internal.OAuth2Grant; import nl.surfnet.coin.api.client.internal.OpenConextApi20AuthorizationCode; import nl.surfnet.coin.api.client.internal.OpenConextApi20ClientCredentials; import org.apache.commons.io.IOUtils; import org.scribe.builder.ServiceBuilder; import org.scribe.builder.api.Api; import org.scribe.model.*; import org.scribe.oauth.OAuthService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.InitializingBean; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import javax.servlet.http.HttpServletRequest; import java.io.IOException; import java.io.InputStream; import java.util.List; /** * Implementation of OpenConextOAuthClient */ public class OpenConextOAuthClientImpl implements OpenConextOAuthClient, InitializingBean { private static final Logger LOG = LoggerFactory.getLogger(OpenConextOAuthClientImpl.class); private static final int MAX_ACCESS_TOKEN_REQUESTS = 3; private OAuthEnvironment environment; private OAuthRepository repository; private OpenConextJsonParser parser = new OpenConextJsonParser(); public OpenConextOAuthClientImpl() { this.environment = new OAuthEnvironment(); this.repository = new InMemoryOAuthRepositoryImpl(); } @Override public boolean isAccessTokenGranted(String userId) { return repository.getToken(userId) != null; } private String doGetAuthorizationUrl(HttpServletRequest request) { OAuthService service = getService(OAuth2Grant.authorizationCode); return service.getAuthorizationUrl(null); } @Override public String getAuthorizationUrl() { return doGetAuthorizationUrl(null); } @Override public void oauthCallback(HttpServletRequest request, String onBehalfOf) { String oAuthVerifier; Token requestToken = null; oAuthVerifier = request.getParameter("code"); Verifier verifier = new Verifier(oAuthVerifier); OAuthService service = getService(OAuth2Grant.authorizationCode); String accessToken = service.getAccessToken(requestToken, verifier).getToken(); repository.storeToken(accessToken, onBehalfOf); } @Override public Person getPerson(String userId, String onBehalfOf) { OAuthRequest request = new OAuthRequest(Verb.GET, environment.getEndpointBaseUrl() + "social/rest/people/" + userId); InputStream in = execute(onBehalfOf, request); return parser.parsePerson(in).getEntry(); } @Override public List<Person> getGroupMembers(String groupId, String onBehalfOf) { if (!StringUtils.hasText(onBehalfOf)) { throw new IllegalArgumentException( "For retrieving group members the onBehalfOf may not be empty"); } OAuthRequest request = new OAuthRequest(Verb.GET, environment.getEndpointBaseUrl() + "social/rest/people/" + onBehalfOf + "/" + groupId); InputStream in = execute(onBehalfOf, request); return parser.parseTeamMembers(in).getEntry(); } @Override public List<Group> getGroups(String userId, String onBehalfOf) { OAuthRequest request = new OAuthRequest(Verb.GET, environment.getEndpointBaseUrl() + "social/rest/groups/" + userId); InputStream in = execute(onBehalfOf, request); return parser.parseGroups(in).getEntry(); } @Override public List<Group20> getGroups20(String userId, String onBehalfOf) { OAuthRequest request = new OAuthRequest(Verb.GET, environment.getEndpointBaseUrl() + "social/rest/groups/" + userId); InputStream in = execute(onBehalfOf, request); return parser.parseGroups20(in).getEntry(); } @Override public Group20 getGroup20(String userId, String groupId, String onBehalfOf) { final String url = String.format("%ssocial/rest/groups/%s/%s", environment.getEndpointBaseUrl(), userId, groupId); OAuthRequest request = new OAuthRequest(Verb.GET, url); InputStream in = execute(onBehalfOf, request); final List<Group20> entry = parser.parseGroups20(in).getEntry(); if (entry != null && entry.size() > 0) { return entry.get(0); } return null; } private InputStream execute(String onBehalfOf, OAuthRequest request) { String token; OAuthService service; token = repository.getToken(onBehalfOf); if (onBehalfOf == null) { int retries = 0; while (token == null && retries < MAX_ACCESS_TOKEN_REQUESTS) { getClientAccessToken(); token = repository.getToken(onBehalfOf); } service = getService(OAuth2Grant.clientCredentials); } else { if (token == null) { throw new RuntimeException("No access token present for user('" + onBehalfOf + "'). First obtain an accesstoken."); } service = getService(OAuth2Grant.authorizationCode); } service.signRequest(new Token(token, ""), request); if (LOG.isDebugEnabled()) { LOG.debug("Will send request '{}'", request.toString()); } Response oAuthResponse = request.send(); if (oAuthResponse.getCode() >= 400) { if (oAuthResponse.getCode() == 401 && oAuthResponse.getStream() != null && oAuthResponse.getBody().contains("invalid_token")) { repository.removeToken(onBehalfOf); throw new InvalidTokenException(oAuthResponse.getBody()); } else { // This could be refined to include other cases, and throw according exceptions. throw new RuntimeException(String.format("Error response: %d, body: %s", oAuthResponse.getCode(), oAuthResponse.getStream() == null ? null : oAuthResponse.getBody())); } } InputStream stream = oAuthResponse.getStream(); if (LOG.isDebugEnabled()) { stream = logInputStream(stream); } return stream; } private InputStream logInputStream(InputStream stream) { String json; try { json = IOUtils.toString(stream); } catch (IOException e) { throw new RuntimeException(e); } LOG.debug(json); stream = IOUtils.toInputStream(json); return stream; } private OAuthService getService(OAuth2Grant grantType) { String baseUrl = environment.getEndpointBaseUrl(); Api api; api = grantType.equals(OAuth2Grant.clientCredentials) ? new OpenConextApi20ClientCredentials(baseUrl) : new OpenConextApi20AuthorizationCode(baseUrl); return new ServiceBuilder() .provider(api) .apiKey(environment.getOauthKey()) .scope("read") .apiSecret(environment.getOauthSecret()) .callback(environment.getCallbackUrl()).build(); } public void setCallbackUrl(String url) { environment.setCallbackUrl(url); } public void setConsumerSecret(String secret) { environment.setOauthSecret(secret); } public void setConsumerKey(String key) { environment.setOauthKey(key); } public void setEndpointBaseUrl(String url) { environment.setEndpointBaseUrl(url); } public void setVersion(OAuthVersion v) { } @Override public void afterPropertiesSet() throws Exception { Assert.notNull(environment); Assert.notNull(environment.getEndpointBaseUrl(), "endpoint base url cannot be null"); Assert.notNull(repository); } public void setRepository(OAuthRepository repository) { this.repository = repository; } public void getClientAccessToken() { Token accessToken = getService(OAuth2Grant.clientCredentials).getAccessToken(new Token("", ""), new Verifier("")); LOG.debug("Received access token from OAuth 2 server: " + accessToken); repository.storeToken(accessToken.getToken(), null); } }