/*
* Copyright 2010 Martin Grotzke
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package de.javakaffee.web.msm;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.UnsupportedEncodingException;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentMap;
import javax.servlet.http.HttpServletRequest;
import org.apache.catalina.Session;
import org.apache.catalina.authenticator.Constants;
import org.apache.catalina.authenticator.SavedRequest;
import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;
import org.apache.tomcat.util.buf.ByteChunk;
import de.javakaffee.web.msm.MemcachedSessionService.SessionManager;
/**
* This service is responsible for serializing/deserializing session data
* so that this can be stored in / loaded from memcached.
*
* @author <a href="mailto:martin.grotzke@javakaffee.de">Martin Grotzke</a>
*/
public class TranscoderService {
private static final Log LOG = LogFactory.getLog( TranscoderService.class );
public static final short VERSION_1 = 1;
public static final short VERSION_2 = 2;
static final int NUM_BYTES = 8 // creationTime: long
+ 8 // lastAccessedTime: long
+ 4 // maxInactiveInterval: int
+ 1 // isNew: boolean
+ 1 // isValid: boolean
+ 8 // thisAccessedTime
+ 8; // lastBackupTime
private final SessionAttributesTranscoder _attributesTranscoder;
/**
* Creates a new {@link TranscoderService}.
*
* @param attributesTranscoder the {@link SessionAttributesTranscoder} strategy to use.
*/
public TranscoderService( final SessionAttributesTranscoder attributesTranscoder ) {
_attributesTranscoder = attributesTranscoder;
}
/**
* Serialize the given session to a byte array. This is a shortcut for
* <code><pre>
* final byte[] attributesData = serializeAttributes( session, session.getAttributes() );
* serialize( session, attributesData );
* </pre></code>
* The returned byte array can be deserialized using {@link #deserialize(byte[], SessionManager)}.
*
* @see #serializeAttributes(MemcachedBackupSession, ConcurrentMap)
* @see #serialize(MemcachedBackupSession, byte[])
* @see #deserialize(byte[], SessionManager)
* @param session the session to serialize.
* @return the serialized session data.
*/
public byte[] serialize( final MemcachedBackupSession session ) {
final byte[] attributesData = serializeAttributes( session, session.getAttributesInternal() );
return serialize( session, attributesData );
}
/**
* Deserialize session data that was serialized using {@link #serialize(MemcachedBackupSession)}
* (or a combination of {@link #serializeAttributes(MemcachedBackupSession, ConcurrentMap)} and
* {@link #serialize(MemcachedBackupSession, byte[])}).
* <p>
* Note: the returned session already has the manager set and
* {@link MemcachedBackupSession#doAfterDeserialization()} is invoked. Additionally
* the attributes hash is set (via {@link MemcachedBackupSession#setDataHashCode(int)}).
* </p>
*
* @param data the byte array of the serialized session and its session attributes. Can be <code>null</code>.
* @param manager the manager to set on the deserialized session.
*
* @return the deserialized {@link MemcachedBackupSession}
* or <code>null</code> if the provided <code>byte[] data</code> was <code>null</code>.
*/
public MemcachedBackupSession deserialize( final byte[] data, final SessionManager manager ) {
if ( data == null ) {
return null;
}
try {
final DeserializationResult deserializationResult = deserializeSessionFields( data, manager );
final byte[] attributesData = deserializationResult.getAttributesData();
final ConcurrentMap<String, Object> attributes = deserializeAttributes( attributesData );
final MemcachedBackupSession session = deserializationResult.getSession();
session.setAttributesInternal( attributes );
session.setDataHashCode( Arrays.hashCode( attributesData ) );
session.setManager( manager );
session.doAfterDeserialization();
return session;
} catch( final InvalidVersionException e ) {
LOG.info( "Got session data from memcached with an unsupported version: " + e.getVersion() );
// for versioning probably there will be changes in the design,
// with the first change and version 2 we'll see what we need
return null;
}
}
/**
* Serialize the given session attributes to a byte array, this is delegated
* to {@link SessionAttributesTranscoder#serializeAttributes(MemcachedBackupSession, ConcurrentMap)} (using
* the {@link SessionAttributesTranscoder} provided in the constructor of this class).
*
* @param session the session that owns the given attributes.
* @param attributes the attributes to serialize.
* @return a byte array representing the serialized attributes.
*
* @see de.javakaffee.web.msm.SessionAttributesTranscoder#serializeAttributes(MemcachedBackupSession, ConcurrentMap)
*/
public byte[] serializeAttributes( final MemcachedBackupSession session, final ConcurrentMap<String, Object> attributes ) {
return _attributesTranscoder.serializeAttributes( session, attributes );
}
/**
* Deserialize the given byte array to session attributes, this is delegated
* to {@link SessionAttributesTranscoder#deserializeAttributes(byte[])} (using
* the {@link SessionAttributesTranscoder} provided in the constructor of this class).
*
* @param data the serialized attributes
* @return the deserialized attributes
*
* @see de.javakaffee.web.msm.SessionAttributesTranscoder#deserializeAttributes(byte[])
*/
public ConcurrentMap<String, Object> deserializeAttributes(final byte[] data ) {
return _attributesTranscoder.deserializeAttributes( data );
}
/**
* Serialize session fields to a byte[] and create a byte[] containing both the
* serialized byte[] of the session fields and the provided byte[] of the serialized
* session attributes.
*
* @param session its fields will be serialized to a byte[]
* @param attributesData the serialized session attributes (e.g. from {@link #serializeAttributes(MemcachedBackupSession, ConcurrentMap)})
* @return a byte[] containing both the serialized session fields and the provided serialized session attributes
*/
public byte[] serialize( final MemcachedBackupSession session, final byte[] attributesData ) {
final byte[] sessionData = serializeSessionFields( session );
final byte[] result = new byte[ sessionData.length + attributesData.length ];
System.arraycopy( sessionData, 0, result, 0, sessionData.length );
System.arraycopy( attributesData, 0, result, sessionData.length, attributesData.length );
return result;
}
// --------------------- private/protected helper methods -------------------
static byte[] serializeSessionFields( final MemcachedBackupSession session ) {
return serializeSessionFields(session, VERSION_2);
}
static byte[] serializeSessionFields( final MemcachedBackupSession session, final int version ) {
final byte[] idData = serializeId( session.getIdInternal() );
final byte[] principalData = serializePrincipal( session.getPrincipal(), session.getManager() );
final int principalDataLength = principalData != null ? principalData.length : 0;
final byte[] savedRequestData = serializeSavedRequest(session.getNote(Constants.FORM_REQUEST_NOTE));
final int savedRequestDataLength = savedRequestData != null ? savedRequestData.length : 0;
final byte[] savedPrincipalData = serializePrincipal((Principal) session.getNote(Constants.FORM_PRINCIPAL_NOTE), session.getManager());
final int savedPrincipalDataLength = savedPrincipalData != null ? savedPrincipalData.length : 0;
int sessionFieldsDataLength = 2 // short value for the version
// the following might change with other versions, refactoring needed then
+ 2 // short value that stores the dataLength
+ NUM_BYTES // bytes that store all session attributes but the id
+ 2 // short value that stores the idData length
+ idData.length // the number of bytes for the id
+ 2 // short value for the authType
+ 2 // short value that stores the principalData length
+ principalDataLength; // the number of bytes for the principal
if(version > VERSION_1) {
sessionFieldsDataLength = sessionFieldsDataLength
+ 2 // short value that stores the savedRequestData length
+ savedRequestDataLength // the number of bytes for the savedRequest
+ 2 // short value that stores the savedPrincipalData length
+ savedPrincipalDataLength; // the number of bytes for the savedPrincipal
}
final byte[] data = new byte[sessionFieldsDataLength];
int idx = 0;
idx = encodeNum( version, data, idx, 2 );
idx = encodeNum( sessionFieldsDataLength, data, idx, 2 );
idx = encodeNum( session.getCreationTimeInternal(), data, idx, 8 );
idx = encodeNum( session.getLastAccessedTimeInternal(), data, idx, 8 );
idx = encodeNum( session.getMaxInactiveInterval(), data, idx, 4 );
idx = encodeBoolean( session.isNewInternal(), data, idx );
idx = encodeBoolean( session.isValidInternal(), data, idx );
idx = encodeNum( session.getThisAccessedTimeInternal(), data, idx, 8 );
idx = encodeNum( session.getLastBackupTime(), data, idx, 8 );
idx = encodeNum( idData.length, data, idx, 2 );
idx = copy( idData, data, idx );
idx = encodeNum( AuthType.valueOfValue( session.getAuthType() ).getId(), data, idx, 2 );
idx = encodeNum( principalDataLength, data, idx, 2 );
idx = copy( principalData, data, idx );
if(version > VERSION_1) {
idx = encodeNum( savedRequestDataLength, data, idx, 2 );
idx = copy( savedRequestData, data, idx );
idx = encodeNum( savedPrincipalDataLength, data, idx, 2 );
idx = copy( savedPrincipalData, data, idx );
}
return data;
}
static DeserializationResult deserializeSessionFields( final byte[] data, final SessionManager manager ) throws InvalidVersionException {
final MemcachedBackupSession result = manager.newMemcachedBackupSession();
final short version = (short) decodeNum( data, 0, 2 );
if ( version != VERSION_1 && version != VERSION_2 ) {
throw new InvalidVersionException( "The version " + version + " does not match the current version " + VERSION_2, version );
}
final short sessionFieldsDataLength = (short) decodeNum( data, 2, 2 );
result.setCreationTimeInternal( decodeNum( data, 4, 8 ) );
result.setLastAccessedTimeInternal( decodeNum( data, 12, 8 ) );
result.setMaxInactiveInterval( (int) decodeNum( data, 20, 4 ) );
result.setIsNewInternal( decodeBoolean( data, 24 ) );
result.setIsValidInternal( decodeBoolean( data, 25 ) );
result.setThisAccessedTimeInternal( decodeNum( data, 26, 8 ) );
result.setLastBackupTime( decodeNum( data, 34, 8 ) );
final short idLength = (short) decodeNum( data, 42, 2 );
result.setIdInternal( decodeString( data, 44, idLength ) );
final short authTypeId = (short)decodeNum( data, 44 + idLength, 2 );
result.setAuthTypeInternal( AuthType.valueOfId( authTypeId ).getValue() );
int currentIdx = 44 + idLength + 2;
final short principalDataLength = (short) decodeNum( data, currentIdx, 2 );
if ( principalDataLength > 0 ) {
final byte[] principalData = new byte[principalDataLength];
System.arraycopy( data, currentIdx + 2, principalData, 0, principalDataLength );
result.setPrincipalInternal( deserializePrincipal( principalData, manager ) );
}
if( version > VERSION_1 ) {
currentIdx += 2 + principalDataLength;
final short savedRequestDataLength = (short) decodeNum( data, currentIdx, 2 );
if ( savedRequestDataLength > 0 ) {
final byte[] savedRequestData = new byte[savedRequestDataLength];
System.arraycopy( data, currentIdx + 2, savedRequestData, 0, savedRequestDataLength );
result.setNote( Constants.FORM_REQUEST_NOTE, deserializeSavedRequest( savedRequestData ) );
}
currentIdx += 2 + savedRequestDataLength;
final short savedPrincipalDataLength = (short) decodeNum( data, currentIdx, 2 );
if ( savedPrincipalDataLength > 0 ) {
final byte[] savedPrincipalData = new byte[savedPrincipalDataLength];
System.arraycopy( data, currentIdx + 2, savedPrincipalData, 0, savedPrincipalDataLength );
result.setNote( Constants.FORM_PRINCIPAL_NOTE, deserializePrincipal( savedPrincipalData, manager ) );
}
}
final byte[] attributesData = new byte[ data.length - sessionFieldsDataLength ];
System.arraycopy( data, sessionFieldsDataLength, attributesData, 0, data.length - sessionFieldsDataLength );
return new DeserializationResult( result, attributesData );
}
static class DeserializationResult {
private final MemcachedBackupSession _session;
private final byte[] _attributesData;
DeserializationResult( final MemcachedBackupSession session, final byte[] attributesData ) {
_session = session;
_attributesData = attributesData;
}
/**
* @return the session with fields initialized apart from the attributes.
*/
MemcachedBackupSession getSession() {
return _session;
}
/**
* The serialized session attributes.
* @return the byte array representing the serialized session attributes.
*/
byte[] getAttributesData() {
return _attributesData;
}
}
private static byte[] serializeId( final String id ) {
try {
return id.getBytes( "UTF-8" );
} catch ( final UnsupportedEncodingException e ) {
throw new RuntimeException( e );
}
}
private static byte[] serializePrincipal( final Principal principal, final SessionManager manager ) {
if(principal == null) {
return null;
}
ByteArrayOutputStream bos = null;
ObjectOutputStream oos = null;
try {
bos = new ByteArrayOutputStream();
oos = new ObjectOutputStream( bos );
manager.writePrincipal(principal, oos);
oos.flush();
return bos.toByteArray();
} catch ( final IOException e ) {
throw new IllegalArgumentException( "Non-serializable object", e );
} finally {
closeSilently( bos );
closeSilently( oos );
}
}
private static Principal deserializePrincipal( final byte[] data, final SessionManager manager ) {
ByteArrayInputStream bis = null;
ObjectInputStream ois = null;
try {
bis = new ByteArrayInputStream( data );
ois = new ObjectInputStream( bis );
return manager.readPrincipal( ois );
} catch ( final IOException e ) {
throw new IllegalArgumentException( "Could not deserialize principal", e );
} catch ( final ClassNotFoundException e ) {
throw new IllegalArgumentException( "Could not deserialize principal", e );
} finally {
closeSilently( bis );
closeSilently( ois );
}
}
private static byte[] serializeSavedRequest( final Object obj ) {
if(obj == null) {
return null;
}
final SavedRequest savedRequest = (SavedRequest) obj;
ByteArrayOutputStream bos = null;
ObjectOutputStream oos = null;
try {
bos = new ByteArrayOutputStream();
oos = new ObjectOutputStream( bos );
oos.writeObject(savedRequest.getBody());
oos.writeObject(savedRequest.getContentType());
// Cookies not cloneable... omit for now - oos.writeObject(newArrayList(savedRequest.getCookies()));
oos.writeObject(getHeaders(savedRequest));
oos.writeObject(newArrayList(savedRequest.getLocales()));
oos.writeObject(savedRequest.getMethod());
// obj.getParameters() are not used in tc6 and not existing in tc7
// -> we omit them here
oos.writeObject(savedRequest.getQueryString());
oos.writeObject(savedRequest.getRequestURI());
oos.writeObject(savedRequest.getDecodedRequestURI());
oos.flush();
return bos.toByteArray();
} catch ( final IOException e ) {
throw new IllegalArgumentException( "Non-serializable object", e );
} finally {
closeSilently( bos );
closeSilently( oos );
}
}
@SuppressWarnings("unchecked")
private static SavedRequest deserializeSavedRequest( final byte[] data ) {
ByteArrayInputStream bis = null;
ObjectInputStream ois = null;
try {
bis = new ByteArrayInputStream( data );
ois = new ObjectInputStream( bis );
final SavedRequest savedRequest = new SavedRequest();
savedRequest.setBody((ByteChunk) ois.readObject());
savedRequest.setContentType((String) ois.readObject());
// no cookies support setCookies(savedRequest, ois.readObject());
setHeaders(savedRequest, (Map<String, List<String>>) ois.readObject());
setLocales(savedRequest, (List<Locale>) ois.readObject());
savedRequest.setMethod((String) ois.readObject());
savedRequest.setQueryString((String) ois.readObject());
savedRequest.setRequestURI((String) ois.readObject());
savedRequest.setDecodedRequestURI((String) ois.readObject());
return savedRequest;
} catch ( final IOException e ) {
throw new IllegalArgumentException( "Could not deserialize SavedRequest", e );
} catch ( final ClassNotFoundException e ) {
throw new IllegalArgumentException( "Could not deserialize SavedRequest", e );
} finally {
closeSilently( bis );
closeSilently( ois );
}
}
private static void setLocales(final SavedRequest savedRequest, final List<Locale> locales) {
if(locales != null && !locales.isEmpty()) {
for (final Locale locale : locales) {
savedRequest.addLocale(locale);
}
}
}
private static <T> List<T> newArrayList(final Iterator<T> iter) {
if(!iter.hasNext()) {
return Collections.emptyList();
}
final List<T> result = new ArrayList<T>();
while (iter.hasNext()) {
result.add(iter.next());
}
return result;
}
private static Map<String, List<String>> getHeaders(final SavedRequest obj) {
final Map<String, List<String>> result = new HashMap<String, List<String>>();
final Iterator<String> namesIter = obj.getHeaderNames();
while (namesIter.hasNext()) {
final String name = namesIter.next();
final List<String> values = new ArrayList<String>();
result.put(name, values);
final Iterator<String> valuesIter = obj.getHeaderValues(name);
while (valuesIter.hasNext()) {
final String value = valuesIter.next();
values.add(value);
}
}
return result;
}
private static void setHeaders(final SavedRequest obj, final Map<String, List<String>> headers) {
if(headers != null) {
for (final Entry<String, List<String>> entry : headers.entrySet()) {
final List<String> values = entry.getValue();
for (final String value : values) {
obj.addHeader(entry.getKey(), value);
}
}
}
}
/**
* Convert a number to bytes (with length of maxBytes) and write bytes into
* the provided byte[] data starting at the specified beginIndex.
*
* @param num
* the number to encode
* @param data
* the byte array into that the number is encoded
* @param beginIndex
* the beginning index of data where to start encoding,
* inclusive.
* @param maxBytes
* the number of bytes to store for the number
* @return the next beginIndex (<code>beginIndex + maxBytes</code>).
*/
public static int encodeNum( final long num, final byte[] data, final int beginIndex, final int maxBytes ) {
for ( int i = 0; i < maxBytes; i++ ) {
final int pos = maxBytes - i - 1; // the position of the byte in the number
final int idx = beginIndex + pos; // the index in the data array
data[idx] = (byte) ( ( num >> ( 8 * i ) ) & 0xff );
}
return beginIndex + maxBytes;
}
public static long decodeNum( final byte[] data, final int beginIndex, final int numBytes ) {
long result = 0;
for ( int i = 0; i < numBytes; i++ ) {
final byte b = data[beginIndex + i];
result = ( result << 8 ) | ( b < 0
? 256 + b
: b );
}
return result;
}
/**
* Encode a boolean that can be decoded with {@link #decodeBoolean(byte[], int)}.
* @param b the boolean value
* @param data the byte array where to write the encoded byte(s) to
* @param index the start index in the byte array for writing.
* @return the incremented index that can be used next.
*/
private static int encodeBoolean( final boolean b, final byte[] data, final int index ) {
data[index] = (byte) ( b
? '1'
: '0' );
return index + 1;
}
private static boolean decodeBoolean( final byte[] in, final int index ) {
return in[index] == '1';
}
private static String decodeString( final byte[] data, final int beginIndex, final int length ) {
try {
final byte[] idData = new byte[length];
System.arraycopy( data, beginIndex, idData, 0, length );
return new String( idData, "UTF-8" );
} catch ( final UnsupportedEncodingException e ) {
throw new RuntimeException( e );
}
}
protected static int copy( final byte[] src, final byte[] dest, final int destBeginIndex ) {
if ( src == null ) {
return destBeginIndex;
}
System.arraycopy( src, 0, dest, destBeginIndex, src.length );
return destBeginIndex + src.length;
}
private static void closeSilently( final OutputStream os ) {
if ( os != null ) {
try {
os.close();
} catch ( final IOException f ) {
// fail silently
}
}
}
private static void closeSilently( final InputStream is ) {
if ( is != null ) {
try {
is.close();
} catch ( final IOException f ) {
// fail silently
}
}
}
/**
* The enum representing id/string mappings for the {@link Session#getAuthType()}
* with values defined in {@link Constants}.
*/
private static enum AuthType {
NONE( (short)0, null ),
BASIC( (short)1, HttpServletRequest.BASIC_AUTH ),
CLIENT_CERT( (short)2, HttpServletRequest.CLIENT_CERT_AUTH ),
DIGEST( (short)3, HttpServletRequest.DIGEST_AUTH ),
FORM( (short)4, HttpServletRequest.FORM_AUTH );
private final short _id;
private final String _value;
private AuthType( final short id, final String value ) {
_id = id;
_value = value;
}
static AuthType valueOfId( final short id ) {
for( final AuthType authType : values() ) {
if ( id == authType._id ) {
return authType;
}
}
throw new IllegalArgumentException( "No AuthType found for id " + id );
}
static AuthType valueOfValue( final String value ) {
for( final AuthType authType : values() ) {
if ( value == null && authType._value == null
|| value != null && value.equals( authType._value )) {
return authType;
}
}
throw new IllegalArgumentException( "No AuthType found for value " + value );
}
/**
* @return the id
*/
short getId() {
return _id;
}
/**
* @return the value
*/
String getValue() {
return _value;
}
}
}