/* * Copyright 2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ratpack.session.internal; import com.google.common.collect.FluentIterable; import com.google.common.collect.ImmutableSortedMap; import com.google.common.collect.Iterables; import com.google.common.collect.Ordering; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufInputStream; import io.netty.buffer.ByteBufOutputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import ratpack.exec.Operation; import ratpack.exec.Promise; import ratpack.http.Response; import ratpack.session.*; import ratpack.util.Types; import java.io.*; import java.util.*; public class DefaultSession implements Session { private static final Logger LOGGER = LoggerFactory.getLogger(Session.class); private Map<SessionKey<?>, byte[]> entries; private final SessionId sessionId; private final ByteBufAllocator bufferAllocator; private final SessionStore storeAdapter; private final Response response; private final SessionSerializer defaultSerializer; private final JavaSessionSerializer javaSerializer; private enum State { NOT_LOADED, CLEAN, DIRTY } private State state = State.NOT_LOADED; private boolean callbackAdded; private final SessionData data = new Data(); private static class SerializedForm implements Externalizable { private static final long serialVersionUID = 2; private static final Ordering<SessionKey<?>> KEY_NAME_ORDERING = Ordering.natural() .nullsFirst() .onResultOf(SessionKey::getName); private static final Ordering<SessionKey<?>> KEY_TYPE_ORDERING = Ordering.natural() .nullsFirst() .onResultOf(k -> k.getType() == null ? null : k.getType().getName()); private static final Comparator<SessionKey<?>> COMPARATOR = KEY_NAME_ORDERING.compound(KEY_TYPE_ORDERING); private Map<SessionKey<?>, byte[]> entries; public SerializedForm() { } @Override public void writeExternal(ObjectOutput out) throws IOException { out.writeShort(1); // schema version out.writeShort(entries.size()); Map<SessionKey<?>, byte[]> sorted = ImmutableSortedMap.copyOf(entries, COMPARATOR); for (Map.Entry<SessionKey<?>, byte[]> entry : sorted.entrySet()) { String name = entry.getKey().getName(); if (name == null) { out.writeBoolean(false); } else { out.writeBoolean(true); out.writeUTF(name); } Class<?> type = entry.getKey().getType(); if (type == null) { out.writeBoolean(false); } else { out.writeBoolean(true); out.writeUTF(type.getName()); } byte[] bytes = entry.getValue(); out.writeInt(bytes.length); out.write(bytes); } } @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); in.readShort(); // schema version short num = in.readShort(); entries = new HashMap<>(num); for (short i = 0; i < num; ++i) { String name = in.readBoolean() ? in.readUTF() : null; Class<Object> type; if (in.readBoolean()) { String typeName = in.readUTF(); Class<?> o = classLoader.loadClass(typeName); type = Types.cast(o); } else { type = null; } int bytesLength = in.readInt(); byte[] bytes = new byte[bytesLength]; int read = in.read(bytes); while (read < bytesLength) { read += in.read(bytes, read, bytesLength - read); } entries.put(new DefaultSessionKey<>(name, type), bytes); } } } public DefaultSession(SessionId sessionId, ByteBufAllocator bufferAllocator, SessionStore storeAdapter, Response response, SessionSerializer defaultSerializer, JavaSessionSerializer javaSerializer) { this.sessionId = sessionId; this.bufferAllocator = bufferAllocator; this.storeAdapter = storeAdapter; this.response = response; this.defaultSerializer = defaultSerializer; this.javaSerializer = javaSerializer; } @Override public String getId() { return sessionId.getValue().toString(); } @Override public Promise<SessionData> getData() { if (state == State.NOT_LOADED) { return storeAdapter.load(sessionId.getValue()).map(bytes -> { state = State.CLEAN; try { hydrate(bytes); } finally { bytes.release(); } return data; }); } else { return Promise.value(data); } } private void hydrate(ByteBuf bytes) throws Exception { if (bytes.readableBytes() > 0) { try { SerializedForm deserialized = defaultSerializer.deserialize(SerializedForm.class, new ByteBufInputStream(bytes)); if (deserialized == null) { this.entries = new HashMap<>(); } else { entries = deserialized.entries; } } catch (Exception e) { LOGGER.warn("Exception thrown deserializing session " + getId() + " with serializer " + defaultSerializer + " (session will be discarded)", e); this.entries = new HashMap<>(); markDirty(); } } else { this.entries = new HashMap<>(); } } @Override public JavaSessionSerializer getJavaSerializer() { return javaSerializer; } @Override public SessionSerializer getDefaultSerializer() { return defaultSerializer; } @Override public boolean isDirty() { return state == State.DIRTY; } private ByteBuf serialize() throws Exception { SerializedForm serializable = new SerializedForm(); serializable.entries = entries; ByteBuf buffer = bufferAllocator.buffer(); OutputStream outputStream = new ByteBufOutputStream(buffer); try { defaultSerializer.serialize(SerializedForm.class, serializable, outputStream); outputStream.close(); return buffer; } catch (Throwable e) { buffer.release(); throw e; } } @Override public Operation save() { return Operation.of(() -> { if (state != State.NOT_LOADED) { ByteBuf serialized = serialize(); storeAdapter.store(sessionId.getValue(), serialized) .wiretap(o -> serialized.release()) .then(() -> state = State.CLEAN); } }); } @Override public Operation terminate() { return storeAdapter.remove(sessionId.getValue()) .next(() -> { sessionId.terminate(); if (entries != null) { entries.clear(); } state = State.NOT_LOADED; }); } private void markDirty() { state = State.DIRTY; if (!callbackAdded) { callbackAdded = true; response.beforeSend(responseMetaData -> { callbackAdded = false; // another before send may try and use the session if (state == State.DIRTY) { save().then(); } }); } } private class Data implements SessionData { @Override public <T> Optional<T> get(SessionKey<T> key, SessionSerializer serializer) throws Exception { String name = key.getName(); if (key.getType() == null) { key = Types.cast(findKey(name)); if (key == null) { return Optional.empty(); } } byte[] bytes = entries.get(key); if (bytes == null) { return Optional.empty(); } else { try { T value = serializer.deserialize(key.getType(), new ByteArrayInputStream(bytes)); return Optional.ofNullable(value); } catch (Exception e) { LOGGER.warn("Exception thrown deserializing entry " + key + " with serializer " + serializer + " (value will be discarded from session)", e); remove(key); return Optional.empty(); } } } private SessionKey<?> findKey(String name) { List<Map.Entry<SessionKey<?>, byte[]>> entries = FluentIterable.from(DefaultSession.this.entries.entrySet()) .filter(e -> Objects.equals(e.getKey().getName(), name)) .toList(); if (entries.isEmpty()) { return null; } else if (entries.size() == 1) { return entries.get(0).getKey(); } else { throw new IllegalArgumentException("Found more than one session entry with name '" + name + "': " + Iterables.transform(entries, Map.Entry::getKey)); } } @Override public <T> void set(SessionKey<T> key, T value, SessionSerializer serializer) throws Exception { Objects.requireNonNull(key, "session key cannot be null"); Objects.requireNonNull(value, "session value for key " + key.getName() + " cannot be null"); ByteArrayOutputStream out = new ByteArrayOutputStream(); serializer.serialize(key.getType(), value, out); entries.put(key, out.toByteArray()); markDirty(); } @Override public Set<SessionKey<?>> getKeys() { return entries.keySet(); } @Override public SessionSerializer getDefaultSerializer() { return defaultSerializer; } @Override public void remove(SessionKey<?> key) { if (key.getType() == null) { key = findKey(key.getName()); if (key == null) { return; } } if (entries.remove(key) != null) { markDirty(); } } @Override public void clear() { entries.clear(); markDirty(); } @Override public Session getSession() { return DefaultSession.this; } } }