package org.zalando.riptide; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.springframework.core.io.ClassPathResource; import org.springframework.http.HttpStatus; import org.springframework.http.client.ClientHttpResponse; import org.springframework.test.web.client.MockRestServiceServer; import org.zalando.problem.MoreStatus; import org.zalando.problem.ThrowableProblem; import org.zalando.riptide.model.Message; import org.zalando.riptide.model.Problem; import org.zalando.riptide.model.Success; import java.io.IOException; import java.net.URI; import java.util.List; import java.util.concurrent.CompletionException; import java.util.concurrent.atomic.AtomicReference; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hobsoft.hamcrest.compose.ComposeMatchers.hasFeature; import static org.junit.Assert.assertThat; import static org.springframework.http.HttpStatus.ACCEPTED; import static org.springframework.http.HttpStatus.CREATED; import static org.springframework.http.HttpStatus.MOVED_PERMANENTLY; import static org.springframework.http.HttpStatus.Series.CLIENT_ERROR; import static org.springframework.http.HttpStatus.Series.SERVER_ERROR; import static org.springframework.http.HttpStatus.Series.SUCCESSFUL; import static org.springframework.http.HttpStatus.UNAUTHORIZED; import static org.springframework.http.HttpStatus.UNPROCESSABLE_ENTITY; import static org.springframework.http.MediaType.parseMediaType; import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; import static org.springframework.test.web.client.response.MockRestResponseCreators.withStatus; import static org.zalando.riptide.Bindings.anyContentType; import static org.zalando.riptide.Bindings.anySeries; import static org.zalando.riptide.Bindings.anyStatus; import static org.zalando.riptide.Bindings.anyStatusCode; import static org.zalando.riptide.Bindings.on; import static org.zalando.riptide.Navigators.contentType; import static org.zalando.riptide.Navigators.series; import static org.zalando.riptide.Navigators.status; import static org.zalando.riptide.Navigators.statusCode; import static org.zalando.riptide.Route.propagate; import static org.zalando.riptide.RoutingTree.dispatch; import static org.zalando.riptide.model.MediaTypes.ERROR; import static org.zalando.riptide.model.MediaTypes.PROBLEM; public final class NestedDispatchTest { @Rule public final ExpectedException exception = ExpectedException.none(); private final URI url = URI.create("http://localhost"); private final Rest unit; private final MockRestServiceServer server; public NestedDispatchTest() { final MockSetup setup = new MockSetup(); this.unit = setup.getRest(); this.server = setup.getServer(); } private <T> T perform(final Class<T> type) { final AtomicReference<Object> capture = new AtomicReference<>(); unit.get(url) .dispatch(series(), on(SUCCESSFUL) .dispatch(status(), on(CREATED).dispatch(contentType(), on(parseMediaType("application/messages+json")).call(Route.listOf(Message.class), capture::set), anyContentType().call(Success.class, capture::set)), on(ACCEPTED).call(Success.class, capture::set), anyStatus().call(this::fail)), on(CLIENT_ERROR) .dispatch(status(), on(UNAUTHORIZED).call(capture::set), anyStatus().call(problemHandling())), on(SERVER_ERROR) .dispatch(statusCode(), on(500).call(capture::set), on(503).call(capture::set), anyStatusCode().call(this::fail)), anySeries().call(this::fail)) .join(); return type.cast(capture.get()); } private Route problemHandling() { return dispatch(contentType(), on(PROBLEM).call(ThrowableProblem.class, propagate()), on(ERROR).call(ThrowableProblem.class, propagate()), anyContentType().call(this::fail)); } @SuppressWarnings("serial") private static final class Failure extends RuntimeException { private final HttpStatus status; private Failure(final HttpStatus status) { this.status = status; } public HttpStatus getStatus() { return status; } } private void fail(final ClientHttpResponse response) throws IOException { throw new Failure(response.getStatusCode()); } @Test public void shouldDispatchLevelOne() { server.expect(requestTo(url)).andRespond(withStatus(MOVED_PERMANENTLY)); exception.expect(CompletionException.class); exception.expectCause(instanceOf(Failure.class)); exception.expectCause(hasFeature("status", Failure::getStatus, equalTo(MOVED_PERMANENTLY))); perform(Void.class); } @Test public void shouldDispatchLevelTwo() { server.expect(requestTo(url)).andRespond( withStatus(CREATED) .body(new ClassPathResource("messages.json")) .contentType(parseMediaType("application/messages+json"))); @SuppressWarnings("unchecked") final List<Message> messages = perform(List.class); assertThat(messages.get(0).getMessage(), is("Foo")); assertThat(messages.get(1).getMessage(), is("Bar")); } @Test public void shouldDispatchLevelThree() { server.expect(requestTo(url)).andRespond( withStatus(UNPROCESSABLE_ENTITY) .body(new ClassPathResource("problem.json")) .contentType(ERROR)); try { perform(Problem.class); Assert.fail("Expected exception"); } catch (final CompletionException e) { assertThat(e.getCause(), is(instanceOf(ThrowableProblem.class))); final ThrowableProblem problem = (ThrowableProblem) e.getCause(); assertThat(problem.getType(), is(URI.create("http://httpstatus.es/422"))); assertThat(problem.getTitle(), is("Unprocessable Entity")); assertThat(problem.getStatus(), is(MoreStatus.UNPROCESSABLE_ENTITY)); assertThat(problem.getDetail(), is("A problem occurred.")); } } }