package com.github.kristofa.brave.grpc;
import static com.github.kristofa.brave.grpc.GrpcKeys.GRPC_STATUS_CODE;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.failBecauseExceptionWasNotThrown;
import com.github.kristofa.brave.Brave;
import com.github.kristofa.brave.LocalTracer;
import com.github.kristofa.brave.Sampler;
import com.github.kristofa.brave.SpanId;
import com.github.kristofa.brave.ThreadLocalServerClientAndLocalSpanState;
import com.github.kristofa.brave.internal.InternalSpan;
import com.github.kristofa.brave.internal.Util;
import com.google.common.util.concurrent.ListenableFuture;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.ServerInterceptors;
import io.grpc.Status;
import io.grpc.examples.helloworld.GreeterGrpc;
import io.grpc.examples.helloworld.GreeterGrpc.GreeterBlockingStub;
import io.grpc.examples.helloworld.GreeterGrpc.GreeterFutureStub;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
import java.util.Collections;
import java.util.concurrent.TimeUnit;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.net.ServerSocket;
import java.util.List;
import java.util.concurrent.ExecutionException;
import zipkin.Constants;
import zipkin.Span;
import zipkin.storage.InMemoryStorage;
import zipkin.storage.QueryRequest;
public class BraveGrpcInterceptorsTest {
static {
InternalSpan.initializeInstanceForTests();
}
static final HelloRequest HELLO_REQUEST = HelloRequest.newBuilder()
.setName("brave")
.build();
Server server;
ManagedChannel channel;
InMemoryStorage storage = new InMemoryStorage();
Brave brave = new Brave.Builder()
.traceSampler(new ExplicitSampler())
.reporter(s -> storage.spanConsumer().accept(Collections.singletonList(s))).build();
boolean enableSampling;
@Before
public void before() throws Exception {
enableSampling = true;
ThreadLocalServerClientAndLocalSpanState.clear();
int serverPort = pickUnusedPort();
server = ServerBuilder.forPort(serverPort)
.addService(ServerInterceptors.intercept(new GreeterImpl(), BraveGrpcServerInterceptor.create(brave)))
.build()
.start();
channel = ManagedChannelBuilder.forAddress("localhost", serverPort)
.intercept(BraveGrpcClientInterceptor.create(brave))
.usePlaintext(true)
.build();
}
@Test
public void testBlockingUnaryCall() throws Exception {
GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(channel);
HelloReply reply = stub.sayHello(HELLO_REQUEST);
assertThat(reply.getMessage()).isEqualTo("Hello brave");
// TODO: only validate client side as for some reason this flakes in travis on server side
zipkin.Span clientSpan = validateOnlyOneSpan();
assertThat(clientSpan.annotations).extracting(a -> a.value)
.contains("cs", "cr"); // not contains exactly
}
@Test
public void testAsyncUnaryCall() throws Exception {
GreeterFutureStub futureStub = GreeterGrpc.newFutureStub(channel);
ListenableFuture<HelloReply> helloReplyListenableFuture = futureStub.sayHello(HELLO_REQUEST);
HelloReply reply = helloReplyListenableFuture.get();
assertThat(reply.getMessage()).isEqualTo("Hello brave");
assertClientServerSpan();
}
@Test
public void statusCodeAddedOnError() throws Exception {
tearDown(); // kill the server
GreeterFutureStub futureStub = GreeterGrpc.newFutureStub(channel);
ListenableFuture<HelloReply> helloReplyListenableFuture = futureStub.sayHello(HELLO_REQUEST);
try {
helloReplyListenableFuture.get();
failBecauseExceptionWasNotThrown(ExecutionException.class);
} catch (ExecutionException expected) {
List<List<zipkin.Span>> traces = storage.spanStore()
.getTraces(QueryRequest.builder().build());
assertThat(traces).hasSize(1);
List<zipkin.Span> spans = traces.get(0);
assertThat(spans.size()).isEqualTo(1);
assertThat(spans.get(0).binaryAnnotations)
.filteredOn(ba -> ba.key.equals(GRPC_STATUS_CODE))
.extracting(ba -> new String(ba.value, Util.UTF_8))
.containsExactly(Status.UNAVAILABLE.getCode().name());
}
}
@Test
public void usesExistingTraceId() throws Exception {
LocalTracer localTracer = brave.localTracer();
SpanId spanId = localTracer.startNewSpan("localSpan", "myop");
GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(channel);
//This call will be made using the context of the localTracer as it's parent
HelloReply reply = stub.sayHello(HELLO_REQUEST);
assertThat(reply.getMessage()).isEqualTo("Hello brave");
Span clientServerSpan = validateOnlyOneSpan();
assertThat(clientServerSpan.traceIdHigh).isEqualTo(spanId.traceIdHigh);
assertThat(clientServerSpan.traceId).isEqualTo(spanId.traceId);
assertThat(clientServerSpan.parentId).isEqualTo(spanId.spanId);
}
/**
* This test verifies that the sampling rate determined by the client is correctly propagated to the server.
* Since sampling is disabled in by the client, then there should be no span generated by the server.
*/
@Test
public void noSamplesWhenSamplingDisabled() throws Exception {
enableSampling = false;
GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(channel);
HelloReply reply = stub.sayHello(HELLO_REQUEST);
assertThat(reply.getMessage()).isEqualTo("Hello brave");
assertThat(storage.spanStore().getRawTraces()).isEmpty();
}
@Test
public void propagatesAndReads128BitTraceId() throws Exception {
SpanId spanId = SpanId.builder().traceIdHigh(1).traceId(2).spanId(3).parentId(2L).build();
brave.localSpanThreadBinder().setCurrentSpan(InternalSpan.instance.toSpan(spanId));
GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(channel);
//This call will be made using hte context of the localTracer as it's parent
HelloReply reply = stub.sayHello(HELLO_REQUEST);
assertThat(reply.getMessage()).isEqualTo("Hello brave");
//Verify that the 128-bit trace id and span id were propagated to the server
zipkin.Span clientServerSpan = validateOnlyOneSpan();
assertThat(clientServerSpan.traceIdHigh).isEqualTo(spanId.traceIdHigh);
assertThat(clientServerSpan.traceId).isEqualTo(spanId.traceId);
assertThat(clientServerSpan.parentId).isEqualTo(spanId.spanId);
}
/**
* Validating that two spans were generated indicates that a span was generated by both the
* server and the client.
*/
void assertClientServerSpan() throws Exception {
zipkin.Span clientServerSpan = validateOnlyOneSpan();
assertThat(clientServerSpan.annotations).extracting(a -> a.value)
.containsExactly("cs", "sr", "ss", "cr");
//Server spans should have the client address binary annotation
assertThat(clientServerSpan.binaryAnnotations)
.filteredOn(ba -> ba.key.equals(Constants.CLIENT_ADDR))
.extracting(b -> b.endpoint.port)
.doesNotContainNull(); // if we got a port, we also got an ipv4 or ipv6
}
zipkin.Span validateOnlyOneSpan() throws InterruptedException {
List<List<Span>> traces = storage.spanStore()
.getTraces(QueryRequest.builder().build());
assertThat(traces).hasSize(1);
assertThat(traces.get(0))
.withFailMessage("Expected client and server to share ids: " + traces.get(0))
.hasSize(1);
return traces.get(0).get(0);
}
@After
public void tearDown() throws InterruptedException {
if (!channel.isShutdown()) {
channel.shutdownNow();
channel.awaitTermination(1, TimeUnit.SECONDS);
}
if (!server.isShutdown()) {
server.shutdownNow();
server.awaitTermination(1, TimeUnit.SECONDS);
}
}
static int pickUnusedPort() {
try {
ServerSocket serverSocket = new ServerSocket(0);
int port = serverSocket.getLocalPort();
serverSocket.close();
return port;
} catch (IOException e) {
throw new RuntimeException(e);
}
}
class ExplicitSampler extends Sampler {
@Override
public boolean isSampled(long traceId) {
return enableSampling;
}
}
}