package edu.washington.cs.oneswarm.f2f.servicesharing;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;
import org.gudy.azureus2.core3.util.DirectByteBuffer;
/**
* Multiplexes a stream of data, and tracks what is in
* transit across channels.
*
* @author willscott
*
*/
public class MessageStreamMultiplexer {
public final static Logger logger = Logger.getLogger(MessageStreamMultiplexer.class.getName());
private Integer next;
private final short flow;
private final HashMap<Integer, ServiceChannelEndpoint> channels;
private final HashMap<Integer, SequenceNumber> outstandingMessages;
private final HashMap<Integer, Set<SequenceNumber>> channelOutstanding;
private final static byte ss = 44;
public MessageStreamMultiplexer(short flow) {
this.channels = new HashMap<Integer, ServiceChannelEndpoint>();
this.outstandingMessages = new HashMap<Integer, SequenceNumber>();
this.channelOutstanding = new HashMap<Integer, Set<SequenceNumber>>();
this.flow = flow;
next = 0;
}
public void addChannel(ServiceChannelEndpoint s) {
this.channels.put(s.getChannelId(), s);
this.channelOutstanding.put(s.getChannelId(), new HashSet<SequenceNumber>());
}
public int onAck(OSF2FServiceDataMsg message) {
// Parse acknowledged messages
DirectByteBuffer payload = message.getPayload();
HashSet<SequenceNumber> numbers = new HashSet<SequenceNumber>();
ArrayList<Integer> retransmissions = new ArrayList<Integer>();
SequenceNumber s = outstandingMessages.get(message.getSequenceNumber());
if (s != null) {
numbers.add(s);
} else {
retransmissions.add(message.getSequenceNumber());
}
while (payload.remaining(ss) > 0) {
int num = payload.getInt(ss);
s = outstandingMessages.get(num);
if (s != null) {
numbers.add(s);
} else {
retransmissions.add(num);
}
}
for (SequenceNumber seq : numbers) {
seq.ack();
for (Integer channelId : seq.getChannels()) {
if (this.channels.get(channelId).forgetMessage(seq)) {
channelOutstanding.get(channelId).remove(seq);
seq.removeChannel(channelId);
}
}
if (seq.getChannels().size() == 0) {
outstandingMessages.remove(seq);
}
}
for (Integer num : retransmissions) {
logger.info("Non outstanding packet acked: " + num);
}
return numbers.size();
}
public SequenceNumber nextMsg() {
int num = next++;
SequenceNumber n = new SequenceNumber(num, flow);
outstandingMessages.put(num, n);
return n;
}
public void sendMsg(SequenceNumber msg, ServiceChannelEndpoint channel) {
int channelId = channel.getChannelId();
msg.addChannel(channelId);
channelOutstanding.get(channelId).add(msg);
}
public boolean hasOutstanding(ServiceChannelEndpoint channel) {
return channelOutstanding.containsKey(channel.getChannelId());
}
public Map<SequenceNumber, DirectByteBuffer> getOutstanding(final ServiceChannelEndpoint channel) {
Set<SequenceNumber> outstanding = channelOutstanding.get(channel.getChannelId());
HashMap<SequenceNumber, DirectByteBuffer> mapping = new HashMap<SequenceNumber, DirectByteBuffer>();
for (SequenceNumber s : outstanding) {
DirectByteBuffer msg = channel.getMessage(s);
if (msg != null && !s.isAcked()) {
mapping.put(s, msg);
}
}
return mapping;
}
public void removeChannel(ServiceChannelEndpoint channel) {
int channelId = channel.getChannelId();
channels.remove(channelId);
Set<SequenceNumber> inFlight = channelOutstanding.get(channelId);
if (inFlight != null) {
for (SequenceNumber s : inFlight) {
s.removeChannel(channelId);
}
channelOutstanding.remove(channelId);
}
}
}