/* * Copyright 2017 Real Logic Ltd. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.aeron.driver.ext; import io.aeron.driver.CongestionControl; import io.aeron.driver.MediaDriver; import io.aeron.driver.media.UdpChannel; import io.aeron.driver.status.PerImageIndicator; import org.agrona.CloseHelper; import org.agrona.concurrent.NanoClock; import org.agrona.concurrent.status.AtomicCounter; import org.agrona.concurrent.status.CountersManager; import java.net.InetSocketAddress; import java.util.concurrent.TimeUnit; import static io.aeron.driver.CongestionControlUtil.packOutcome; /** * CUBIC congestion control manipulation of the receiver window length. * * https://research.csc.ncsu.edu/netsrv/?q=content/bic-and-cubic * * W_cubic = C(T - K)^3 + w_max * * K = cbrt(w_max * B / C) * w_max = window size before reduction * T = time since last decrease * * C = scaling constant (default 0.4) * B = multiplicative decrease (default 0.2) * * at MTU=4K, max window=128KB (w_max = 32 MTUs), then K ~= 2.5 seconds. */ public class CubicCongestionControl implements CongestionControl { private static final boolean RTT_MEASUREMENT = CubicCongestionControlConfiguration.MEASURE_RTT; private static final boolean TCP_MODE = CubicCongestionControlConfiguration.TCP_MODE; private static final long RTT_MEASUREMENT_TIMEOUT_NS = TimeUnit.MILLISECONDS.toNanos(10); private static final long SECOND_IN_NS = TimeUnit.SECONDS.toNanos(1); private static final long RTT_MAX_TIMEOUT_NS = SECOND_IN_NS; private static final int MAX_OUTSTANDING_RTT_MEASUREMENTS = 1; private static final double C = 0.4; private static final double B = 0.2; private final int minWindow; private final int mtu; private final int maxCwnd; private long lastLossTimestampNs; private long lastUpdateTimestampNs; private long lastRttTimestamp = 0; private long windowUpdateTimeout; private long rttInNs; private double k; private int cwnd; private int w_max; private int outstandingRttMeasurements = 0; private AtomicCounter rttIndicator; private AtomicCounter windowIndicator; CubicCongestionControl( final long registrationId, final UdpChannel udpChannel, final int streamId, final int sessionId, final int termLength, final int senderMtuLength, final NanoClock clock, final MediaDriver.Context context, final CountersManager countersManager) { mtu = senderMtuLength; minWindow = senderMtuLength; final int maxWindow = Math.min(termLength / 2, context.initialWindowLength()); maxCwnd = maxWindow / mtu; cwnd = 1; w_max = maxCwnd; // initially set w_max to max window and act in the TCP and concave region initially k = Math.cbrt((double)w_max * B / C); // determine interval for adjustment based on heuristic of MTU, max window, and/or RTT estimate rttInNs = CubicCongestionControlConfiguration.INITIAL_RTT_NS; windowUpdateTimeout = rttInNs; rttIndicator = PerImageIndicator.allocate( "rcv-cc-cubic-rtt", countersManager, registrationId, sessionId, streamId, udpChannel.originalUriString(), ""); windowIndicator = PerImageIndicator.allocate( "rcv-cc-cubic-wnd", countersManager, registrationId, sessionId, streamId, udpChannel.originalUriString(), ""); rttIndicator.setOrdered(0); windowIndicator.setOrdered(minWindow); lastLossTimestampNs = clock.nanoTime(); lastUpdateTimestampNs = lastLossTimestampNs; } public boolean shouldMeasureRtt(final long nowNs) { boolean result = false; if (RTT_MEASUREMENT && outstandingRttMeasurements < MAX_OUTSTANDING_RTT_MEASUREMENTS) { if (nowNs > (lastRttTimestamp + RTT_MAX_TIMEOUT_NS)) { lastRttTimestamp = nowNs; outstandingRttMeasurements++; result = true; } else if (nowNs > (lastRttTimestamp + RTT_MEASUREMENT_TIMEOUT_NS)) { lastRttTimestamp = nowNs; outstandingRttMeasurements++; result = true; } } return result; } public void onRttMeasurement(final long nowNs, final long rttNs, final InetSocketAddress srcAddress) { outstandingRttMeasurements--; lastRttTimestamp = nowNs; this.rttInNs = rttNs; rttIndicator.setOrdered(rttNs); } public long onTrackRebuild( final long nowNs, final long newConsumptionPosition, final long lastSmPosition, final long hwmPosition, final long startingRebuildPosition, final long endingRebuildPosition, final boolean lossOccurred) { boolean forceStatusMessage = false; if (lossOccurred) { w_max = cwnd; k = Math.cbrt((double)w_max * B / C); cwnd = Math.min(1, (int)(cwnd * (1.0 - B))); lastLossTimestampNs = nowNs; forceStatusMessage = true; } else if (cwnd < maxCwnd && nowNs > (lastUpdateTimestampNs + windowUpdateTimeout)) { // W_cubic = C(T - K)^3 + w_max final double durationSinceDecr = (double)(nowNs - lastLossTimestampNs) / (double)SECOND_IN_NS; final double diffToK = durationSinceDecr - k; final double incr = C * diffToK * diffToK * diffToK; cwnd = Math.min(maxCwnd, w_max + (int)incr); // if using TCP mode, then check to see if we are in the TCP region if (TCP_MODE && cwnd < w_max) { // W_tcp(t) = w_max*(1-B) + 3*B/(2-B)* t/RTT final double rttInSeconds = (double)rttInNs / (double)SECOND_IN_NS; final double wTcp = (double)w_max * (1.0 - B) + ((3.0 * B / (2.0 * B)) * (durationSinceDecr / rttInSeconds)); cwnd = Math.max(cwnd, (int)wTcp); } lastUpdateTimestampNs = nowNs; } final int window = cwnd * mtu; windowIndicator.setOrdered(window); return packOutcome(window, forceStatusMessage); } public int initialWindowLength() { return minWindow; } public void close() { CloseHelper.close(rttIndicator); CloseHelper.close(windowIndicator); } }