package com.yammer.telemetry.agent.handlers; import com.google.common.base.Optional; import com.yammer.telemetry.agent.test.SimpleServlet; import com.yammer.telemetry.test.TransformedTest; import com.yammer.telemetry.tracing.*; import javassist.ClassPool; import javassist.CtClass; import org.junit.After; import org.junit.Test; import javax.servlet.*; import javax.servlet.http.*; import java.io.*; import java.math.BigInteger; import java.security.Principal; import java.util.*; import static com.yammer.telemetry.test.TelemetryTestHelpers.runTransformed; import static org.junit.Assert.*; public class HttpServletClassHandlerTest { private HttpServletClassHandler handler = new HttpServletClassHandler(); @After public void clearSpanSinkRegistry() { SpanSinkRegistry.clear(); } @Test public void testNothingForNonHttpServletClasses() throws Exception { ClassPool cp = ClassPool.getDefault(); CtClass ctClass = cp.get("java.lang.String"); assertFalse(handler.transformed(ctClass, cp)); } @Test public void testTransformsHttpServletClasses() throws Exception { ClassPool cp = ClassPool.getDefault(); CtClass ctClass = cp.get("javax.servlet.http.HttpServlet"); assertTrue(handler.transformed(ctClass, cp)); } @Test public void testTransformsHttpServletSubclassesThatOverrideService() throws Exception { ClassPool cp = ClassPool.getDefault(); CtClass ctClass = cp.get("com.sun.jersey.spi.container.servlet.ServletContainer"); assertTrue(handler.transformed(ctClass, cp)); } @Test public void testNothingForHttpServletSubclassesWithoutServiceMethodOverride() throws Exception { ClassPool cp = ClassPool.getDefault(); CtClass ctClass = cp.get("com.yammer.dropwizard.tasks.TaskServlet"); assertFalse(handler.transformed(ctClass, cp)); } @Test public void testRunTransformedTests() throws Exception { runTransformed(TransformedTests.class, handler); } @SuppressWarnings("UnusedDeclaration") /** * This provides static methods which get invoked within the transformed classloader context. This means we can * largely just write code for tests as we would. Right now mockito doesn't play happily in this environment * however so instead rely on fake objects, defined below. */ public static class TransformedTests { @TransformedTest public static void testBaseBehaviour() throws Exception { InMemorySpanSinkSource sink = new InMemorySpanSinkSource(); Annotations.setServiceAnnotations(new ServiceAnnotations("testing")); SpanSinkRegistry.register(sink); StringWriter underlyingWriter = new StringWriter(); HttpServletRequest request = new FakeHttpServletRequest("GET", "http://localhost:8080/foo"); HttpServletResponse response = new FakeHttpServletResponse(underlyingWriter); SimpleServlet servlet = new SimpleServlet(); servlet.service(request, response); assertEquals("foof", underlyingWriter.toString()); Collection<Trace> traces = sink.getTraces(); assertEquals(1, traces.size()); Trace trace = traces.iterator().next(); SpanData root = trace.getRoot(); assertNotNull(root); List<AnnotationData> annotations = trace.getAnnotations(root.getSpanId()); assertEquals(3, annotations.size()); assertEquals(AnnotationNames.SERVER_RECEIVED, annotations.get(0).getName()); assertNull(annotations.get(0).getMessage()); assertEquals(AnnotationNames.SERVICE_NAME, annotations.get(1).getName()); assertEquals("testing", annotations.get(1).getMessage()); assertEquals(AnnotationNames.SERVER_SENT, annotations.get(2).getName()); assertNull(annotations.get(2).getMessage()); } @TransformedTest public void testRecordsRequestMethod() throws Exception { InMemorySpanSinkSource sink = new InMemorySpanSinkSource(); Annotations.setServiceAnnotations(new ServiceAnnotations("testing")); SpanSinkRegistry.register(sink); FakeHttpServletRequest request = new FakeHttpServletRequest("FOOF", "http://localhost:8080/foo"); HttpServletResponse response = new FakeHttpServletResponse(new StringWriter()); HttpServlet servlet = new HttpServlet() { @Override protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { resp.setStatus(200); } }; servlet.service(request, response); assertEquals(200, response.getStatus()); Collection<Trace> traces = sink.getTraces(); Trace trace = traces.iterator().next(); SpanData rootSpan = trace.getRoot(); assertEquals("FOOF http://localhost:8080/foo", rootSpan.getName()); } @TransformedTest public void testCreatesSpanBeneathIncomingSpanAndTraceId() throws Exception { InMemorySpanSinkSource sink = new InMemorySpanSinkSource(); Annotations.setServiceAnnotations(new ServiceAnnotations("testing")); SpanSinkRegistry.register(sink); FakeHttpServletRequest request = new FakeHttpServletRequest("GET", "http://localhost:8080/foo", Optional.of(BigInteger.ONE), Optional.of(BigInteger.TEN)); HttpServletResponse response = new FakeHttpServletResponse(new StringWriter()); HttpServlet servlet = new HttpServlet() { @Override protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { resp.setStatus(200); } }; servlet.service(request, response); Collection<Trace> traces = sink.getTraces(); Trace trace = traces.iterator().next(); assertEquals(BigInteger.ONE, trace.getTraceId()); SpanData root = trace.getRoot(); assertNull(root); List<SpanData> childSpans = trace.getChildren(BigInteger.TEN); assertEquals(1, childSpans.size()); SpanData spanData = childSpans.get(0); assertEquals(Optional.of(BigInteger.TEN), spanData.getParentSpanId()); assertEquals("GET http://localhost:8080/foo", spanData.getName()); assertTrue(trace.getAnnotations(BigInteger.TEN).isEmpty()); List<AnnotationData> annotations = trace.getAnnotations(spanData.getSpanId()); assertEquals(3, annotations.size()); assertEquals(AnnotationNames.SERVER_RECEIVED, annotations.get(0).getName()); assertNull(annotations.get(0).getMessage()); assertEquals(AnnotationNames.SERVICE_NAME, annotations.get(1).getName()); assertEquals("testing", annotations.get(1).getMessage()); assertEquals(AnnotationNames.SERVER_SENT, annotations.get(2).getName()); assertNull(annotations.get(2).getMessage()); } } public static class FakeHttpServletResponse implements HttpServletResponse { private final PrintWriter underlying; private int status; public FakeHttpServletResponse(StringWriter underlying) { this.underlying = new PrintWriter(underlying); } @Override public void addCookie(Cookie cookie) { throw new UnsupportedOperationException(); } @Override public boolean containsHeader(String name) { throw new UnsupportedOperationException(); } @Override public String encodeURL(String url) { throw new UnsupportedOperationException(); } @Override public String encodeRedirectURL(String url) { throw new UnsupportedOperationException(); } @Override public String encodeUrl(String url) { throw new UnsupportedOperationException(); } @Override public String encodeRedirectUrl(String url) { throw new UnsupportedOperationException(); } @Override public void sendError(int sc, String msg) throws IOException { throw new UnsupportedOperationException(); } @Override public void sendError(int sc) throws IOException { throw new UnsupportedOperationException(); } @Override public void sendRedirect(String location) throws IOException { throw new UnsupportedOperationException(); } @Override public void setDateHeader(String name, long date) { throw new UnsupportedOperationException(); } @Override public void addDateHeader(String name, long date) { throw new UnsupportedOperationException(); } @Override public void setHeader(String name, String value) { throw new UnsupportedOperationException(); } @Override public void addHeader(String name, String value) { throw new UnsupportedOperationException(); } @Override public void setIntHeader(String name, int value) { throw new UnsupportedOperationException(); } @Override public void addIntHeader(String name, int value) { throw new UnsupportedOperationException(); } @Override public void setStatus(int sc) { this.status = sc; } @Override public void setStatus(int sc, String sm) { this.status = sc; } @Override public int getStatus() { return status; } @Override public String getHeader(String name) { throw new UnsupportedOperationException(); } @Override public Collection<String> getHeaders(String name) { throw new UnsupportedOperationException(); } @Override public Collection<String> getHeaderNames() { throw new UnsupportedOperationException(); } @Override public String getCharacterEncoding() { throw new UnsupportedOperationException(); } @Override public String getContentType() { throw new UnsupportedOperationException(); } @Override public ServletOutputStream getOutputStream() throws IOException { throw new UnsupportedOperationException(); } @Override public PrintWriter getWriter() throws IOException { return underlying; } @Override public void setCharacterEncoding(String charset) { throw new UnsupportedOperationException(); } @Override public void setContentLength(int len) { throw new UnsupportedOperationException(); } @Override public void setContentType(String type) { throw new UnsupportedOperationException(); } @Override public void setBufferSize(int size) { throw new UnsupportedOperationException(); } @Override public int getBufferSize() { throw new UnsupportedOperationException(); } @Override public void flushBuffer() throws IOException { throw new UnsupportedOperationException(); } @Override public void resetBuffer() { throw new UnsupportedOperationException(); } @Override public boolean isCommitted() { throw new UnsupportedOperationException(); } @Override public void reset() { throw new UnsupportedOperationException(); } @Override public void setLocale(Locale loc) { throw new UnsupportedOperationException(); } @Override public Locale getLocale() { throw new UnsupportedOperationException(); } } public static class FakeHttpServletRequest implements HttpServletRequest { private final String method; private final StringBuffer requestURL; private final Optional<BigInteger> traceId; private final Optional<BigInteger> spanId; public FakeHttpServletRequest(String method, String requestURL) { this(method, requestURL, Optional.<BigInteger>absent(), Optional.<BigInteger>absent()); } public FakeHttpServletRequest(String method, String requestURL, Optional<BigInteger> traceId, Optional<BigInteger> spanId) { this.method = method; this.requestURL = new StringBuffer(requestURL); this.traceId = traceId; this.spanId = spanId; } @Override public String getAuthType() { throw new UnsupportedOperationException(); } @Override public Cookie[] getCookies() { throw new UnsupportedOperationException(); } @Override public long getDateHeader(String name) { throw new UnsupportedOperationException(); } @Override public String getHeader(String name) { if (HttpHeaderNames.TRACE_ID.equalsIgnoreCase(name) && traceId.isPresent()) { return traceId.get().toString(); } if (HttpHeaderNames.SPAN_ID.equalsIgnoreCase(name) && spanId.isPresent()) { return spanId.get().toString(); } return null; } @Override public Enumeration<String> getHeaders(String name) { throw new UnsupportedOperationException(); } @Override public Enumeration<String> getHeaderNames() { throw new UnsupportedOperationException(); } @Override public int getIntHeader(String name) { throw new UnsupportedOperationException(); } @Override public String getMethod() { return method; } @Override public String getPathInfo() { throw new UnsupportedOperationException(); } @Override public String getPathTranslated() { throw new UnsupportedOperationException(); } @Override public String getContextPath() { throw new UnsupportedOperationException(); } @Override public String getQueryString() { throw new UnsupportedOperationException(); } @Override public String getRemoteUser() { throw new UnsupportedOperationException(); } @Override public boolean isUserInRole(String role) { throw new UnsupportedOperationException(); } @Override public Principal getUserPrincipal() { throw new UnsupportedOperationException(); } @Override public String getRequestedSessionId() { throw new UnsupportedOperationException(); } @Override public String getRequestURI() { throw new UnsupportedOperationException(); } @Override public StringBuffer getRequestURL() { return requestURL; } @Override public String getServletPath() { throw new UnsupportedOperationException(); } @Override public HttpSession getSession(boolean create) { throw new UnsupportedOperationException(); } @Override public HttpSession getSession() { throw new UnsupportedOperationException(); } @Override public boolean isRequestedSessionIdValid() { throw new UnsupportedOperationException(); } @Override public boolean isRequestedSessionIdFromCookie() { throw new UnsupportedOperationException(); } @Override public boolean isRequestedSessionIdFromURL() { throw new UnsupportedOperationException(); } @Override public boolean isRequestedSessionIdFromUrl() { throw new UnsupportedOperationException(); } @Override public boolean authenticate(HttpServletResponse response) throws IOException, ServletException { throw new UnsupportedOperationException(); } @Override public void login(String username, String password) throws ServletException { throw new UnsupportedOperationException(); } @Override public void logout() throws ServletException { throw new UnsupportedOperationException(); } @Override public Collection<Part> getParts() throws IOException, ServletException { throw new UnsupportedOperationException(); } @Override public Part getPart(String name) throws IOException, ServletException { throw new UnsupportedOperationException(); } @Override public Object getAttribute(String name) { throw new UnsupportedOperationException(); } @Override public Enumeration<String> getAttributeNames() { throw new UnsupportedOperationException(); } @Override public String getCharacterEncoding() { throw new UnsupportedOperationException(); } @Override public void setCharacterEncoding(String env) throws UnsupportedEncodingException { throw new UnsupportedOperationException(); } @Override public int getContentLength() { throw new UnsupportedOperationException(); } @Override public String getContentType() { throw new UnsupportedOperationException(); } @Override public ServletInputStream getInputStream() throws IOException { throw new UnsupportedOperationException(); } @Override public String getParameter(String name) { throw new UnsupportedOperationException(); } @Override public Enumeration<String> getParameterNames() { throw new UnsupportedOperationException(); } @Override public String[] getParameterValues(String name) { throw new UnsupportedOperationException(); } @Override public Map<String, String[]> getParameterMap() { throw new UnsupportedOperationException(); } @Override public String getProtocol() { throw new UnsupportedOperationException(); } @Override public String getScheme() { throw new UnsupportedOperationException(); } @Override public String getServerName() { throw new UnsupportedOperationException(); } @Override public int getServerPort() { throw new UnsupportedOperationException(); } @Override public BufferedReader getReader() throws IOException { throw new UnsupportedOperationException(); } @Override public String getRemoteAddr() { throw new UnsupportedOperationException(); } @Override public String getRemoteHost() { throw new UnsupportedOperationException(); } @Override public void setAttribute(String name, Object o) { throw new UnsupportedOperationException(); } @Override public void removeAttribute(String name) { throw new UnsupportedOperationException(); } @Override public Locale getLocale() { throw new UnsupportedOperationException(); } @Override public Enumeration<Locale> getLocales() { throw new UnsupportedOperationException(); } @Override public boolean isSecure() { throw new UnsupportedOperationException(); } @Override public RequestDispatcher getRequestDispatcher(String path) { throw new UnsupportedOperationException(); } @Override public String getRealPath(String path) { throw new UnsupportedOperationException(); } @Override public int getRemotePort() { throw new UnsupportedOperationException(); } @Override public String getLocalName() { throw new UnsupportedOperationException(); } @Override public String getLocalAddr() { throw new UnsupportedOperationException(); } @Override public int getLocalPort() { throw new UnsupportedOperationException(); } @Override public ServletContext getServletContext() { throw new UnsupportedOperationException(); } @Override public AsyncContext startAsync() throws IllegalStateException { throw new UnsupportedOperationException(); } @Override public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException { throw new UnsupportedOperationException(); } @Override public boolean isAsyncStarted() { throw new UnsupportedOperationException(); } @Override public boolean isAsyncSupported() { throw new UnsupportedOperationException(); } @Override public AsyncContext getAsyncContext() { throw new UnsupportedOperationException(); } @Override public DispatcherType getDispatcherType() { throw new UnsupportedOperationException(); } } }