package com.kixeye.kixmpp; /* * #%L * KIXMPP Parent * %% * Copyright (C) 2014 KIXEYE, Inc * %% * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * #L% */ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufOutputStream; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageCodec; import java.nio.charset.StandardCharsets; import java.util.List; import javax.xml.stream.XMLStreamConstants; import org.jdom2.Element; import org.jdom2.output.XMLOutputter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.fasterxml.aalto.AsyncInputFeeder; import com.fasterxml.aalto.AsyncXMLStreamReader; import com.fasterxml.aalto.stax.InputFactoryImpl; import com.kixeye.kixmpp.jdom.StAXElementBuilder; /** * An XMPP codec for the client. * * @author ebahtijaragic */ public class KixmppCodec extends ByteToMessageCodec<Object> { private static final int STANZA_ELEMENT_DEPTH = 2; private static final Logger logger = LoggerFactory.getLogger(KixmppCodec.class); private StAXElementBuilder elementBuilder = null; private InputFactoryImpl inputFactory = new InputFactoryImpl(); private AsyncXMLStreamReader streamReader = inputFactory.createAsyncXMLStreamReader(); private AsyncInputFeeder asyncInputFeeder = streamReader.getInputFeeder(); public enum XMLStreamReaderConfiguration { SPEED, LOW_MEMORY_USAGE, ROUND_TRIPPING, CONVENIENCE, XML_CONFORMANCE } /** * Creates a new codec and optimizes the parser for speed. */ public KixmppCodec() { this(XMLStreamReaderConfiguration.SPEED); } /** * Creates a new codec and optimizes the parser based on the configuration flag. * * @param configuration tells the codec how to optimize the XMLStreamReader */ public KixmppCodec(XMLStreamReaderConfiguration configuration) { switch (configuration) { case CONVENIENCE: inputFactory.configureForConvenience(); break; case LOW_MEMORY_USAGE: inputFactory.configureForLowMemUsage(); break; case ROUND_TRIPPING: inputFactory.configureForRoundTripping(); break; case SPEED: inputFactory.configureForSpeed(); break; case XML_CONFORMANCE: inputFactory.configureForXmlConformance(); break; } } /** * @see io.netty.handler.codec.ByteToMessageCodec#decode(io.netty.channel.ChannelHandlerContext, io.netty.buffer.ByteBuf, java.util.List) */ @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { if (logger.isDebugEnabled()) { logger.debug("Received: [{}]", in.toString(StandardCharsets.UTF_8)); } int retryCount = 0; Exception thrownException = null; // feed the data into the async xml input feeded byte[] data = new byte[in.readableBytes()]; in.readBytes(data); if (streamReader != null) { while (retryCount < 2) { try { asyncInputFeeder.feedInput(data, 0, data.length); int event = -1; while (isValidEvent(event = streamReader.next())) { // handle stream start/end if (streamReader.getDepth() == STANZA_ELEMENT_DEPTH - 1) { if (event == XMLStreamConstants.END_ELEMENT) { out.add(new KixmppStreamEnd()); asyncInputFeeder.endOfInput(); streamReader.close(); streamReader = null; asyncInputFeeder = null; break; } else if (event == XMLStreamConstants.START_ELEMENT) { StAXElementBuilder streamElementBuilder = new StAXElementBuilder(true); streamElementBuilder.process(streamReader); out.add(new KixmppStreamStart(null, true)); } // only handle events that have element depth of 2 and above (everything under <stream:stream>..) } else if (streamReader.getDepth() >= STANZA_ELEMENT_DEPTH) { // if this is the beginning of the element and this is at stanza depth if (event == XMLStreamConstants.START_ELEMENT && streamReader.getDepth() == STANZA_ELEMENT_DEPTH) { elementBuilder = new StAXElementBuilder(true); elementBuilder.process(streamReader); // get the constructed element Element element = elementBuilder.getElement(); if ("stream:stream".equals(element.getQualifiedName())) { throw new RuntimeException("Starting a new stream."); } // if this is the ending of the element and this is at stanza depth } else if (event == XMLStreamConstants.END_ELEMENT && streamReader.getDepth() == STANZA_ELEMENT_DEPTH) { elementBuilder.process(streamReader); // get the constructed element Element element = elementBuilder.getElement(); out.add(element); // just process the event } else { elementBuilder.process(streamReader); } } } break; } catch (Exception e) { retryCount++; logger.info("Attempting to recover from impropper XML: " + e.getMessage()); thrownException = e; try { streamReader.close(); } finally { streamReader = inputFactory.createAsyncXMLStreamReader(); asyncInputFeeder = streamReader.getInputFeeder(); } } } if (retryCount > 1) { throw thrownException; } } } /** * @param event the event id * @return <b>true</b> if this event is not the end of a document event and it is not an event incomplete event */ private boolean isValidEvent(int event) { return (event != XMLStreamConstants.END_DOCUMENT && event != AsyncXMLStreamReader.EVENT_INCOMPLETE); } /** * @see io.netty.handler.codec.ByteToMessageCodec#encode(io.netty.channel.ChannelHandlerContext, java.lang.Object, io.netty.buffer.ByteBuf) */ @Override protected void encode(ChannelHandlerContext ctx, Object msg, ByteBuf out) throws Exception { if (msg instanceof Element) { new XMLOutputter().output((Element)msg, new ByteBufOutputStream(out)); } else if (msg instanceof KixmppStreamStart) { KixmppStreamStart streamStart = (KixmppStreamStart)msg; if (streamStart.doesIncludeXmlHeader()) { out.writeBytes("<?xml version='1.0' encoding='UTF-8'?>".getBytes(StandardCharsets.UTF_8)); } out.writeBytes("<stream:stream ".getBytes(StandardCharsets.UTF_8)); if (streamStart.getId() != null) { out.writeBytes(String.format("id=\"%s\" ", streamStart.getId()).getBytes(StandardCharsets.UTF_8)); } if (streamStart.getFrom() != null) { out.writeBytes(String.format("from=\"%s\" ", streamStart.getFrom().getFullJid()).getBytes(StandardCharsets.UTF_8)); } if (streamStart.getTo() != null) { out.writeBytes(String.format("to=\"%s\" ", streamStart.getTo().getFullJid()).getBytes(StandardCharsets.UTF_8)); } out.writeBytes("version=\"1.0\" xmlns=\"jabber:client\" xmlns:stream=\"http://etherx.jabber.org/streams\">".getBytes(StandardCharsets.UTF_8)); } else if (msg instanceof KixmppStreamEnd) { out.writeBytes("</stream:stream>".getBytes(StandardCharsets.UTF_8)); } else if (msg instanceof String) { out.writeBytes(((String)msg).getBytes(StandardCharsets.UTF_8)); } else if (msg instanceof ByteBuf) { ByteBuf buf = (ByteBuf)msg; out.writeBytes(buf, 0, buf.readableBytes()); } if (logger.isDebugEnabled()) { logger.debug("Sending: [{}]", out.toString(StandardCharsets.UTF_8)); } } }