package org.pac4j.saml.transport;
import net.shibboleth.utilities.java.support.codec.Base64Support;
import net.shibboleth.utilities.java.support.codec.HTMLEncoder;
import net.shibboleth.utilities.java.support.component.ComponentInitializationException;
import net.shibboleth.utilities.java.support.component.ComponentSupport;
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
import org.apache.velocity.VelocityContext;
import org.apache.velocity.app.VelocityEngine;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.io.MarshallingException;
import org.opensaml.core.xml.util.XMLObjectSupport;
import org.opensaml.messaging.context.MessageContext;
import org.opensaml.messaging.encoder.AbstractMessageEncoder;
import org.opensaml.messaging.encoder.MessageEncodingException;
import org.opensaml.saml.common.SAMLObject;
import org.opensaml.saml.common.binding.BindingException;
import org.opensaml.saml.common.binding.SAMLBindingSupport;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.core.RequestAbstractType;
import org.opensaml.saml.saml2.core.StatusResponseType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Element;
import java.io.OutputStreamWriter;
import java.io.UnsupportedEncodingException;
import java.net.URI;
/**
* Pac4j implementation extending directly the {@link AbstractMessageEncoder} as intermediate classes use the J2E HTTP response.
* It's mostly a copy/paste of the source code of these intermediate opensaml classes.
*
* @author Misagh Moayyed
* @since 1.8
*/
public class Pac4jHTTPPostEncoder extends AbstractMessageEncoder<SAMLObject> {
private final static Logger log = LoggerFactory.getLogger(Pac4jHTTPPostEncoder.class);
/** Default template ID. */
public static final String DEFAULT_TEMPLATE_ID = "/templates/saml2-post-binding.vm";
/** Velocity engine used to evaluate the template when performing POST encoding. */
private VelocityEngine velocityEngine;
/** ID of the Velocity template used when performing POST encoding. */
private String velocityTemplateId;
private final Pac4jSAMLResponse responseAdapter;
public Pac4jHTTPPostEncoder(final Pac4jSAMLResponse responseAdapter) {
this.responseAdapter = responseAdapter;
setVelocityTemplateId(DEFAULT_TEMPLATE_ID);
}
/**
* Set the Velocity template id.
*
* <p>Defaults to {@link #DEFAULT_TEMPLATE_ID}.</p>
*
* @param newVelocityTemplateId the new Velocity template id
*/
public void setVelocityTemplateId(String newVelocityTemplateId) {
ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
ComponentSupport.ifDestroyedThrowDestroyedComponentException(this);
velocityTemplateId = newVelocityTemplateId;
}
@Override
protected void doDestroy() {
velocityEngine = null;
velocityTemplateId = null;
super.doDestroy();
}
@Override
protected void doInitialize() throws ComponentInitializationException {
super.doInitialize();
log.debug("Initialized {}", this.getClass().getSimpleName());
if (velocityEngine == null) {
throw new ComponentInitializationException("VelocityEngine must be supplied");
}
if (velocityTemplateId == null) {
throw new ComponentInitializationException("Velocity template id must be supplied");
}
}
@Override
protected void doEncode() throws MessageEncodingException {
MessageContext<SAMLObject> messageContext = getMessageContext();
SAMLObject outboundMessage = messageContext.getMessage();
if (outboundMessage == null) {
throw new MessageEncodingException("No outbound SAML message contained in message context");
}
String endpointURL = getEndpointURL(messageContext).toString();
postEncode(messageContext, endpointURL);
}
/**
* Gets the response URL from the message context.
*
* @param messageContext current message context
*
* @return response URL from the message context
*
* @throws MessageEncodingException throw if no relying party endpoint is available
*/
protected URI getEndpointURL(MessageContext<SAMLObject> messageContext) throws MessageEncodingException {
try {
return SAMLBindingSupport.getEndpointURL(messageContext);
} catch (BindingException e) {
throw new MessageEncodingException("Could not obtain message endpoint URL", e);
}
}
protected void postEncode(final MessageContext<SAMLObject> messageContext, final String endpointURL) throws MessageEncodingException {
log.debug("Invoking Velocity template to create POST body");
try {
final VelocityContext e = new VelocityContext();
this.populateVelocityContext(e, messageContext, endpointURL);
responseAdapter.setContentType("text/html");
responseAdapter.init();
final OutputStreamWriter out = responseAdapter.getOutputStreamWriter();
this.getVelocityEngine().mergeTemplate(this.getVelocityTemplateId(), "UTF-8", e, out);
out.flush();
} catch (Exception var6) {
throw new MessageEncodingException("Error creating output document", var6);
}
}
/**
* Get the Velocity template id.
*
* <p>Defaults to {@link #DEFAULT_TEMPLATE_ID}.</p>
*
* @return return the Velocity template id
*/
public String getVelocityTemplateId() {
return velocityTemplateId;
}
/**
* Get the VelocityEngine instance.
*
* @return return the VelocityEngine instance
*/
public VelocityEngine getVelocityEngine() {
return velocityEngine;
}
/**
* Populate the Velocity context instance which will be used to render the POST body.
*
* @param velocityContext the Velocity context instance to populate with data
* @param messageContext the SAML message context source of data
* @param endpointURL endpoint URL to which to encode message
* @throws MessageEncodingException thrown if there is a problem encoding the message
*/
protected void populateVelocityContext(VelocityContext velocityContext, MessageContext<SAMLObject> messageContext,
String endpointURL) throws MessageEncodingException {
String encodedEndpointURL = HTMLEncoder.encodeForHTMLAttribute(endpointURL);
log.debug("Encoding action url of '{}' with encoded value '{}'", endpointURL, encodedEndpointURL);
velocityContext.put("action", encodedEndpointURL);
velocityContext.put("binding", getBindingURI());
SAMLObject outboundMessage = messageContext.getMessage();
log.debug("Marshalling and Base64 encoding SAML message");
Element domMessage = marshallMessage(outboundMessage);
try {
String messageXML = SerializeSupport.nodeToString(domMessage);
log.trace("Output XML message: {}", messageXML);
String encodedMessage = Base64Support.encode(messageXML.getBytes("UTF-8"), Base64Support.UNCHUNKED);
if (outboundMessage instanceof RequestAbstractType) {
velocityContext.put("SAMLRequest", encodedMessage);
} else if (outboundMessage instanceof StatusResponseType) {
velocityContext.put("SAMLResponse", encodedMessage);
} else {
throw new MessageEncodingException(
"SAML message is neither a SAML RequestAbstractType or StatusResponseType");
}
} catch (UnsupportedEncodingException e) {
throw new MessageEncodingException("Unable to encode message, UTF-8 encoding is not supported");
}
String relayState = SAMLBindingSupport.getRelayState(messageContext);
if (SAMLBindingSupport.checkRelayState(relayState)) {
String encodedRelayState = HTMLEncoder.encodeForHTMLAttribute(relayState);
log.debug("Setting RelayState parameter to: '{}', encoded as '{}'", relayState, encodedRelayState);
velocityContext.put("RelayState", encodedRelayState);
}
}
/**
* Set the VelocityEngine instance.
*
* @param newVelocityEngine the new VelocityEngine instane
*/
public void setVelocityEngine(VelocityEngine newVelocityEngine) {
ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
ComponentSupport.ifDestroyedThrowDestroyedComponentException(this);
velocityEngine = newVelocityEngine;
}
/** {@inheritDoc} */
public String getBindingURI() {
return SAMLConstants.SAML2_POST_BINDING_URI;
}
/**
* Helper method that marshalls the given message.
*
* @param message message the marshall and serialize
*
* @return marshalled message
*
* @throws MessageEncodingException thrown if the give message can not be marshalled into its DOM representation
*/
protected Element marshallMessage(XMLObject message) throws MessageEncodingException {
log.debug("Marshalling message");
try {
return XMLObjectSupport.marshall(message);
} catch (MarshallingException e) {
throw new MessageEncodingException("Error marshalling message", e);
}
}
}