/* * Copyright 2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ratpack.pac4j.internal; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import io.netty.handler.codec.http.cookie.DefaultCookie; import org.pac4j.core.context.Cookie; import org.pac4j.core.context.WebContext; import org.pac4j.core.exception.RequiresHttpAction; import ratpack.exec.Promise; import ratpack.form.Form; import ratpack.form.internal.DefaultForm; import ratpack.form.internal.FormDecoder; import ratpack.handling.Context; import ratpack.http.*; import ratpack.server.PublicAddress; import ratpack.session.Session; import ratpack.session.SessionData; import ratpack.util.Exceptions; import ratpack.util.MultiValueMap; import java.net.URI; import java.util.*; public class RatpackWebContext implements WebContext { private final Context context; private final SessionData session; private final Request request; private final Response response; private final Form form; private String responseContent = ""; public RatpackWebContext(Context ctx, TypedData body, SessionData session) { this.context = ctx; this.session = session; this.request = ctx.getRequest(); this.response = ctx.getResponse(); if (isFormAvailable(request, body)) { this.form = FormDecoder.parseForm(ctx, body, MultiValueMap.empty()); } else { this.form = new DefaultForm(MultiValueMap.empty(), MultiValueMap.empty()); } } public static Promise<RatpackWebContext> from(Context ctx, boolean bodyBacked) { Promise<SessionData> sessionDataPromise = ctx.get(Session.class).getData(); if (bodyBacked) { return ctx.getRequest().getBody().flatMap(body -> sessionDataPromise.map(sessionData -> new RatpackWebContext(ctx, body, sessionData)) ); } else { return sessionDataPromise.map(sessionData -> new RatpackWebContext(ctx, null, sessionData)); } } @Override public String getRequestParameter(String name) { return Optional.ofNullable(request.getQueryParams().get(name)) .orElseGet(() -> form.get(name)); } @Override public Map<String, String[]> getRequestParameters() { return flattenMap(combineMaps(request.getQueryParams(), form)); } private RequestAttributes getRequestAttributes() { RequestAttributes attributes = request.get(RequestAttributes.class); if (attributes == null) { attributes = new RequestAttributes(); request.add(attributes); } return attributes; } @Override public Object getRequestAttribute(String name) { return getRequestAttributes().getAttributes().get(name); } @Override public void setRequestAttribute(String name, Object value) { getRequestAttributes().getAttributes().put(name, value); } @Override public String getRequestHeader(String name) { return request.getHeaders().get(name); } @Override public void setSessionAttribute(String name, Object value) { if (value == null) { session.remove(name); } else { Exceptions.uncheck(() -> session.set(name, value, session.getJavaSerializer())); } } @Override public Object getSessionAttribute(String name) { return Exceptions.uncheck(() -> session.get(name, session.getJavaSerializer()).orElse(null)); } @Override public Object getSessionIdentifier() { return session.getSession().getId(); } @Override public String getRequestMethod() { return request.getMethod().getName(); } @Override public String getRemoteAddr() { return request.getRemoteAddress().getHost(); } @Override public void writeResponseContent(String responseContent) { this.responseContent = responseContent; } @Override public void setResponseStatus(int code) { response.status(code); } @Override public void setResponseHeader(String name, String value) { response.getHeaders().set(name, value); } @Override public void setResponseContentType(String content) { response.contentType(content); } @Override public String getServerName() { return getAddress().getHost(); } @Override public int getServerPort() { return getAddress().getPort(); } @Override public String getScheme() { return getAddress().getScheme(); } @Override public boolean isSecure() { return "HTTPS".equalsIgnoreCase(getScheme()); } @Override public String getFullRequestURL() { return getAddress().toString() + request.getUri(); } public void sendResponse(RequiresHttpAction action) { response.status(action.getCode()); sendResponse(); } public void sendResponse() { int statusCode = response.getStatus().getCode(); if (statusCode >= 400) { context.clientError(statusCode); } else { response.send(MediaType.TEXT_HTML, responseContent); } } @Override public Collection<Cookie> getRequestCookies() { final List<Cookie> newCookies = new ArrayList<>(); final Set<io.netty.handler.codec.http.cookie.Cookie> cookies = request.getCookies(); for (final io.netty.handler.codec.http.cookie.Cookie cookie : cookies) { final Cookie newCookie = new Cookie(cookie.name(), cookie.value()); newCookie.setDomain(cookie.domain()); newCookie.setPath(cookie.path()); newCookie.setMaxAge((int) cookie.maxAge()); newCookie.setSecure(cookie.isSecure()); newCookie.setHttpOnly(cookie.isHttpOnly()); newCookies.add(newCookie); } return newCookies; } @Override public void addResponseCookie(Cookie cookie) { final DefaultCookie newCookie = new DefaultCookie(cookie.getName(), cookie.getValue()); newCookie.setDomain(cookie.getDomain()); newCookie.setPath(cookie.getPath()); newCookie.setMaxAge(cookie.getMaxAge()); newCookie.setSecure(cookie.isSecure()); newCookie.setHttpOnly(cookie.isHttpOnly()); response.getCookies().add(newCookie); } @Override public String getPath() { return request.getPath(); } public SessionData getSession() { return session; } private URI getAddress() { return context.get(PublicAddress.class).get(); } private static boolean isFormAvailable(Request request, TypedData body) { HttpMethod method = request.getMethod(); return body != null && body.getContentType().isForm() && (method.isPost() || method.isPut()); } private Map<String, List<String>> combineMaps(MultiValueMap<String, String> first, MultiValueMap<String, String> second) { Map<String, List<String>> result = Maps.newLinkedHashMap(); Set<String> keys = Sets.newLinkedHashSet(Iterables.concat(first.keySet(), second.keySet())); for (String key : keys) { result.put(key, Lists.newArrayList(Iterables.concat(first.getAll(key), second.getAll(key)))); } return result; } private Map<String, String[]> flattenMap(Map<String, List<String>> map) { Map<String, String[]> result = Maps.newLinkedHashMap(); for (String key : map.keySet()) { result.put(key, Iterables.toArray(map.get(key), String.class)); } return result; } private class RequestAttributes { private Map<String, Object> attributes = new HashMap<>(); public Map<String, Object> getAttributes() { return attributes; } } }