package edu.mayo.cts2.framework.webapp.rest.filter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Map;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import edu.mayo.cts2.framework.model.exception.UnspecifiedCts2Exception;
public class AcceptHeaderAdjustingFilter implements Filter {
@Override
public void destroy() {
//
}
@Override
public void doFilter(
ServletRequest request,
ServletResponse response,
FilterChain chain) throws IOException, ServletException {
if(! (request instanceof HttpServletRequest)){
throw new UnspecifiedCts2Exception("ServletRequest expected to be of type HttpServletRequest");
}
HttpServletRequest httpRequest = (HttpServletRequest) request;
@SuppressWarnings("unchecked")
Map<String, String[]> params = httpRequest.getParameterMap();
if(params.containsKey("format")){
String[] formats = params.get("format");
if(formats.length != 1){
throw new IllegalStateException("Only one 'format' parameter allowed.");
}
String format = formats[0];
String type;
if(format.equals("json")){
type = "application/json";
} else if (format.equals("xml")){
type = "application/xml";
} else {
throw new IllegalStateException("Format: " + format + " not recognized.");
}
chain.doFilter(new AcceptTypeChangingRequest(httpRequest, type), response);
} else {
chain.doFilter(request, response);
}
}
@Override
public void init(FilterConfig config) throws ServletException {
//
}
public class AcceptTypeChangingRequest extends HttpServletRequestWrapper {
private String acceptHeader;
public AcceptTypeChangingRequest(HttpServletRequest request, String acceptHeader) {
super(request);
this.acceptHeader = acceptHeader;
}
@SuppressWarnings("rawtypes")
public Enumeration getHeaders(String name){
if(name.equalsIgnoreCase("accept")){
return Collections.enumeration(Arrays.asList(acceptHeader));
}
return super.getHeaders(name);
}
@Override
public String getHeader(String name) {
if(name.equalsIgnoreCase("accept")){
return acceptHeader;
}
return super.getHeader(name);
}
@Override
@SuppressWarnings("unchecked")
public Enumeration<String> getHeaderNames() {
List<String> headers = new ArrayList<String>();
headers.addAll(Collections.list(super.getHeaderNames()));
headers.add("Accept");
return Collections.enumeration(headers);
}
}
}