/*
* Copyright (c) 2012-2013 Spotify AB
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package com.spotify.netty4.handler.codec.zmtp;
import java.util.ArrayList;
import java.util.List;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.util.ReferenceCountUtil;
/**
* Netty ZMTP encoder.
*/
class ZMTPFramingEncoder extends ChannelOutboundHandlerAdapter {
private final ZMTPEncoder encoder;
private final List<Object> messages = new ArrayList<Object>();
private final List<ChannelPromise> promises = new ArrayList<ChannelPromise>();
private ZMTPWriter writer;
private ZMTPEstimator estimator;
ZMTPFramingEncoder(final ZMTPSession session, final ZMTPEncoder encoder) {
if (session == null) {
throw new NullPointerException("session");
}
if (encoder == null) {
throw new NullPointerException("encoder");
}
this.encoder = encoder;
this.writer = ZMTPWriter.create(session.negotiatedVersion());
this.estimator = ZMTPEstimator.create(session.negotiatedVersion());
}
public ZMTPFramingEncoder(final ZMTPWireFormat wireFormat, final ZMTPEncoder encoder) {
if (wireFormat == null) {
throw new NullPointerException("wireFormat");
}
if (encoder == null) {
throw new NullPointerException("encoder");
}
this.encoder = encoder;
this.writer = new ZMTPWriter(wireFormat);
this.estimator = new ZMTPEstimator(wireFormat);
}
@Override
public void handlerRemoved(final ChannelHandlerContext ctx) {
encoder.close();
}
@Override
public void write(final ChannelHandlerContext ctx, final Object msg,
final ChannelPromise promise) {
messages.add(msg);
promises.add(promise);
}
@Override
public void flush(final ChannelHandlerContext ctx) throws Exception {
if (messages == null) {
return;
}
estimator.reset();
for (final Object message : messages) {
encoder.estimate(message, estimator);
}
final ByteBuf output = ctx.alloc().buffer(estimator.size());
writer.reset(output);
for (final Object message : messages) {
encoder.encode(message, writer);
ReferenceCountUtil.release(message);
}
final ChannelPromise aggregate = new AggregatePromise(ctx.channel(), promises);
messages.clear();
promises.clear();
ctx.write(output, aggregate);
ctx.flush();
}
private static class AggregatePromise extends DefaultChannelPromise {
private final ChannelPromise[] promises;
private AggregatePromise(final Channel channel,
final List<ChannelPromise> promises) {
super(channel);
this.promises = promises.toArray(new ChannelPromise[promises.size()]);
}
@Override
public ChannelPromise setSuccess(final Void result) {
super.setSuccess(result);
for (final ChannelPromise promise : promises) {
promise.setSuccess(result);
}
return this;
}
@Override
public boolean trySuccess() {
final boolean result = super.trySuccess();
for (final ChannelPromise promise : promises) {
promise.trySuccess();
}
return result;
}
@Override
public ChannelPromise setFailure(final Throwable cause) {
super.setFailure(cause);
for (final ChannelPromise promise : promises) {
promise.setFailure(cause);
}
return this;
}
}
}