package org.cloudfoundry.identity.uaa.login.saml;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import java.io.IOException;
import java.sql.Timestamp;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED;
import static org.springframework.http.MediaType.APPLICATION_JSON;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
import org.apache.commons.codec.binary.Base64;
import org.cloudfoundry.identity.uaa.authentication.Origin;
import org.cloudfoundry.identity.uaa.authentication.UaaPrincipal;
import org.cloudfoundry.identity.uaa.client.SocialClientUserDetails;
import org.cloudfoundry.identity.uaa.codestore.ExpiringCode;
import org.cloudfoundry.identity.uaa.config.YamlServletProfileInitializer;
import org.cloudfoundry.identity.uaa.login.PasscodeAuthenticationFilter;
import org.cloudfoundry.identity.uaa.login.PasscodeInformation;
import org.cloudfoundry.identity.uaa.login.SamlRemoteUaaController;
import org.cloudfoundry.identity.uaa.security.web.UaaRequestMatcher;
import org.codehaus.jackson.map.ObjectMapper;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.mock.env.MockEnvironment;
import org.springframework.mock.web.MockHttpSession;
import org.springframework.mock.web.MockServletConfig;
import org.springframework.mock.web.MockServletContext;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.providers.ExpiringUsernameAuthenticationToken;
import org.springframework.security.web.DefaultSecurityFilterChain;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.context.support.XmlWebApplicationContext;
import org.springframework.web.filter.GenericFilterBean;
public class PasscodeMockMvcTests {
private XmlWebApplicationContext webApplicationContext;
private MockMvc mockMvc;
private CaptureSecurityContextFilter captureSecurityContextFilter;
private static String USERNAME = "marissa@saml.test.pivotal.io";
private static String PASSWORD = "foo";
private static String ALIAS = "testalias";
@After
public void tearDown() throws Exception {
webApplicationContext.destroy();
}
@Before
public void setUp() throws Exception {
MockEnvironment environment = new MockEnvironment();
MockServletContext context = new MockServletContext();
MockServletConfig config = new MockServletConfig(context);
config.addInitParameter("environmentConfigDefaults", "login.yml");
webApplicationContext = new XmlWebApplicationContext();
webApplicationContext.setServletConfig(config);
webApplicationContext.setEnvironment(environment);
webApplicationContext.setConfigLocation("file:./src/main/webapp/WEB-INF/spring-servlet.xml");
new YamlServletProfileInitializer().initialize(webApplicationContext);
webApplicationContext.refresh();
FilterChainProxy springSecurityFilterChain = (FilterChainProxy)webApplicationContext.getBean("org.springframework.security.filterChainProxy");
mockMvc = MockMvcBuilders.webAppContextSetup(webApplicationContext).addFilter(springSecurityFilterChain)
.build();
captureSecurityContextFilter = new CaptureSecurityContextFilter();
List<SecurityFilterChain> chains = springSecurityFilterChain.getFilterChains();
for (SecurityFilterChain chain : chains) {
if (chain instanceof DefaultSecurityFilterChain) {
DefaultSecurityFilterChain dfc = (DefaultSecurityFilterChain)chain;
if (dfc.getRequestMatcher() instanceof UaaRequestMatcher) {
UaaRequestMatcher matcher = (UaaRequestMatcher)dfc.getRequestMatcher();
if (matcher.toString().contains("passcodeTokenMatcher")) {
dfc.getFilters().add(captureSecurityContextFilter);
break;
}
}
}
}
RestTemplate restTemplate = mock(RestTemplate.class);
when(
restTemplate.exchange(
anyString(),
eq(HttpMethod.POST),
any(HttpEntity.class),
eq(ExpiringCode.class)
)
).thenReturn(
new ResponseEntity<>(
new ExpiringCode("test",
new Timestamp(System.currentTimeMillis()),
"data"),
HttpStatus.CREATED)
);
PasscodeInformation pi = new PasscodeInformation("user_id", "username", "passcode", Origin.ORIGIN, (Map)null);
when(
restTemplate.exchange(
anyString(),
eq(HttpMethod.GET),
any(HttpEntity.class),
eq(ExpiringCode.class)
)
).thenReturn(
new ResponseEntity<>(
new ExpiringCode("test",
new Timestamp(System.currentTimeMillis()),
new ObjectMapper().writeValueAsString(pi)),
HttpStatus.OK)
);
when(
restTemplate.exchange(
anyString(),
any(HttpMethod.class),
any(HttpEntity.class),
eq(byte[].class)
)
).thenReturn(
new ResponseEntity<>(
"{\"access_token\": test}".getBytes(),
HttpStatus.OK)
);
SamlRemoteUaaController controller = webApplicationContext.getBean(SamlRemoteUaaController.class);
controller.setAuthorizationTemplate(restTemplate);
PasscodeAuthenticationFilter pcFilter = webApplicationContext.getBean(PasscodeAuthenticationFilter.class);
pcFilter.setAuthorizationTemplate(restTemplate);
}
@Test
public void testLoginUsingPasscodeWithSamlToken() throws Exception {
UaaPrincipal p = new UaaPrincipal("123","marissa","marissa@test.org", Origin.UAA,"");
ExpiringUsernameAuthenticationToken et = new ExpiringUsernameAuthenticationToken(USERNAME, null);
LoginSamlAuthenticationToken auth = new LoginSamlAuthenticationToken(et, ALIAS);
final MockSecurityContext mockSecurityContext = new MockSecurityContext(auth);
SecurityContextHolder.setContext(mockSecurityContext);
MockHttpSession session = new MockHttpSession();
session.setAttribute(
HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY,
mockSecurityContext
);
String passcode = "";
MockHttpServletRequestBuilder get = get("/passcode")
.accept(APPLICATION_JSON)
.session(session);
mockMvc.perform(get)
.andExpect(status().isOk())
.andExpect(content().string("\"test\""));
mockSecurityContext.setAuthentication(null);
session = new MockHttpSession();
session.setAttribute(
HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY,
mockSecurityContext
);
String basicDigestHeaderValue = "Basic " + new String(Base64.encodeBase64(("cf:").getBytes()));
MockHttpServletRequestBuilder post = post("/oauth/token")
.accept(APPLICATION_JSON)
.contentType(APPLICATION_FORM_URLENCODED)
.header("Authorization", basicDigestHeaderValue)
.param("grant_type", "password")
.param("passcode", passcode)
.param("response_type", "token")
.session(session);
mockMvc.perform(post)
.andExpect(status().isOk())
.andExpect(content().string("{\"access_token\": test}"));
Authentication authentication = captureSecurityContextFilter.getAuthentication();
assertNotNull(authentication);
assertTrue(authentication instanceof UsernamePasswordAuthenticationToken);
assertTrue(authentication.getPrincipal() instanceof SocialClientUserDetails);
SocialClientUserDetails details = (SocialClientUserDetails)authentication.getPrincipal();
assertEquals(Origin.ORIGIN, details.getSource());
}
public static class MockSecurityContext implements SecurityContext {
private static final long serialVersionUID = -1386535243513362694L;
private Authentication authentication;
public MockSecurityContext(Authentication authentication) {
this.authentication = authentication;
}
@Override
public Authentication getAuthentication() {
return this.authentication;
}
@Override
public void setAuthentication(Authentication authentication) {
this.authentication = authentication;
}
}
public static class CaptureSecurityContextFilter extends GenericFilterBean {
private Authentication authentication;
public Authentication getAuthentication() {
return authentication;
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
authentication = SecurityContextHolder.getContext().getAuthentication();
chain.doFilter(request, response);
}
}
}