/* * Copyright 2014-2017 Real Logic Ltd. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.aeron.driver; import io.aeron.driver.buffer.*; import io.aeron.driver.cmd.*; import io.aeron.driver.media.*; import io.aeron.driver.reports.LossReport; import io.aeron.driver.status.SystemCounters; import io.aeron.logbuffer.*; import io.aeron.protocol.*; import org.agrona.ErrorHandler; import org.agrona.concurrent.*; import org.agrona.concurrent.status.*; import org.junit.*; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.DatagramChannel; import static io.aeron.logbuffer.LogBufferDescriptor.*; import static java.lang.Integer.numberOfTrailingZeros; import static junit.framework.TestCase.assertTrue; import static org.agrona.BitUtil.align; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.core.Is.is; import static org.junit.Assert.assertNotNull; import static org.mockito.Mockito.*; public class ReceiverTest { private static final int TERM_BUFFER_LENGTH = TERM_MIN_LENGTH; private static final String URI = "aeron:udp?endpoint=localhost:45678"; private static final UdpChannel UDP_CHANNEL = UdpChannel.parse(URI); private static final long CORRELATION_ID = 20; private static final int STREAM_ID = 10; private static final int INITIAL_TERM_ID = 3; private static final int ACTIVE_TERM_ID = 3; private static final int SESSION_ID = 1; private static final int INITIAL_TERM_OFFSET = 0; private static final int ACTIVE_INDEX = indexByTerm(ACTIVE_TERM_ID, ACTIVE_TERM_ID); private static final byte[] FAKE_PAYLOAD = "Hello there, message!".getBytes(); private static final int INITIAL_WINDOW_LENGTH = Configuration.INITIAL_WINDOW_LENGTH_DEFAULT; private static final long STATUS_MESSAGE_TIMEOUT = Configuration.STATUS_MESSAGE_TIMEOUT_DEFAULT_NS; private static final InetSocketAddress SOURCE_ADDRESS = new InetSocketAddress("localhost", 45679); private static final ReadablePosition POSITION = mock(ReadablePosition.class); private static final ReadablePosition[] POSITIONS = new ReadablePosition[]{ POSITION }; private final FeedbackDelayGenerator mockFeedbackDelayGenerator = mock(FeedbackDelayGenerator.class); private final DataTransportPoller mockDataTransportPoller = mock(DataTransportPoller.class); private final ControlTransportPoller mockControlTransportPoller = mock(ControlTransportPoller.class); private final SystemCounters mockSystemCounters = mock(SystemCounters.class); private final RawLogFactory mockRawLogFactory = mock(RawLogFactory.class); private final Position mockHighestReceivedPosition = spy(new AtomicLongPosition()); private final Position mockRebuildPosition = spy(new AtomicLongPosition()); private final Position mockSubscriberPosition = mock(Position.class); private final ByteBuffer dataFrameBuffer = ByteBuffer.allocateDirect(2 * 1024); private final UnsafeBuffer dataBuffer = new UnsafeBuffer(dataFrameBuffer); private final ByteBuffer setupFrameBuffer = ByteBuffer.allocateDirect(SetupFlyweight.HEADER_LENGTH); private final UnsafeBuffer setupBuffer = new UnsafeBuffer(setupFrameBuffer); private final ErrorHandler mockErrorHandler = mock(ErrorHandler.class); private final DataHeaderFlyweight dataHeader = new DataHeaderFlyweight(); private final StatusMessageFlyweight statusHeader = new StatusMessageFlyweight(); private final SetupFlyweight setupHeader = new SetupFlyweight(); private long currentTime = 0; private final NanoClock nanoClock = () -> currentTime; private final EpochClock epochClock = mock(EpochClock.class); private final LossReport lossReport = mock(LossReport.class); private final RawLog rawLog = LogBufferHelper.newTestLogBuffers(TERM_BUFFER_LENGTH); private final Header header = new Header(INITIAL_TERM_ID, TERM_BUFFER_LENGTH); private UnsafeBuffer[] termBuffers; private DatagramChannel senderChannel; private InetSocketAddress senderAddress = new InetSocketAddress("localhost", 40123); private Receiver receiver; private ReceiverProxy receiverProxy; private ManyToOneConcurrentArrayQueue<DriverConductorCmd> toConductorQueue; private MediaDriver.Context context = new MediaDriver.Context(); private ReceiveChannelEndpoint receiveChannelEndpoint; private CongestionControl congestionControl = mock(CongestionControl.class); // TODO rework test to use proxies rather than the command queues. @Before public void setUp() throws Exception { when(POSITION.getVolatile()) .thenReturn(computePosition(ACTIVE_TERM_ID, 0, numberOfTrailingZeros(TERM_BUFFER_LENGTH), ACTIVE_TERM_ID)); when(mockSystemCounters.get(any())).thenReturn(mock(AtomicCounter.class)); when(congestionControl.onTrackRebuild( anyLong(), anyLong(), anyLong(), anyLong(), anyLong(), anyLong(), anyBoolean())) .thenReturn(CongestionControlUtil.packOutcome(INITIAL_WINDOW_LENGTH, false)); when(congestionControl.initialWindowLength()).thenReturn(INITIAL_WINDOW_LENGTH); final MediaDriver.Context ctx = new MediaDriver.Context() .driverCommandQueue(new ManyToOneConcurrentArrayQueue<>(Configuration.CMD_QUEUE_CAPACITY)) .dataTransportPoller(mockDataTransportPoller) .controlTransportPoller(mockControlTransportPoller) .rawLogBuffersFactory(mockRawLogFactory) .systemCounters(mockSystemCounters) .receiverCommandQueue(new OneToOneConcurrentArrayQueue<>(Configuration.CMD_QUEUE_CAPACITY)) .nanoClock(() -> currentTime); toConductorQueue = ctx.driverCommandQueue(); final DriverConductorProxy driverConductorProxy = new DriverConductorProxy(ThreadingMode.DEDICATED, toConductorQueue, mock(AtomicCounter.class)); ctx.driverConductorProxy(driverConductorProxy); receiverProxy = new ReceiverProxy( ThreadingMode.DEDICATED, ctx.receiverCommandQueue(), mock(AtomicCounter.class)); ctx.receiveChannelEndpointThreadLocals(new ReceiveChannelEndpointThreadLocals(ctx)); receiver = new Receiver(ctx); senderChannel = DatagramChannel.open(); senderChannel.bind(senderAddress); senderChannel.configureBlocking(false); termBuffers = rawLog.termBuffers(); context.systemCounters(mockSystemCounters); context.receiveChannelEndpointThreadLocals(new ReceiveChannelEndpointThreadLocals(context)); receiveChannelEndpoint = new ReceiveChannelEndpoint( UdpChannel.parse(URI), new DataPacketDispatcher(driverConductorProxy, receiver), mock(AtomicCounter.class), context); } @After public void tearDown() throws Exception { receiveChannelEndpoint.close(); senderChannel.close(); receiver.onClose(); } @Test(timeout = 10000) public void shouldCreateRcvTermAndSendSmOnSetup() throws Exception { receiverProxy.registerReceiveChannelEndpoint(receiveChannelEndpoint); receiverProxy.addSubscription(receiveChannelEndpoint, STREAM_ID); receiver.doWork(); fillSetupFrame(setupHeader); receiveChannelEndpoint.onSetupMessage(setupHeader, setupBuffer, SetupFlyweight.HEADER_LENGTH, senderAddress); final PublicationImage image = new PublicationImage( CORRELATION_ID, Configuration.IMAGE_LIVENESS_TIMEOUT_NS, receiveChannelEndpoint, senderAddress, SESSION_ID, STREAM_ID, INITIAL_TERM_ID, ACTIVE_TERM_ID, INITIAL_TERM_OFFSET, rawLog, mockFeedbackDelayGenerator, POSITIONS, mockHighestReceivedPosition, mockRebuildPosition, nanoClock, epochClock, mockSystemCounters, SOURCE_ADDRESS, congestionControl, lossReport, true); final int messagesRead = toConductorQueue.drain( (e) -> { final CreatePublicationImageCmd cmd = (CreatePublicationImageCmd)e; assertThat(cmd.channelEndpoint().udpChannel(), is(UDP_CHANNEL)); assertThat(cmd.streamId(), is(STREAM_ID)); assertThat(cmd.sessionId(), is(SESSION_ID)); assertThat(cmd.termId(), is(ACTIVE_TERM_ID)); // pass in new term buffer from conductor, which should trigger SM receiverProxy.newPublicationImage(receiveChannelEndpoint, image); }); assertThat(messagesRead, is(1)); receiver.doWork(); image.trackRebuild(currentTime + (2 * STATUS_MESSAGE_TIMEOUT), STATUS_MESSAGE_TIMEOUT); image.sendPendingStatusMessage(); final ByteBuffer rcvBuffer = ByteBuffer.allocateDirect(256); InetSocketAddress rcvAddress; do { rcvAddress = (InetSocketAddress)senderChannel.receive(rcvBuffer); } while (null == rcvAddress); statusHeader.wrap(new UnsafeBuffer(rcvBuffer)); assertNotNull(rcvAddress); assertThat(rcvAddress.getPort(), is(UDP_CHANNEL.remoteData().getPort())); assertThat(statusHeader.headerType(), is(HeaderFlyweight.HDR_TYPE_SM)); assertThat(statusHeader.streamId(), is(STREAM_ID)); assertThat(statusHeader.sessionId(), is(SESSION_ID)); assertThat(statusHeader.consumptionTermId(), is(ACTIVE_TERM_ID)); assertThat(statusHeader.frameLength(), is(StatusMessageFlyweight.HEADER_LENGTH)); } @Test public void shouldInsertDataIntoLogAfterInitialExchange() throws Exception { receiverProxy.registerReceiveChannelEndpoint(receiveChannelEndpoint); receiverProxy.addSubscription(receiveChannelEndpoint, STREAM_ID); receiver.doWork(); fillSetupFrame(setupHeader); receiveChannelEndpoint.onSetupMessage(setupHeader, setupBuffer, SetupFlyweight.HEADER_LENGTH, senderAddress); final int commandsRead = toConductorQueue.drain( (e) -> { assertTrue(e instanceof CreatePublicationImageCmd); // pass in new term buffer from conductor, which should trigger SM receiverProxy.newPublicationImage( receiveChannelEndpoint, new PublicationImage( CORRELATION_ID, Configuration.IMAGE_LIVENESS_TIMEOUT_NS, receiveChannelEndpoint, senderAddress, SESSION_ID, STREAM_ID, INITIAL_TERM_ID, ACTIVE_TERM_ID, INITIAL_TERM_OFFSET, rawLog, mockFeedbackDelayGenerator, POSITIONS, mockHighestReceivedPosition, mockRebuildPosition, nanoClock, epochClock, mockSystemCounters, SOURCE_ADDRESS, congestionControl, lossReport, true)); }); assertThat(commandsRead, is(1)); receiver.doWork(); fillDataFrame(dataHeader, 0, FAKE_PAYLOAD); receiveChannelEndpoint.onDataPacket(dataHeader, dataBuffer, dataHeader.frameLength(), senderAddress); final int readOutcome = TermReader.read( termBuffers[ACTIVE_INDEX], INITIAL_TERM_OFFSET, (buffer, offset, length, header) -> { assertThat(header.type(), is(HeaderFlyweight.HDR_TYPE_DATA)); assertThat(header.termId(), is(ACTIVE_TERM_ID)); assertThat(header.streamId(), is(STREAM_ID)); assertThat(header.sessionId(), is(SESSION_ID)); assertThat(header.termOffset(), is(0)); assertThat(header.frameLength(), is(DataHeaderFlyweight.HEADER_LENGTH + FAKE_PAYLOAD.length)); }, Integer.MAX_VALUE, header, mockErrorHandler, 0, mockSubscriberPosition); assertThat(readOutcome, is(1)); } @Test public void shouldNotOverwriteDataFrameWithHeartbeat() throws Exception { receiverProxy.registerReceiveChannelEndpoint(receiveChannelEndpoint); receiverProxy.addSubscription(receiveChannelEndpoint, STREAM_ID); receiver.doWork(); fillSetupFrame(setupHeader); receiveChannelEndpoint.onSetupMessage(setupHeader, setupBuffer, SetupFlyweight.HEADER_LENGTH, senderAddress); final int commandsRead = toConductorQueue.drain( (e) -> { assertTrue(e instanceof CreatePublicationImageCmd); // pass in new term buffer from conductor, which should trigger SM receiverProxy.newPublicationImage( receiveChannelEndpoint, new PublicationImage( CORRELATION_ID, Configuration.IMAGE_LIVENESS_TIMEOUT_NS, receiveChannelEndpoint, senderAddress, SESSION_ID, STREAM_ID, INITIAL_TERM_ID, ACTIVE_TERM_ID, INITIAL_TERM_OFFSET, rawLog, mockFeedbackDelayGenerator, POSITIONS, mockHighestReceivedPosition, mockRebuildPosition, nanoClock, epochClock, mockSystemCounters, SOURCE_ADDRESS, congestionControl, lossReport, true)); }); assertThat(commandsRead, is(1)); receiver.doWork(); fillDataFrame(dataHeader, 0, FAKE_PAYLOAD); // initial data frame receiveChannelEndpoint.onDataPacket(dataHeader, dataBuffer, dataHeader.frameLength(), senderAddress); fillDataFrame(dataHeader, 0, FAKE_PAYLOAD); // heartbeat with same term offset receiveChannelEndpoint.onDataPacket(dataHeader, dataBuffer, dataHeader.frameLength(), senderAddress); final int readOutcome = TermReader.read( termBuffers[ACTIVE_INDEX], INITIAL_TERM_OFFSET, (buffer, offset, length, header) -> { assertThat(header.type(), is(HeaderFlyweight.HDR_TYPE_DATA)); assertThat(header.termId(), is(ACTIVE_TERM_ID)); assertThat(header.streamId(), is(STREAM_ID)); assertThat(header.sessionId(), is(SESSION_ID)); assertThat(header.termOffset(), is(0)); assertThat(header.frameLength(), is(DataHeaderFlyweight.HEADER_LENGTH + FAKE_PAYLOAD.length)); }, Integer.MAX_VALUE, header, mockErrorHandler, 0, mockSubscriberPosition); assertThat(readOutcome, is(1)); } @Test public void shouldOverwriteHeartbeatWithDataFrame() throws Exception { receiverProxy.registerReceiveChannelEndpoint(receiveChannelEndpoint); receiverProxy.addSubscription(receiveChannelEndpoint, STREAM_ID); receiver.doWork(); fillSetupFrame(setupHeader); receiveChannelEndpoint.onSetupMessage(setupHeader, setupBuffer, SetupFlyweight.HEADER_LENGTH, senderAddress); final int commandsRead = toConductorQueue.drain( (e) -> { assertTrue(e instanceof CreatePublicationImageCmd); // pass in new term buffer from conductor, which should trigger SM receiverProxy.newPublicationImage( receiveChannelEndpoint, new PublicationImage( CORRELATION_ID, Configuration.IMAGE_LIVENESS_TIMEOUT_NS, receiveChannelEndpoint, senderAddress, SESSION_ID, STREAM_ID, INITIAL_TERM_ID, ACTIVE_TERM_ID, INITIAL_TERM_OFFSET, rawLog, mockFeedbackDelayGenerator, POSITIONS, mockHighestReceivedPosition, mockRebuildPosition, nanoClock, epochClock, mockSystemCounters, SOURCE_ADDRESS, congestionControl, lossReport, true)); }); assertThat(commandsRead, is(1)); receiver.doWork(); fillDataFrame(dataHeader, 0, FAKE_PAYLOAD); // heartbeat with same term offset receiveChannelEndpoint.onDataPacket(dataHeader, dataBuffer, dataHeader.frameLength(), senderAddress); fillDataFrame(dataHeader, 0, FAKE_PAYLOAD); // initial data frame receiveChannelEndpoint.onDataPacket(dataHeader, dataBuffer, dataHeader.frameLength(), senderAddress); final int readOutcome = TermReader.read( termBuffers[ACTIVE_INDEX], INITIAL_TERM_OFFSET, (buffer, offset, length, header) -> { assertThat(header.type(), is(HeaderFlyweight.HDR_TYPE_DATA)); assertThat(header.termId(), is(ACTIVE_TERM_ID)); assertThat(header.streamId(), is(STREAM_ID)); assertThat(header.sessionId(), is(SESSION_ID)); assertThat(header.termOffset(), is(0)); assertThat(header.frameLength(), is(DataHeaderFlyweight.HEADER_LENGTH + FAKE_PAYLOAD.length)); }, Integer.MAX_VALUE, header, mockErrorHandler, 0, mockSubscriberPosition); assertThat(readOutcome, is(1)); } @Test public void shouldHandleNonZeroTermOffsetCorrectly() throws Exception { final int initialTermOffset = align(TERM_BUFFER_LENGTH / 16, FrameDescriptor.FRAME_ALIGNMENT); final int alignedDataFrameLength = align(DataHeaderFlyweight.HEADER_LENGTH + FAKE_PAYLOAD.length, FrameDescriptor.FRAME_ALIGNMENT); receiverProxy.registerReceiveChannelEndpoint(receiveChannelEndpoint); receiverProxy.addSubscription(receiveChannelEndpoint, STREAM_ID); receiver.doWork(); fillSetupFrame(setupHeader, initialTermOffset); receiveChannelEndpoint.onSetupMessage(setupHeader, setupBuffer, SetupFlyweight.HEADER_LENGTH, senderAddress); final int commandsRead = toConductorQueue.drain( (e) -> { assertTrue(e instanceof CreatePublicationImageCmd); // pass in new term buffer from conductor, which should trigger SM receiverProxy.newPublicationImage( receiveChannelEndpoint, new PublicationImage( CORRELATION_ID, Configuration.IMAGE_LIVENESS_TIMEOUT_NS, receiveChannelEndpoint, senderAddress, SESSION_ID, STREAM_ID, INITIAL_TERM_ID, ACTIVE_TERM_ID, initialTermOffset, rawLog, mockFeedbackDelayGenerator, POSITIONS, mockHighestReceivedPosition, mockRebuildPosition, nanoClock, epochClock, mockSystemCounters, SOURCE_ADDRESS, congestionControl, lossReport, true)); }); assertThat(commandsRead, is(1)); verify(mockHighestReceivedPosition).setOrdered(initialTermOffset); receiver.doWork(); fillDataFrame(dataHeader, initialTermOffset, FAKE_PAYLOAD); // initial data frame receiveChannelEndpoint.onDataPacket(dataHeader, dataBuffer, alignedDataFrameLength, senderAddress); verify(mockHighestReceivedPosition).setOrdered(initialTermOffset + alignedDataFrameLength); final int readOutcome = TermReader.read( termBuffers[ACTIVE_INDEX], initialTermOffset, (buffer, offset, length, header) -> { assertThat(header.type(), is(HeaderFlyweight.HDR_TYPE_DATA)); assertThat(header.termId(), is(ACTIVE_TERM_ID)); assertThat(header.streamId(), is(STREAM_ID)); assertThat(header.sessionId(), is(SESSION_ID)); assertThat(header.termOffset(), is(initialTermOffset)); assertThat(header.frameLength(), is(DataHeaderFlyweight.HEADER_LENGTH + FAKE_PAYLOAD.length)); }, Integer.MAX_VALUE, header, mockErrorHandler, 0, mockSubscriberPosition); assertThat(readOutcome, is(1)); } @Test public void shouldRemoveImageFromDispatcherWithNoActivity() throws Exception { receiverProxy.registerReceiveChannelEndpoint(receiveChannelEndpoint); receiverProxy.addSubscription(receiveChannelEndpoint, STREAM_ID); receiver.doWork(); fillSetupFrame(setupHeader); receiveChannelEndpoint.onSetupMessage(setupHeader, setupBuffer, SetupFlyweight.HEADER_LENGTH, senderAddress); final PublicationImage mockImage = mock(PublicationImage.class); when(mockImage.sessionId()).thenReturn(SESSION_ID); when(mockImage.streamId()).thenReturn(STREAM_ID); when(mockImage.checkForActivity(anyLong())).thenReturn(false); receiver.onNewPublicationImage(receiveChannelEndpoint, mockImage); receiver.doWork(); verify(mockImage).removeFromDispatcher(); } @Test public void shouldNotRemoveImageFromDispatcherOnRemoveSubscription() throws Exception { receiverProxy.registerReceiveChannelEndpoint(receiveChannelEndpoint); receiverProxy.addSubscription(receiveChannelEndpoint, STREAM_ID); receiver.doWork(); fillSetupFrame(setupHeader); receiveChannelEndpoint.onSetupMessage(setupHeader, setupBuffer, SetupFlyweight.HEADER_LENGTH, senderAddress); final PublicationImage mockImage = mock(PublicationImage.class); when(mockImage.sessionId()).thenReturn(SESSION_ID); when(mockImage.streamId()).thenReturn(STREAM_ID); when(mockImage.checkForActivity(anyLong())).thenReturn(true); receiver.onNewPublicationImage(receiveChannelEndpoint, mockImage); receiver.onRemoveSubscription(receiveChannelEndpoint, STREAM_ID); receiver.doWork(); verify(mockImage).ifActiveGoInactive(); verify(mockImage, never()).removeFromDispatcher(); } private void fillDataFrame(final DataHeaderFlyweight header, final int termOffset, final byte[] payload) { header.wrap(dataBuffer); header .termOffset(termOffset) .termId(ACTIVE_TERM_ID) .streamId(STREAM_ID) .sessionId(SESSION_ID) .frameLength(DataHeaderFlyweight.HEADER_LENGTH + payload.length) .headerType(HeaderFlyweight.HDR_TYPE_DATA) .flags(DataHeaderFlyweight.BEGIN_AND_END_FLAGS) .version(HeaderFlyweight.CURRENT_VERSION); if (0 < payload.length) { dataBuffer.putBytes(header.dataOffset(), payload); } } private void fillSetupFrame(final SetupFlyweight header) { fillSetupFrame(header, 0); } private void fillSetupFrame(final SetupFlyweight header, final int termOffset) { header.wrap(setupBuffer); header .streamId(STREAM_ID) .sessionId(SESSION_ID) .initialTermId(INITIAL_TERM_ID) .activeTermId(ACTIVE_TERM_ID) .termOffset(termOffset) .frameLength(SetupFlyweight.HEADER_LENGTH) .headerType(HeaderFlyweight.HDR_TYPE_SETUP) .flags((byte)0) .version(HeaderFlyweight.CURRENT_VERSION); } }