package de.dfki.nlp.flow;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Charsets;
import com.google.common.base.MoreObjects;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multimaps;
import de.dfki.nlp.config.AnnotatorConfig;
import de.dfki.nlp.config.MessagingConfig.ProcessingGateway;
import de.dfki.nlp.domain.IdList;
import de.dfki.nlp.domain.PredictionResult;
import de.dfki.nlp.domain.exceptions.Errors;
import de.dfki.nlp.domain.rest.ServerRequest;
import de.dfki.nlp.domain.rest.ServerResponse;
import de.dfki.nlp.errors.FailedMessage;
import de.dfki.nlp.io.BufferingClientHttpResponseWrapper;
import de.dfki.nlp.loader.MultiDocumentFetcher;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.aopalliance.aop.Advice;
import org.apache.commons.lang3.StringUtils;
import org.springframework.amqp.core.Queue;
import org.springframework.amqp.rabbit.connection.ConnectionFactory;
import org.springframework.amqp.rabbit.retry.RejectAndDontRequeueRecoverer;
import org.springframework.amqp.support.converter.Jackson2JsonMessageConverter;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Profile;
import org.springframework.core.env.Environment;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.integration.amqp.support.DefaultAmqpHeaderMapper;
import org.springframework.integration.dsl.Adapters;
import org.springframework.integration.dsl.IntegrationFlow;
import org.springframework.integration.dsl.IntegrationFlowBuilder;
import org.springframework.integration.dsl.IntegrationFlows;
import org.springframework.integration.dsl.amqp.Amqp;
import org.springframework.integration.dsl.core.MessageHandlerSpec;
import org.springframework.integration.dsl.support.Function;
import org.springframework.integration.handler.LoggingHandler;
import org.springframework.integration.handler.advice.RequestHandlerRetryAdvice;
import org.springframework.integration.http.outbound.HttpRequestExecutingMessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.retry.backoff.ExponentialBackOffPolicy;
import org.springframework.retry.interceptor.RetryOperationsInterceptor;
import org.springframework.retry.policy.SimpleRetryPolicy;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.stereotype.Component;
import org.springframework.util.FileCopyUtils;
import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.client.UnknownHttpStatusCodeException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.util.*;
import java.util.stream.Collectors;
import static org.springframework.amqp.rabbit.config.RetryInterceptorBuilder.stateless;
@Slf4j
@Component
@AllArgsConstructor
@Profile("backend")
public class FlowHandler {
private final MultiDocumentFetcher documentFetcher;
private final AnnotatorConfig annotatorConfig;
private final ObjectMapper objectMapper;
@Bean
IntegrationFlow errorSendingResults() {
return f -> f
.handle(MessagingException.class, (payload, headers) -> {
FailedMessage failedMessage = new FailedMessage();
Integer communicationId = (Integer) payload.getFailedMessage().getHeaders().getOrDefault("communication_id", -1);
failedMessage.setCommunicationId(communicationId);
log.error("Failure sending results [{}] {}", communicationId, payload.getMessage());
try {
failedMessage.setFailedMessagePayload(objectMapper.writeValueAsString(payload.getFailedMessage().getPayload()));
} catch (JsonProcessingException e) {
log.error("Could not serialize the error message {}", e.getMessage());
}
failedMessage.setServerErrorCause(payload.getMessage());
// try to replicate most of the message and the error
if (payload.getCause() instanceof HttpClientErrorException) {
HttpClientErrorException cause = (HttpClientErrorException) payload.getCause();
String serverPayload = cause.getResponseBodyAsString();
failedMessage.setServerErrorPayload(serverPayload);
}
try {
log.error("The complete failed message for retrying\n{}", objectMapper.writeValueAsString(failedMessage));
} catch (JsonProcessingException e) {
log.error("Double error ... {}", e.getMessage());
log.error("The complete failed message for retrying\n{}", failedMessage);
}
return null;
});
}
@Bean
public RetryOperationsInterceptor retryOperationsInterceptor() {
return stateless()
.maxAttempts(2)
.recoverer(new RejectAndDontRequeueRecoverer()).build();
}
@Bean
IntegrationFlow flow(ConnectionFactory connectionFactory, Queue input, Jackson2JsonMessageConverter messageConverter, Environment environment) {
IntegrationFlowBuilder flow =
IntegrationFlows
.from(
Amqp.inboundAdapter(connectionFactory, input)
// set concurrentConsumers - anything larger than 1 gives parallelism per annotator request
// but not the number of requests
.concurrentConsumers(annotatorConfig.getConcurrentConsumer())
.messageConverter(messageConverter)
.headerMapper(DefaultAmqpHeaderMapper.inboundMapper())
// retry the complete message
// if this fails ... forward to the error queue
.defaultRequeueRejected(false)
.adviceChain(retryOperationsInterceptor())
.errorChannel("errorSendingResults.input")
)
.enrichHeaders(headerEnricherSpec -> headerEnricherSpec.headerExpression("communication_id", "payload.parameters.communication_id"))
.enrichHeaders(headerEnricherSpec -> headerEnricherSpec.headerExpression("types", "payload.parameters.types"))
.split(ServerRequest.class, serverRequest -> {
// partition the input
ImmutableListMultimap<String, ServerRequest.Document> index = Multimaps.index(serverRequest.getParameters()
.getDocuments(), ServerRequest.Document::getSource);
List<IdList> idLists = new ArrayList<>();
// now split into X at most per source
for (Map.Entry<String, Collection<ServerRequest.Document>> entry : index.asMap().entrySet()) {
for (List<ServerRequest.Document> documentList : Iterables.partition(entry.getValue(), annotatorConfig.getRequestBulkSize())) {
idLists.add(new IdList(entry.getKey(), documentList.stream().map(ServerRequest.Document::getDocument_id).collect(Collectors.toList())));
}
}
return idLists;
}
)
// handle in parallel using an executor on a different channel
// .channel(c -> c.executor("Downloader", Executors.newFixedThreadPool(annotatorConfig.getConcurrentHandler())))
.transform(IdList.class, documentFetcher::load)
.split()
.channel("annotate")
.routeToRecipients(r ->
r.applySequence(true)
.defaultOutputToParentFlow()
.recipient("mirner", "headers['types'].contains(T(de.dfki.nlp.domain.PredictionType).MIRNA)")
.recipient("seth", "headers['types'].contains(T(de.dfki.nlp.domain.PredictionType).MUTATION)")
.recipient("diseases", "headers['types'].contains(T(de.dfki.nlp.domain.PredictionType).DISEASE)")
)
.channel("parsed")
.aggregate() // this aggregates annotations per document (from router)
.<List<Set<PredictionResult>>, Set<PredictionResult>>transform(s -> s.stream().flatMap(Collection::stream).collect(Collectors.toSet()))
.channel("aggregate")
.aggregate() // this aggregates all document per source group
.aggregate() // this aggregates all documents
// now merge the results by flattening
.channel("jointogether")
.<List<List<Set<PredictionResult>>>, Set<PredictionResult>>transform(source ->
source.stream().flatMap(Collection::stream).flatMap(Collection::stream).collect(Collectors.toSet()));
if (environment.acceptsProfiles("cloud")) {
// when cloud profile is active, send results via http
flow
.enrichHeaders(headerEnricherSpec -> headerEnricherSpec.headerExpression("Content-Type", "'application/json'"))
.log(LoggingHandler.Level.INFO, objectMessage ->
String.format("Sending Results [%s] after %d ms %s", objectMessage.getHeaders().get("communication_id"), (System.currentTimeMillis() - (long) objectMessage.getHeaders().get(ProcessingGateway.HEADER_REQUEST_TIME)), objectMessage.getHeaders().toString()))
//"respone", "headers['communication_id']")
// use retry advice - which tries to resend in case of failure
.handleWithAdapter(sendToBecalmServer(), e -> e.advice(retryAdvice()));
} else {
// for local deployment, just log
flow
.<Set<PredictionResult>>handle((parsed, headers) -> {
log.info(headers.toString());
log.info("Annotation request took [{}] {} ms", headers.get("communication_id"), System.currentTimeMillis() - (long) headers.get(ProcessingGateway.HEADER_REQUEST_TIME));
parsed
.stream()
.sorted((o1, o2) -> ComparisonChain
.start()
.compare(o1.getDocumentId(), o2.getDocumentId())
.compare(o1.getSection().name(), o2.getSection().name())
.compare(o1.getInit(), o2.getInit())
.result())
.forEach(r -> log.info(r.toString()));
return null;
});
}
return flow.get();
}
/**
* This bean is a retry advice to handle failures using retry
* finally it fails.
* It tries 20 (configurable) times using an exponential backoff strategy:
* - initial wait 100ms - default multiplier 2 - maximal wait between calls 30s
* <p>
* Thus it tries in case of a failure ...
* </p>
* <pre>
* Sleeping for 100
* Sleeping for 200
* ......
* Sleeping for 30000
* </pre>
*
* @return bean
*/
@Bean
public Advice retryAdvice() {
RequestHandlerRetryAdvice advice = new RequestHandlerRetryAdvice();
RetryTemplate retryTemplate = new RetryTemplate();
// use exponential backoff to wait between calls
ExponentialBackOffPolicy backOffPolicy = new ExponentialBackOffPolicy();
backOffPolicy.setInitialInterval(1000);
backOffPolicy.setMultiplier(2);
backOffPolicy.setMaxInterval(60000);
retryTemplate.setBackOffPolicy(backOffPolicy);
// try at most 20 times
retryTemplate.setRetryPolicy(new SimpleRetryPolicy(annotatorConfig.becalmSaveAnnotationRetries));
advice.setRetryTemplate(retryTemplate);
return advice;
}
/**
* Handles the BECALM post requests.
*
* @return the configured HTTP adapter
*/
private Function<Adapters, MessageHandlerSpec<?, HttpRequestExecutingMessageHandler>> sendToBecalmServer() {
return adapters -> {
RestTemplate restTemplate = new RestTemplate();
// we need to use a custom interceptor to allow multiple reading of the respons body
// once to check if we have an error, a second time to see the results (if there was no error)
restTemplate.getInterceptors().add(new BufferingClientHttpResponseWrapper());
restTemplate.setErrorHandler(errorHandler());
return adapters
// where to send the results to
.http(annotatorConfig.becalmSaveAnnotationLocation, restTemplate)
// use the post method
.httpMethod(HttpMethod.POST)
// replace URI placeholders using variables
.uriVariable("apikey", "'" + annotatorConfig.apiKey + "'")
.uriVariable("communicationId", "headers['communication_id']")
.expectedResponseType(ServerResponse.class);
};
}
@Bean
ResponseErrorHandler errorHandler() {
return new ResponseErrorHandler() {
@Override
public boolean hasError(ClientHttpResponse response) throws IOException {
// todo check success
return true;
}
@Override
public void handleError(ClientHttpResponse response) throws IOException {
// we have some sort of error
// try to inspect the result
HttpStatus statusCode = getHttpStatusCode(response);
byte[] responseBody = getResponseBody(response);
Charset charset = getCharset(response);
try {
HttpHeaders headers = response.getHeaders();
MediaType contentType = headers.getContentType();
if (!contentType.includes(MediaType.APPLICATION_JSON)) {
String serverResponse = new String(responseBody, MoreObjects.firstNonNull(charset, Charsets.UTF_8));
log.error("Server did not respond with JSON, but contentType {} - {}", contentType, StringUtils.replaceAll(serverResponse, "[\\r\\n]+", " "));
throw new HttpClientErrorException(statusCode, response.getStatusText(), response.getHeaders(), serverResponse.getBytes(), charset);
}
ServerResponse serverResponse = objectMapper.readValue(responseBody, ServerResponse.class);
// now check the serverResponse
if (serverResponse.isSuccess()) {
log.info("Success sending results: {}", serverResponse);
return;
}
switch (Errors.lookup(serverResponse.getErrorCode())) {
case REQUEST_CLOSED: // no error - we have sent it once .. continue
log.info("Posting results - assuming we don't have an error as server responded with {}", serverResponse.toString());
break;
default:
log.error("Error posting results to server: {} ", serverResponse);
throw new HttpClientErrorException(statusCode, response.getStatusText(), response.getHeaders(), serverResponse.toString().getBytes(), charset);
}
} catch (IOException e) {
log.error("Error parsing Server Response: {} {}", new String(responseBody), e.getMessage());
throw new HttpClientErrorException(statusCode, response.getStatusText(),
response.getHeaders(), responseBody, charset);
}
}
private HttpStatus getHttpStatusCode(ClientHttpResponse response) throws IOException {
HttpStatus statusCode;
try {
statusCode = response.getStatusCode();
} catch (IllegalArgumentException ex) {
throw new UnknownHttpStatusCodeException(response.getRawStatusCode(),
response.getStatusText(), response.getHeaders(), getResponseBody(response), getCharset(response));
}
return statusCode;
}
private byte[] getResponseBody(ClientHttpResponse response) {
try {
InputStream responseBody = response.getBody();
if (responseBody != null) {
return FileCopyUtils.copyToByteArray(responseBody);
}
} catch (IOException ex) {
// ignore
}
return new byte[0];
}
private Charset getCharset(ClientHttpResponse response) {
HttpHeaders headers = response.getHeaders();
MediaType contentType = headers.getContentType();
return contentType != null ? contentType.getCharset() : null;
}
};
}
}