/*
* Copyright 2014-2017 Real Logic Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.aeron.driver;
import io.aeron.driver.media.*;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import io.aeron.driver.status.SystemCounters;
import io.aeron.logbuffer.FrameDescriptor;
import io.aeron.protocol.DataHeaderFlyweight;
import io.aeron.protocol.HeaderFlyweight;
import io.aeron.protocol.StatusMessageFlyweight;
import org.agrona.BitUtil;
import org.agrona.concurrent.status.AtomicCounter;
import org.agrona.concurrent.UnsafeBuffer;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.concurrent.atomic.AtomicInteger;
import static org.hamcrest.core.Is.is;
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.*;
public class SelectorAndTransportTest
{
private static final int RCV_PORT = 40123;
private static final int SRC_PORT = 40124;
private static final int SESSION_ID = 0xdeadbeef;
private static final int STREAM_ID = 0x44332211;
private static final int TERM_ID = 0x99887766;
private static final int FRAME_LENGTH = 24;
private static final UdpChannel SRC_DST =
UdpChannel.parse("aeron:udp?interface=localhost:" + SRC_PORT + "|endpoint=localhost:" + RCV_PORT);
private static final UdpChannel RCV_DST = UdpChannel.parse("aeron:udp?endpoint=localhost:" + RCV_PORT);
private final ByteBuffer byteBuffer = ByteBuffer.allocateDirect(256);
private final UnsafeBuffer buffer = new UnsafeBuffer(byteBuffer);
private final DataHeaderFlyweight encodeDataHeader = new DataHeaderFlyweight();
private final StatusMessageFlyweight statusMessage = new StatusMessageFlyweight();
private final InetSocketAddress rcvRemoteAddress = new InetSocketAddress("localhost", SRC_PORT);
private final SystemCounters mockSystemCounters = mock(SystemCounters.class);
private final AtomicCounter mockStatusMessagesReceivedCounter = mock(AtomicCounter.class);
private final AtomicCounter mockSendStatusIndicator = mock(AtomicCounter.class);
private final AtomicCounter mockReceiveStatusIndicator = mock(AtomicCounter.class);
private final DataPacketDispatcher mockDispatcher = mock(DataPacketDispatcher.class);
private final NetworkPublication mockPublication = mock(NetworkPublication.class);
private DataTransportPoller dataTransportPoller = new DataTransportPoller();
private ControlTransportPoller controlTransportPoller = new ControlTransportPoller();
private SendChannelEndpoint sendChannelEndpoint;
private ReceiveChannelEndpoint receiveChannelEndpoint;
private MediaDriver.Context context = new MediaDriver.Context();
@Before
public void setup()
{
when(mockSystemCounters.get(any())).thenReturn(mockStatusMessagesReceivedCounter);
when(mockPublication.streamId()).thenReturn(STREAM_ID);
when(mockPublication.sessionId()).thenReturn(SESSION_ID);
context.systemCounters(mockSystemCounters);
context.receiveChannelEndpointThreadLocals(new ReceiveChannelEndpointThreadLocals(context));
}
@After
public void tearDown()
{
try
{
if (null != sendChannelEndpoint)
{
sendChannelEndpoint.close();
processLoop(controlTransportPoller, 5);
}
if (null != receiveChannelEndpoint)
{
receiveChannelEndpoint.close();
processLoop(dataTransportPoller, 5);
}
if (null != dataTransportPoller)
{
dataTransportPoller.close();
controlTransportPoller.close();
}
}
catch (final Exception ex)
{
ex.printStackTrace();
}
}
@Test(timeout = 1000)
public void shouldHandleBasicSetupAndTearDown() throws Exception
{
receiveChannelEndpoint = new ReceiveChannelEndpoint(
RCV_DST, mockDispatcher, mockReceiveStatusIndicator, context);
sendChannelEndpoint = new SendChannelEndpoint(SRC_DST, mockSendStatusIndicator, context);
receiveChannelEndpoint.openDatagramChannel(mockReceiveStatusIndicator);
receiveChannelEndpoint.registerForRead(dataTransportPoller);
sendChannelEndpoint.openDatagramChannel(mockSendStatusIndicator);
sendChannelEndpoint.registerForRead(controlTransportPoller);
processLoop(dataTransportPoller, 5);
}
@Test(timeout = 1000)
public void shouldSendEmptyDataFrameUnicastFromSourceToReceiver() throws Exception
{
final AtomicInteger dataHeadersReceived = new AtomicInteger(0);
doAnswer(
(invocation) ->
{
dataHeadersReceived.incrementAndGet();
return null;
})
.when(mockDispatcher).onDataPacket(
any(ReceiveChannelEndpoint.class),
any(DataHeaderFlyweight.class),
any(UnsafeBuffer.class),
anyInt(),
any(InetSocketAddress.class));
receiveChannelEndpoint = new ReceiveChannelEndpoint(
RCV_DST, mockDispatcher, mockReceiveStatusIndicator, context);
sendChannelEndpoint = new SendChannelEndpoint(SRC_DST, mockSendStatusIndicator, context);
receiveChannelEndpoint.openDatagramChannel(mockReceiveStatusIndicator);
receiveChannelEndpoint.registerForRead(dataTransportPoller);
sendChannelEndpoint.openDatagramChannel(mockSendStatusIndicator);
sendChannelEndpoint.registerForRead(controlTransportPoller);
encodeDataHeader.wrap(buffer);
encodeDataHeader
.version(HeaderFlyweight.CURRENT_VERSION)
.flags(DataHeaderFlyweight.BEGIN_AND_END_FLAGS)
.headerType(HeaderFlyweight.HDR_TYPE_DATA)
.frameLength(FRAME_LENGTH);
encodeDataHeader
.sessionId(SESSION_ID)
.streamId(STREAM_ID)
.termId(TERM_ID);
byteBuffer.position(0).limit(FRAME_LENGTH);
processLoop(dataTransportPoller, 5);
sendChannelEndpoint.send(byteBuffer);
while (dataHeadersReceived.get() < 1)
{
processLoop(dataTransportPoller, 1);
}
assertThat(dataHeadersReceived.get(), is(1));
}
@Test(timeout = 1000)
public void shouldSendMultipleDataFramesPerDatagramUnicastFromSourceToReceiver() throws Exception
{
final AtomicInteger dataHeadersReceived = new AtomicInteger(0);
doAnswer(
(invocation) ->
{
dataHeadersReceived.incrementAndGet();
return null;
})
.when(mockDispatcher).onDataPacket(
any(ReceiveChannelEndpoint.class),
any(DataHeaderFlyweight.class),
any(UnsafeBuffer.class),
anyInt(),
any(InetSocketAddress.class));
receiveChannelEndpoint = new ReceiveChannelEndpoint(
RCV_DST, mockDispatcher, mockReceiveStatusIndicator, context);
sendChannelEndpoint = new SendChannelEndpoint(SRC_DST, mockSendStatusIndicator, context);
receiveChannelEndpoint.openDatagramChannel(mockReceiveStatusIndicator);
receiveChannelEndpoint.registerForRead(dataTransportPoller);
sendChannelEndpoint.openDatagramChannel(mockSendStatusIndicator);
sendChannelEndpoint.registerForRead(controlTransportPoller);
encodeDataHeader.wrap(buffer);
encodeDataHeader
.version(HeaderFlyweight.CURRENT_VERSION)
.flags(DataHeaderFlyweight.BEGIN_AND_END_FLAGS)
.headerType(HeaderFlyweight.HDR_TYPE_DATA)
.frameLength(FRAME_LENGTH);
encodeDataHeader
.sessionId(SESSION_ID)
.streamId(STREAM_ID)
.termId(TERM_ID);
final int alignedFrameLength = BitUtil.align(FRAME_LENGTH, FrameDescriptor.FRAME_ALIGNMENT);
encodeDataHeader.wrap(buffer, alignedFrameLength, buffer.capacity() - alignedFrameLength);
encodeDataHeader
.version(HeaderFlyweight.CURRENT_VERSION)
.flags(DataHeaderFlyweight.BEGIN_AND_END_FLAGS)
.headerType(HeaderFlyweight.HDR_TYPE_DATA)
.frameLength(24);
encodeDataHeader
.sessionId(SESSION_ID)
.streamId(STREAM_ID)
.termId(TERM_ID);
byteBuffer.position(0).limit(2 * alignedFrameLength);
processLoop(dataTransportPoller, 5);
sendChannelEndpoint.send(byteBuffer);
while (dataHeadersReceived.get() < 1)
{
processLoop(dataTransportPoller, 1);
}
assertThat(dataHeadersReceived.get(), is(1));
}
@Test(timeout = 1000)
public void shouldHandleSmFrameFromReceiverToSender() throws Exception
{
final AtomicInteger controlMessagesReceived = new AtomicInteger(0);
doAnswer(
(invocation) ->
{
controlMessagesReceived.incrementAndGet();
return null;
})
.when(mockPublication).onStatusMessage(any(), any());
receiveChannelEndpoint = new ReceiveChannelEndpoint(
RCV_DST, mockDispatcher, mockReceiveStatusIndicator, context);
sendChannelEndpoint = new SendChannelEndpoint(SRC_DST, mockSendStatusIndicator, context);
sendChannelEndpoint.registerForSend(mockPublication);
receiveChannelEndpoint.openDatagramChannel(mockReceiveStatusIndicator);
receiveChannelEndpoint.registerForRead(dataTransportPoller);
sendChannelEndpoint.openDatagramChannel(mockSendStatusIndicator);
sendChannelEndpoint.registerForRead(controlTransportPoller);
statusMessage.wrap(buffer);
statusMessage
.streamId(STREAM_ID)
.sessionId(SESSION_ID)
.consumptionTermId(TERM_ID)
.receiverWindowLength(1000)
.consumptionTermOffset(0)
.version(HeaderFlyweight.CURRENT_VERSION)
.flags((short)0)
.headerType(HeaderFlyweight.HDR_TYPE_SM)
.frameLength(StatusMessageFlyweight.HEADER_LENGTH);
byteBuffer.position(0).limit(statusMessage.frameLength());
processLoop(dataTransportPoller, 5);
receiveChannelEndpoint.sendTo(byteBuffer, rcvRemoteAddress);
while (controlMessagesReceived.get() < 1)
{
processLoop(controlTransportPoller, 1);
}
verify(mockStatusMessagesReceivedCounter, times(1)).orderedIncrement();
}
private void processLoop(final UdpTransportPoller transportPoller, final int iterations) throws Exception
{
for (int i = 0; i < iterations; i++)
{
transportPoller.pollTransports();
}
}
}