/* * Copyright 2014-2017 Real Logic Ltd. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.aeron; import org.junit.Before; import org.junit.Test; import io.aeron.command.ControlProtocolEvents; import io.aeron.command.CorrelatedMessageFlyweight; import io.aeron.command.ErrorResponseFlyweight; import io.aeron.command.PublicationBuffersReadyFlyweight; import io.aeron.exceptions.ConductorServiceTimeoutException; import io.aeron.exceptions.DriverTimeoutException; import io.aeron.exceptions.RegistrationException; import io.aeron.logbuffer.LogBufferDescriptor; import io.aeron.protocol.DataHeaderFlyweight; import org.agrona.ErrorHandler; import org.agrona.MutableDirectBuffer; import org.agrona.collections.Long2LongHashMap; import org.agrona.concurrent.*; import org.agrona.concurrent.broadcast.CopyBroadcastReceiver; import java.nio.channels.FileChannel; import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.Lock; import java.util.function.Function; import static java.lang.Boolean.TRUE; import static java.nio.ByteBuffer.allocateDirect; import static java.nio.channels.FileChannel.MapMode.READ_ONLY; import static java.nio.channels.FileChannel.MapMode.READ_WRITE; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.sameInstance; import static org.hamcrest.core.Is.is; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.*; import static io.aeron.ErrorCode.INVALID_CHANNEL; import static io.aeron.logbuffer.LogBufferDescriptor.*; public class ClientConductorTest { private static final int TERM_BUFFER_LENGTH = TERM_MIN_LENGTH; protected static final int SESSION_ID_1 = 13; protected static final int SESSION_ID_2 = 15; private static final String CHANNEL = "aeron:udp?endpoint=localhost:40124"; private static final int STREAM_ID_1 = 2; private static final int STREAM_ID_2 = 4; private static final int SEND_BUFFER_CAPACITY = 1024; private static final int COUNTER_BUFFER_LENGTH = 1024; private static final long CORRELATION_ID = 2000; private static final long CORRELATION_ID_2 = 2002; private static final long CLOSE_CORRELATION_ID = 2001; private static final long UNKNOWN_CORRELATION_ID = 3000; private static final long KEEP_ALIVE_INTERVAL = TimeUnit.MILLISECONDS.toNanos(500); private static final long AWAIT_TIMEOUT = 100; private static final long INTER_SERVICE_TIMEOUT_MS = 1000; private static final long PUBLICATION_CONNECTION_TIMEOUT_MS = 5000; private static final String SOURCE_INFO = "127.0.0.1:40789"; private final PublicationBuffersReadyFlyweight publicationReady = new PublicationBuffersReadyFlyweight(); private final CorrelatedMessageFlyweight correlatedMessage = new CorrelatedMessageFlyweight(); private final ErrorResponseFlyweight errorResponse = new ErrorResponseFlyweight(); private final UnsafeBuffer publicationReadyBuffer = new UnsafeBuffer(allocateDirect(SEND_BUFFER_CAPACITY)); private final UnsafeBuffer correlatedMessageBuffer = new UnsafeBuffer(allocateDirect(SEND_BUFFER_CAPACITY)); private final UnsafeBuffer errorMessageBuffer = new UnsafeBuffer(allocateDirect(SEND_BUFFER_CAPACITY)); private final CopyBroadcastReceiver mockToClientReceiver = mock(CopyBroadcastReceiver.class); private final UnsafeBuffer counterValuesBuffer = new UnsafeBuffer(allocateDirect(COUNTER_BUFFER_LENGTH)); private long timeMs = 0; private final EpochClock epochClock = () -> timeMs += 10; private long timeNs = 0; private final NanoClock nanoClock = () -> timeNs += 10_000_000; private final ErrorHandler mockClientErrorHandler = spy(new PrintError()); private DriverProxy driverProxy = mock(DriverProxy.class); private ClientConductor conductor; private AvailableImageHandler mockAvailableImageHandler = mock(AvailableImageHandler.class); private UnavailableImageHandler mockUnavailableImageHandler = mock(UnavailableImageHandler.class); private LogBuffersFactory logBuffersFactory = mock(LogBuffersFactory.class); private Lock mockClientLock = mock(Lock.class); private Long2LongHashMap subscriberPositionMap = new Long2LongHashMap(-1L); private boolean suppressPrintError = false; @Before public void setUp() throws Exception { final Aeron.Context ctx = new Aeron.Context() .clientLock(mockClientLock) .epochClock(epochClock) .nanoClock(nanoClock) .toClientBuffer(mockToClientReceiver) .driverProxy(driverProxy) .logBuffersFactory(logBuffersFactory) .errorHandler(mockClientErrorHandler) .availableImageHandler(mockAvailableImageHandler) .unavailableImageHandler(mockUnavailableImageHandler) .imageMapMode(READ_ONLY) .keepAliveInterval(KEEP_ALIVE_INTERVAL) .driverTimeoutMs(AWAIT_TIMEOUT) .interServiceTimeout(TimeUnit.MILLISECONDS.toNanos(INTER_SERVICE_TIMEOUT_MS)) .publicationConnectionTimeout(PUBLICATION_CONNECTION_TIMEOUT_MS); ctx.countersValuesBuffer(counterValuesBuffer); when(mockClientLock.tryLock()).thenReturn(TRUE); when(driverProxy.addPublication(CHANNEL, STREAM_ID_1)).thenReturn(CORRELATION_ID); when(driverProxy.addPublication(CHANNEL, STREAM_ID_2)).thenReturn(CORRELATION_ID_2); when(driverProxy.removePublication(CORRELATION_ID)).thenReturn(CLOSE_CORRELATION_ID); when(driverProxy.addSubscription(anyString(), anyInt())).thenReturn(CORRELATION_ID); when(driverProxy.removeSubscription(CORRELATION_ID)).thenReturn(CLOSE_CORRELATION_ID); conductor = new ClientConductor(ctx); publicationReady.wrap(publicationReadyBuffer, 0); correlatedMessage.wrap(correlatedMessageBuffer, 0); errorResponse.wrap(errorMessageBuffer, 0); publicationReady.correlationId(CORRELATION_ID); publicationReady.sessionId(SESSION_ID_1); publicationReady.streamId(STREAM_ID_1); publicationReady.logFileName(SESSION_ID_1 + "-log"); subscriberPositionMap.put(CORRELATION_ID, 0); correlatedMessage.correlationId(CLOSE_CORRELATION_ID); final UnsafeBuffer[] termBuffersSession1 = new UnsafeBuffer[PARTITION_COUNT]; final UnsafeBuffer[] termBuffersSession2 = new UnsafeBuffer[PARTITION_COUNT]; for (int i = 0; i < PARTITION_COUNT; i++) { termBuffersSession1[i] = new UnsafeBuffer(allocateDirect(TERM_BUFFER_LENGTH)); termBuffersSession2[i] = new UnsafeBuffer(allocateDirect(TERM_BUFFER_LENGTH)); } final UnsafeBuffer logMetaDataSession1 = new UnsafeBuffer(allocateDirect(TERM_BUFFER_LENGTH)); final UnsafeBuffer logMetaDataSession2 = new UnsafeBuffer(allocateDirect(TERM_BUFFER_LENGTH)); final MutableDirectBuffer header1 = DataHeaderFlyweight.createDefaultHeader(SESSION_ID_1, STREAM_ID_1, 0); final MutableDirectBuffer header2 = DataHeaderFlyweight.createDefaultHeader(SESSION_ID_2, STREAM_ID_2, 0); LogBufferDescriptor.storeDefaultFrameHeader(logMetaDataSession1, header1); LogBufferDescriptor.storeDefaultFrameHeader(logMetaDataSession2, header2); final LogBuffers logBuffersSession1 = mock(LogBuffers.class); final LogBuffers logBuffersSession2 = mock(LogBuffers.class); when(logBuffersFactory.map(SESSION_ID_1 + "-log", READ_WRITE)).thenReturn(logBuffersSession1); when(logBuffersFactory.map(SESSION_ID_2 + "-log", READ_WRITE)).thenReturn(logBuffersSession2); when(logBuffersFactory.map(SESSION_ID_1 + "-log", READ_ONLY)).thenReturn(logBuffersSession1); when(logBuffersFactory.map(SESSION_ID_2 + "-log", READ_ONLY)).thenReturn(logBuffersSession2); when(logBuffersSession1.termBuffers()).thenReturn(termBuffersSession1); when(logBuffersSession2.termBuffers()).thenReturn(termBuffersSession2); when(logBuffersSession1.metaDataBuffer()).thenReturn(logMetaDataSession1); when(logBuffersSession2.metaDataBuffer()).thenReturn(logMetaDataSession2); } // -------------------------------- // Publication related interactions // -------------------------------- @Test public void addPublicationShouldNotifyMediaDriver() throws Exception { whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_PUBLICATION_READY, publicationReadyBuffer, (buffer) -> publicationReady.length()); conductor.addPublication(CHANNEL, STREAM_ID_1); verify(driverProxy).addPublication(CHANNEL, STREAM_ID_1); } @Test public void addPublicationShouldMapLogFile() throws Exception { whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_PUBLICATION_READY, publicationReadyBuffer, (buffer) -> publicationReady.length()); conductor.addPublication(CHANNEL, STREAM_ID_1); verify(logBuffersFactory).map(SESSION_ID_1 + "-log", READ_WRITE); } @Test(expected = DriverTimeoutException.class, timeout = 5_000) public void addPublicationShouldTimeoutWithoutReadyMessage() { conductor.addPublication(CHANNEL, STREAM_ID_1); } @Test public void conductorShouldCachePublicationInstances() { whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_PUBLICATION_READY, publicationReadyBuffer, (buffer) -> publicationReady.length()); final Publication firstPublication = conductor.addPublication(CHANNEL, STREAM_ID_1); final Publication secondPublication = conductor.addPublication(CHANNEL, STREAM_ID_1); assertThat(firstPublication, sameInstance(secondPublication)); } @Test public void closingPublicationShouldNotifyMediaDriver() throws Exception { whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_PUBLICATION_READY, publicationReadyBuffer, (buffer) -> publicationReady.length()); final Publication publication = conductor.addPublication(CHANNEL, STREAM_ID_1); whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_OPERATION_SUCCESS, correlatedMessageBuffer, (buffer) -> CorrelatedMessageFlyweight.LENGTH); publication.close(); verify(driverProxy).removePublication(CORRELATION_ID); } @Test public void closingPublicationShouldPurgeCache() throws Exception { whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_PUBLICATION_READY, publicationReadyBuffer, (buffer) -> publicationReady.length()); final Publication firstPublication = conductor.addPublication(CHANNEL, STREAM_ID_1); whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_OPERATION_SUCCESS, correlatedMessageBuffer, (buffer) -> CorrelatedMessageFlyweight.LENGTH); firstPublication.close(); whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_PUBLICATION_READY, publicationReadyBuffer, (buffer) -> publicationReady.length()); final Publication secondPublication = conductor.addPublication(CHANNEL, STREAM_ID_1); assertThat(firstPublication, not(sameInstance(secondPublication))); } @Test(expected = RegistrationException.class) public void shouldFailToClosePublicationOnMediaDriverError() { whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_PUBLICATION_READY, publicationReadyBuffer, (buffer) -> publicationReady.length()); final Publication publication = conductor.addPublication(CHANNEL, STREAM_ID_1); whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_ERROR, errorMessageBuffer, (buffer) -> { errorResponse.errorCode(INVALID_CHANNEL); errorResponse.errorMessage("channel unknown"); errorResponse.offendingCommandCorrelationId(CLOSE_CORRELATION_ID); return errorResponse.length(); }); publication.close(); } @Test(expected = RegistrationException.class) public void shouldFailToAddPublicationOnMediaDriverError() { whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_ERROR, errorMessageBuffer, (buffer) -> { errorResponse.errorCode(INVALID_CHANNEL); errorResponse.errorMessage("invalid channel"); errorResponse.offendingCommandCorrelationId(CORRELATION_ID); return errorResponse.length(); }); conductor.addPublication(CHANNEL, STREAM_ID_1); } @Test public void publicationOnlyRemovedOnLastClose() throws Exception { whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_PUBLICATION_READY, publicationReadyBuffer, (buffer) -> publicationReady.length()); final Publication publication = conductor.addPublication(CHANNEL, STREAM_ID_1); conductor.addPublication(CHANNEL, STREAM_ID_1); publication.close(); verify(driverProxy, never()).removePublication(CORRELATION_ID); whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_OPERATION_SUCCESS, correlatedMessageBuffer, (buffer) -> CorrelatedMessageFlyweight.LENGTH); publication.close(); verify(driverProxy).removePublication(CORRELATION_ID); } @Test public void closingPublicationDoesNotRemoveOtherPublications() throws Exception { whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_PUBLICATION_READY, publicationReadyBuffer, (buffer) -> publicationReady.length()); final Publication publication = conductor.addPublication(CHANNEL, STREAM_ID_1); whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_PUBLICATION_READY, publicationReadyBuffer, (buffer) -> { publicationReady.streamId(STREAM_ID_2); publicationReady.sessionId(SESSION_ID_2); publicationReady.logFileName(SESSION_ID_2 + "-log"); publicationReady.correlationId(CORRELATION_ID_2); return publicationReady.length(); }); conductor.addPublication(CHANNEL, STREAM_ID_2); whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_OPERATION_SUCCESS, correlatedMessageBuffer, (buffer) -> CorrelatedMessageFlyweight.LENGTH); publication.close(); verify(driverProxy).removePublication(CORRELATION_ID); verify(driverProxy, never()).removePublication(CORRELATION_ID_2); } @Test public void shouldNotMapBuffersForUnknownCorrelationId() throws Exception { whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_PUBLICATION_READY, publicationReadyBuffer, (buffer) -> { publicationReady.correlationId(UNKNOWN_CORRELATION_ID); return publicationReady.length(); }); whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_PUBLICATION_READY, publicationReadyBuffer, (buffer) -> { publicationReady.correlationId(CORRELATION_ID); return publicationReady.length(); }); final Publication publication = conductor.addPublication(CHANNEL, STREAM_ID_1); conductor.doWork(); verify(logBuffersFactory, times(1)).map(anyString(), any(FileChannel.MapMode.class)); assertThat(publication.registrationId(), is(CORRELATION_ID)); } // --------------------------------- // Subscription related interactions // --------------------------------- @Test public void addSubscriptionShouldNotifyMediaDriver() throws Exception { whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_OPERATION_SUCCESS, correlatedMessageBuffer, (buffer) -> { correlatedMessage.correlationId(CORRELATION_ID); return CorrelatedMessageFlyweight.LENGTH; }); conductor.addSubscription(CHANNEL, STREAM_ID_1); verify(driverProxy).addSubscription(CHANNEL, STREAM_ID_1); } @Test public void closingSubscriptionShouldNotifyMediaDriver() { whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_OPERATION_SUCCESS, correlatedMessageBuffer, (buffer) -> { correlatedMessage.correlationId(CORRELATION_ID); return CorrelatedMessageFlyweight.LENGTH; }); final Subscription subscription = conductor.addSubscription(CHANNEL, STREAM_ID_1); whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_OPERATION_SUCCESS, correlatedMessageBuffer, (buffer) -> { correlatedMessage.correlationId(CLOSE_CORRELATION_ID); return CorrelatedMessageFlyweight.LENGTH; }); subscription.close(); verify(driverProxy).removeSubscription(CORRELATION_ID); } @Test(expected = DriverTimeoutException.class, timeout = 5_000) public void addSubscriptionShouldTimeoutWithoutOperationSuccessful() { conductor.addSubscription(CHANNEL, STREAM_ID_1); } @Test(expected = RegistrationException.class) public void shouldFailToAddSubscriptionOnMediaDriverError() { whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_ERROR, errorMessageBuffer, (buffer) -> { errorResponse.errorCode(INVALID_CHANNEL); errorResponse.errorMessage("invalid channel"); errorResponse.offendingCommandCorrelationId(CORRELATION_ID); return errorResponse.length(); }); conductor.addSubscription(CHANNEL, STREAM_ID_1); } @Test public void clientNotifiedOfNewImageShouldMapLogFile() { whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_OPERATION_SUCCESS, correlatedMessageBuffer, (buffer) -> { correlatedMessage.correlationId(CORRELATION_ID); return CorrelatedMessageFlyweight.LENGTH; }); conductor.addSubscription(CHANNEL, STREAM_ID_1); conductor.onAvailableImage( STREAM_ID_1, SESSION_ID_1, subscriberPositionMap, SESSION_ID_1 + "-log", SOURCE_INFO, CORRELATION_ID); verify(logBuffersFactory).map(eq(SESSION_ID_1 + "-log"), any(FileChannel.MapMode.class)); } @Test public void clientNotifiedOfNewAndInactiveImages() { whenReceiveBroadcastOnMessage( ControlProtocolEvents.ON_OPERATION_SUCCESS, correlatedMessageBuffer, (buffer) -> { correlatedMessage.correlationId(CORRELATION_ID); return CorrelatedMessageFlyweight.LENGTH; }); final Subscription subscription = conductor.addSubscription(CHANNEL, STREAM_ID_1); conductor.onAvailableImage( STREAM_ID_1, SESSION_ID_1, subscriberPositionMap, SESSION_ID_1 + "-log", SOURCE_INFO, CORRELATION_ID); assertFalse(subscription.hasNoImages()); verify(mockAvailableImageHandler).onAvailableImage(any(Image.class)); conductor.onUnavailableImage(STREAM_ID_1, CORRELATION_ID); verify(mockUnavailableImageHandler).onUnavailableImage(any(Image.class)); assertTrue(subscription.hasNoImages()); } @Test public void shouldIgnoreUnknownNewImage() { conductor.onAvailableImage( STREAM_ID_2, SESSION_ID_2, subscriberPositionMap, SESSION_ID_2 + "-log", SOURCE_INFO, CORRELATION_ID_2); verify(logBuffersFactory, never()).map(anyString(), any(FileChannel.MapMode.class)); verify(mockAvailableImageHandler, never()).onAvailableImage(any(Image.class)); } @Test public void shouldIgnoreUnknownInactiveImage() { conductor.onUnavailableImage(STREAM_ID_2, CORRELATION_ID_2); verify(logBuffersFactory, never()).map(anyString(), any(FileChannel.MapMode.class)); verify(mockUnavailableImageHandler, never()).onUnavailableImage(any(Image.class)); } @Test public void shouldTimeoutInterServiceIfTooLongBetweenDoWorkCalls() throws Exception { suppressPrintError = true; conductor.doWork(); timeNs += (TimeUnit.MILLISECONDS.toNanos(INTER_SERVICE_TIMEOUT_MS) + 1); conductor.doWork(); verify(mockClientErrorHandler).onError(any(ConductorServiceTimeoutException.class)); } private void whenReceiveBroadcastOnMessage( final int msgTypeId, final MutableDirectBuffer buffer, final Function<MutableDirectBuffer, Integer> filler) { doAnswer( (invocation) -> { final int length = filler.apply(buffer); conductor.driverListenerAdapter().onMessage(msgTypeId, buffer, 0, length); return 1; }).when(mockToClientReceiver).receive(any()); } private class PrintError implements ErrorHandler { public void onError(final Throwable throwable) { if (!suppressPrintError) { throwable.printStackTrace(); } } } }