package io.dropwizard.servlets.tasks; import com.codahale.metrics.MetricRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMultimap; import java.io.StringWriter; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import javax.servlet.ReadListener; import javax.servlet.ServletInputStream; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.PrintWriter; import java.nio.charset.StandardCharsets; import java.util.Collections; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class TaskServletTest { private final Task gc = mock(Task.class); private final PostBodyTask printJSON = mock(PostBodyTask.class); { when(gc.getName()).thenReturn("gc"); when(printJSON.getName()).thenReturn("print-json"); } private final TaskServlet servlet = new TaskServlet(new MetricRegistry()); private final HttpServletRequest request = mock(HttpServletRequest.class); private final HttpServletResponse response = mock(HttpServletResponse.class); @Before public void setUp() throws Exception { servlet.add(gc); servlet.add(printJSON); } @Test public void returnsA404WhenNotFound() throws Exception { when(request.getMethod()).thenReturn("POST"); when(request.getPathInfo()).thenReturn("/test"); servlet.service(request, response); verify(response).sendError(404); } @Test public void runsATaskWhenFound() throws Exception { final PrintWriter output = mock(PrintWriter.class); final ServletInputStream bodyStream = new TestServletInputStream(new ByteArrayInputStream("".getBytes(StandardCharsets.UTF_8))); when(request.getMethod()).thenReturn("POST"); when(request.getPathInfo()).thenReturn("/gc"); when(request.getParameterNames()).thenReturn(Collections.enumeration(ImmutableList.of())); when(response.getWriter()).thenReturn(output); when(request.getInputStream()).thenReturn(bodyStream); servlet.service(request, response); verify(gc).execute(ImmutableMultimap.of(), output); } @Test public void passesQueryStringParamsAlong() throws Exception { final PrintWriter output = mock(PrintWriter.class); final ServletInputStream bodyStream = new TestServletInputStream(new ByteArrayInputStream("".getBytes(StandardCharsets.UTF_8))); when(request.getMethod()).thenReturn("POST"); when(request.getPathInfo()).thenReturn("/gc"); when(request.getParameterNames()).thenReturn(Collections.enumeration(ImmutableList.of("runs"))); when(request.getParameterValues("runs")).thenReturn(new String[]{"1"}); when(request.getInputStream()).thenReturn(bodyStream); when(response.getWriter()).thenReturn(output); servlet.service(request, response); verify(gc).execute(ImmutableMultimap.of("runs", "1"), output); } @Test public void passesPostBodyAlongToPostBodyTasks() throws Exception { String body = "{\"json\": true}"; final PrintWriter output = mock(PrintWriter.class); final ServletInputStream bodyStream = new TestServletInputStream(new ByteArrayInputStream(body.getBytes(StandardCharsets.UTF_8))); when(request.getMethod()).thenReturn("POST"); when(request.getPathInfo()).thenReturn("/print-json"); when(request.getParameterNames()).thenReturn(Collections.enumeration(ImmutableList.of())); when(request.getInputStream()).thenReturn(bodyStream); when(response.getWriter()).thenReturn(output); servlet.service(request, response); verify(printJSON).execute(ImmutableMultimap.of(), body, output); } @Test @SuppressWarnings("unchecked") public void returnsA500OnExceptions() throws Exception { when(request.getMethod()).thenReturn("POST"); when(request.getPathInfo()).thenReturn("/gc"); when(request.getParameterNames()).thenReturn(Collections.enumeration(ImmutableList.of())); final PrintWriter output = mock(PrintWriter.class); when(response.getWriter()).thenReturn(output); final RuntimeException ex = new RuntimeException("whoops"); doThrow(ex).when(gc).execute(any(ImmutableMultimap.class), any(PrintWriter.class)); servlet.service(request, response); verify(response).setStatus(500); } /** * Add a test to make sure the signature of the Task class does not change as the TaskServlet * depends on this to perform record metrics on Tasks */ @Test public void verifyTaskExecuteMethod() { try { Task.class.getMethod("execute", ImmutableMultimap.class, PrintWriter.class); } catch (NoSuchMethodException e) { Assert.fail("Execute method for " + Task.class.getName() + " not found"); } } @Test public void verifyPostBodyTaskExecuteMethod() { try { PostBodyTask.class.getMethod("execute", ImmutableMultimap.class, String.class, PrintWriter.class); } catch (NoSuchMethodException e) { Assert.fail("Execute method for " + PostBodyTask.class.getName() + " not found"); } } @Test public void returnAllTaskNamesLexicallyOnGet() throws Exception { try (StringWriter sw = new StringWriter(); PrintWriter pw = new PrintWriter(sw)) { when(request.getMethod()).thenReturn("GET"); when(request.getPathInfo()).thenReturn(null); when(response.getWriter()).thenReturn(pw); servlet.service(request, response); final String newLine = System.lineSeparator(); assertThat(sw.toString()) .isEqualTo(gc.getName() + newLine + printJSON.getName() + newLine); } } @Test public void returnsA404WhenGettingUnknownTask() throws Exception { when(request.getMethod()).thenReturn("GET"); when(request.getPathInfo()).thenReturn("/absent"); servlet.service(request, response); verify(response).sendError(404); } @Test public void returnsA405WhenGettingTaskByName() throws Exception { when(request.getMethod()).thenReturn("GET"); when(request.getPathInfo()).thenReturn("/gc"); servlet.service(request, response); verify(response).sendError(405); } private static class TestServletInputStream extends ServletInputStream { private InputStream delegate; public TestServletInputStream(InputStream delegate) { this.delegate = delegate; } @Override public boolean isFinished() { return false; } @Override public boolean isReady() { return false; } @Override public void setReadListener(ReadListener readListener) { } @Override public int read() throws IOException { return delegate.read(); } } }