/*
* Copyright 2017 Real Logic Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.aeron.driver.ext;
import io.aeron.driver.CongestionControl;
import io.aeron.driver.MediaDriver;
import io.aeron.driver.media.UdpChannel;
import io.aeron.driver.status.PerImageIndicator;
import org.agrona.CloseHelper;
import org.agrona.concurrent.NanoClock;
import org.agrona.concurrent.status.AtomicCounter;
import org.agrona.concurrent.status.CountersManager;
import java.net.InetSocketAddress;
import java.util.concurrent.TimeUnit;
import static io.aeron.driver.CongestionControlUtil.packOutcome;
/**
* CUBIC congestion control manipulation of the receiver window length.
*
* https://research.csc.ncsu.edu/netsrv/?q=content/bic-and-cubic
*
* W_cubic = C(T - K)^3 + w_max
*
* K = cbrt(w_max * B / C)
* w_max = window size before reduction
* T = time since last decrease
*
* C = scaling constant (default 0.4)
* B = multiplicative decrease (default 0.2)
*
* at MTU=4K, max window=128KB (w_max = 32 MTUs), then K ~= 2.5 seconds.
*/
public class CubicCongestionControl implements CongestionControl
{
private static final boolean RTT_MEASUREMENT = CubicCongestionControlConfiguration.MEASURE_RTT;
private static final boolean TCP_MODE = CubicCongestionControlConfiguration.TCP_MODE;
private static final long RTT_MEASUREMENT_TIMEOUT_NS = TimeUnit.MILLISECONDS.toNanos(10);
private static final long SECOND_IN_NS = TimeUnit.SECONDS.toNanos(1);
private static final long RTT_MAX_TIMEOUT_NS = SECOND_IN_NS;
private static final int MAX_OUTSTANDING_RTT_MEASUREMENTS = 1;
private static final double C = 0.4;
private static final double B = 0.2;
private final int minWindow;
private final int mtu;
private final int maxCwnd;
private long lastLossTimestampNs;
private long lastUpdateTimestampNs;
private long lastRttTimestamp = 0;
private long windowUpdateTimeout;
private long rttInNs;
private double k;
private int cwnd;
private int w_max;
private int outstandingRttMeasurements = 0;
private AtomicCounter rttIndicator;
private AtomicCounter windowIndicator;
CubicCongestionControl(
final long registrationId,
final UdpChannel udpChannel,
final int streamId,
final int sessionId,
final int termLength,
final int senderMtuLength,
final NanoClock clock,
final MediaDriver.Context context,
final CountersManager countersManager)
{
mtu = senderMtuLength;
minWindow = senderMtuLength;
final int maxWindow = Math.min(termLength / 2, context.initialWindowLength());
maxCwnd = maxWindow / mtu;
cwnd = 1;
w_max = maxCwnd; // initially set w_max to max window and act in the TCP and concave region initially
k = Math.cbrt((double)w_max * B / C);
// determine interval for adjustment based on heuristic of MTU, max window, and/or RTT estimate
rttInNs = CubicCongestionControlConfiguration.INITIAL_RTT_NS;
windowUpdateTimeout = rttInNs;
rttIndicator = PerImageIndicator.allocate(
"rcv-cc-cubic-rtt",
countersManager,
registrationId,
sessionId,
streamId,
udpChannel.originalUriString(),
"");
windowIndicator = PerImageIndicator.allocate(
"rcv-cc-cubic-wnd",
countersManager,
registrationId,
sessionId,
streamId,
udpChannel.originalUriString(),
"");
rttIndicator.setOrdered(0);
windowIndicator.setOrdered(minWindow);
lastLossTimestampNs = clock.nanoTime();
lastUpdateTimestampNs = lastLossTimestampNs;
}
public boolean shouldMeasureRtt(final long nowNs)
{
boolean result = false;
if (RTT_MEASUREMENT && outstandingRttMeasurements < MAX_OUTSTANDING_RTT_MEASUREMENTS)
{
if (nowNs > (lastRttTimestamp + RTT_MAX_TIMEOUT_NS))
{
lastRttTimestamp = nowNs;
outstandingRttMeasurements++;
result = true;
}
else if (nowNs > (lastRttTimestamp + RTT_MEASUREMENT_TIMEOUT_NS))
{
lastRttTimestamp = nowNs;
outstandingRttMeasurements++;
result = true;
}
}
return result;
}
public void onRttMeasurement(final long nowNs, final long rttNs, final InetSocketAddress srcAddress)
{
outstandingRttMeasurements--;
lastRttTimestamp = nowNs;
this.rttInNs = rttNs;
rttIndicator.setOrdered(rttNs);
}
public long onTrackRebuild(
final long nowNs,
final long newConsumptionPosition,
final long lastSmPosition,
final long hwmPosition,
final long startingRebuildPosition,
final long endingRebuildPosition,
final boolean lossOccurred)
{
boolean forceStatusMessage = false;
if (lossOccurred)
{
w_max = cwnd;
k = Math.cbrt((double)w_max * B / C);
cwnd = Math.min(1, (int)(cwnd * (1.0 - B)));
lastLossTimestampNs = nowNs;
forceStatusMessage = true;
}
else if (cwnd < maxCwnd && nowNs > (lastUpdateTimestampNs + windowUpdateTimeout))
{
// W_cubic = C(T - K)^3 + w_max
final double durationSinceDecr = (double)(nowNs - lastLossTimestampNs) / (double)SECOND_IN_NS;
final double diffToK = durationSinceDecr - k;
final double incr = C * diffToK * diffToK * diffToK;
cwnd = Math.min(maxCwnd, w_max + (int)incr);
// if using TCP mode, then check to see if we are in the TCP region
if (TCP_MODE && cwnd < w_max)
{
// W_tcp(t) = w_max*(1-B) + 3*B/(2-B)* t/RTT
final double rttInSeconds = (double)rttInNs / (double)SECOND_IN_NS;
final double wTcp =
(double)w_max * (1.0 - B) + ((3.0 * B / (2.0 * B)) * (durationSinceDecr / rttInSeconds));
cwnd = Math.max(cwnd, (int)wTcp);
}
lastUpdateTimestampNs = nowNs;
}
final int window = cwnd * mtu;
windowIndicator.setOrdered(window);
return packOutcome(window, forceStatusMessage);
}
public int initialWindowLength()
{
return minWindow;
}
public void close()
{
CloseHelper.close(rttIndicator);
CloseHelper.close(windowIndicator);
}
}