package org.pac4j.saml.transport;
import com.google.common.base.Strings;
import net.shibboleth.utilities.java.support.codec.Base64Support;
import net.shibboleth.utilities.java.support.component.ComponentInitializationException;
import net.shibboleth.utilities.java.support.logic.Constraint;
import net.shibboleth.utilities.java.support.xml.ParserPool;
import net.shibboleth.utilities.java.support.xml.XMLParserException;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.io.UnmarshallingException;
import org.opensaml.core.xml.util.XMLObjectSupport;
import org.opensaml.messaging.context.MessageContext;
import org.opensaml.messaging.decoder.AbstractMessageDecoder;
import org.opensaml.messaging.decoder.MessageDecodingException;
import org.opensaml.saml.common.SAMLObject;
import org.opensaml.saml.common.binding.SAMLBindingSupport;
import org.opensaml.saml.common.messaging.context.SAMLBindingContext;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.pac4j.core.context.HttpConstants;
import org.pac4j.core.context.WebContext;
import org.pac4j.core.exception.TechnicalException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.io.UnsupportedEncodingException;
/**
* Pac4j implementation extending directly the {@link AbstractMessageDecoder} as intermediate classes use the J2E HTTP request.
* It's mostly a copy/paste of the source code of these intermediate opensaml classes.
*
* @author Misagh Moayyed
* @since 1.8
*/
public class Pac4jHTTPPostDecoder extends AbstractMessageDecoder<SAMLObject> {
private final static Logger logger = LoggerFactory.getLogger(Pac4jHTTPPostDecoder.class);
/** Parser pool used to deserialize the message. */
private ParserPool parserPool;
private final WebContext context;
public Pac4jHTTPPostDecoder(final WebContext context) {
this.context = context;
if (this.context == null) {
throw new TechnicalException("Context cannot be null");
}
}
@Override
protected void doDecode() throws MessageDecodingException {
final MessageContext messageContext = new MessageContext();
if(!HttpConstants.HTTP_METHOD.POST.name().equalsIgnoreCase(this.context.getRequestMethod())) {
throw new MessageDecodingException("This message decoder only supports the HTTP POST method");
} else {
final String relayState = this.context.getRequestParameter("RelayState");
logger.debug("Decoded SAML relay state of: {}", relayState);
SAMLBindingSupport.setRelayState(messageContext, relayState);
final InputStream base64DecodedMessage = this.getBase64DecodedMessage();
final SAMLObject inboundMessage = (SAMLObject)this.unmarshallMessage(base64DecodedMessage);
messageContext.setMessage(inboundMessage);
logger.debug("Decoded SAML message");
this.populateBindingContext(messageContext);
this.setMessageContext(messageContext);
}
}
protected InputStream getBase64DecodedMessage()
throws MessageDecodingException {
logger.debug("Getting Base64 encoded message from context, ignoring the given request");
String encodedMessage = this.context.getRequestParameter("SAMLRequest");
if(Strings.isNullOrEmpty(encodedMessage)) {
encodedMessage = this.context.getRequestParameter("SAMLResponse");
}
if(Strings.isNullOrEmpty(encodedMessage)) {
throw new MessageDecodingException("Request did not contain either a SAMLRequest or SAMLResponse parameter. Invalid request for SAML 2 HTTP POST binding.");
} else {
logger.trace("Base64 decoding SAML message:\n{}", encodedMessage);
final byte[] decodedBytes = Base64Support.decode(encodedMessage);
if(decodedBytes == null) {
throw new MessageDecodingException("Unable to Base64 decode SAML message");
} else {
try {
logger.trace("Decoded SAML message:\n{}", new String(decodedBytes, HttpConstants.UTF8_ENCODING));
} catch(final UnsupportedEncodingException e) {
throw new TechnicalException(e);
}
return new ByteArrayInputStream(decodedBytes);
}
}
}
@Override
protected void doDestroy() {
parserPool = null;
super.doDestroy();
}
@Override
protected void doInitialize() throws ComponentInitializationException {
super.doInitialize();
logger.debug("Initialized {}", this.getClass().getSimpleName());
if (parserPool == null) {
throw new ComponentInitializationException("Parser pool cannot be null");
}
}
/**
* Populate the context which carries information specific to this binding.
*
* @param messageContext the current message context
*/
protected void populateBindingContext(MessageContext<SAMLObject> messageContext) {
SAMLBindingContext bindingContext = messageContext.getSubcontext(SAMLBindingContext.class, true);
bindingContext.setBindingUri(getBindingURI());
bindingContext.setHasBindingSignature(false);
bindingContext.setIntendedDestinationEndpointURIRequired(SAMLBindingSupport.isMessageSigned(messageContext));
}
/** {@inheritDoc} */
public String getBindingURI() {
return SAMLConstants.SAML2_POST_BINDING_URI;
}
/**
* Helper method that deserializes and unmarshalls the message from the given stream.
*
* @param messageStream input stream containing the message
*
* @return the inbound message
*
* @throws MessageDecodingException thrown if there is a problem deserializing and unmarshalling the message
*/
protected XMLObject unmarshallMessage(InputStream messageStream) throws MessageDecodingException {
try {
XMLObject message = XMLObjectSupport.unmarshallFromInputStream(getParserPool(), messageStream);
return message;
} catch (XMLParserException e) {
throw new MessageDecodingException("Error unmarshalling message from input stream", e);
} catch (UnmarshallingException e) {
throw new MessageDecodingException("Error unmarshalling message from input stream", e);
}
}
/**
* Gets the parser pool used to deserialize incoming messages.
*
* @return parser pool used to deserialize incoming messages
*/
@Nonnull
public ParserPool getParserPool() {
return parserPool;
}
/**
* Sets the parser pool used to deserialize incoming messages.
*
* @param pool parser pool used to deserialize incoming messages
*/
public void setParserPool(@Nonnull final ParserPool pool) {
Constraint.isNotNull(pool, "ParserPool cannot be null");
parserPool = pool;
}
}