package com.netease.wd.crossorigin.filter;

import java.io.IOException;
import java.io.PrintWriter;
import java.util.List;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.log4j.Logger;

import com.netease.wd.crossorigin.utils.ConfigUtil;
import com.netease.wd.crossorigin.utils.Constants;
import com.netease.wd.crossorigin.utils.CrossOriginUtil;

/**
 * Cross Origin Filter <br/>
 * 实现Cross Origin Resource Share 相关规范。 {@link http ://www.w3.org/TR/cors/}
 * 
 * @author zhouyongchang(zhouyongchang@corp.netease.com)
 */
public class CrossOriginFilter implements Filter {

	private static Logger logger = Logger.getLogger(CrossOriginFilter.class);

	/**
	 * allow 列表
	 */
	private List<String> allowOrigins = null;
	private List<String> allowMethods = null;
	private List<String> allowHeaders = null;

	/**
	 * access control 头cache
	 */
	private String allowOriginStr = null;
	private String allowMethodStr = null;
	private String allowHeaderStr = null;
	private String exposeHeaderStr = null;

	private boolean supportCredentials = false;
	private int maxAge = -1;
	
	// 检查host和referer是否匹配开关。主要用来处理IE6等不支持origin的情况
	private boolean checkReferer = false;

	@Override
	public void destroy() {
		// TODO Auto-generated method stub
	}

	@Override
	public void doFilter(ServletRequest req, ServletResponse resp,
			FilterChain chain) throws IOException, ServletException {
		HttpServletRequest request = (HttpServletRequest) req;
		HttpServletResponse response = (HttpServletResponse) resp;

		// cross origin规范流程实现
		if (request.getMethod().equalsIgnoreCase("OPTIONS")) {
			this.doOptions(request, response, chain);
		} else {
			this.doRequest(request, response, chain);
		}
	}

	@Override
	public void init(FilterConfig config) throws ServletException {
		// cross oririn 初始化参数
		this.allowOrigins = ConfigUtil.parseParam(config
				.getInitParameter("allowOrigins"));
		this.allowMethods = ConfigUtil.parseParam(config
				.getInitParameter("allowMethods"));
		this.allowHeaders = ConfigUtil.parseParam(config
				.getInitParameter("allowHeaders"));
		this.maxAge = ConfigUtil.getIntParam(config.getInitParameter("maxAge"),
				-1);
		this.supportCredentials = ConfigUtil.getBooleanParam(
				config.getInitParameter("supportCredentials"), false);
		this.exposeHeaderStr = ConfigUtil.formatListParam(config
				.getInitParameter("exposeHeaders"));
		this.checkReferer = ConfigUtil.getBooleanParam(config.getInitParameter("checkReferer"), false);

		this.allowOriginStr = ConfigUtil.composeString(this.allowOrigins, ",");
		this.allowMethodStr = ConfigUtil.composeString(this.allowMethods, ",");
		this.allowHeaderStr = ConfigUtil.composeString(this.allowHeaders, ",");
	}

	/**
	 * 处理OPTIONS请求。
	 * 
	 * @param req
	 * @param resp
	 * @param chain
	 * @throws IOException
	 * @throws ServletException
	 */
	private void doOptions(HttpServletRequest req, HttpServletResponse resp,
			FilterChain chain) throws IOException, ServletException {
		// 检查origin
		String origin = req.getHeader("Origin");
		if(origin == null && this.checkReferer && this.checkCrossOriginByReferer(req)) {
			origin = CrossOriginUtil.parseHostFromUrl(req.getHeader("referer"));
		}
		
		if (origin == null || origin.trim().equalsIgnoreCase("null")) {
			logger.info("Skip empty origin request for OPTIONS.");
			skip(req, resp, chain);
			return;
		}
		// origin需要*大小写敏感*
		if (!CrossOriginUtil.match(origin, this.allowOrigins, true)) {
			logger.info("Refuse invalid Origin request for OPTIONS. Origin:"
					+ origin);
			refuse(req, resp, chain);
			return;
		}

		// 检查Access-Control-Request-Method
		String method = req.getHeader(Constants.ACR_METHOD);
		if (method == null) {
			logger.info("Skip empty Access-Control-Request-Method request for OPTIONS.");
			skip(req, resp, chain);
			return;
		}
		// method需要*大小写敏感*
		if (!CrossOriginUtil.match(method, this.allowMethods, true)) {
			logger.info("Refuse invalid Access-Control-Request-Method Request for OPTIONS. Method:"
					+ origin);
			refuse(req, resp, chain);
			return;
		}

		// 检查Access-Control-Request-Headers
		String acrHeaders = req.getHeader(Constants.ACR_HEADERS);
		if (!checkACRHeaders(acrHeaders)) {
			logger.info("Refuse invalid Access-Control-Request-Headers for OPTIONS. Headers:"
					+ acrHeaders);
			refuse(req, resp, chain);
			return;
		}

		// 添加Access-Control-Allow等系列Header
		this.addAllowHeaders(origin, method, acrHeaders, resp);

		if (this.maxAge > 0) {
			resp.setHeader(Constants.AC_MAX_AGE, Integer.toString(this.maxAge));
		}

		resp.setStatus(HttpServletResponse.SC_OK);
		resp.setContentLength(1);
		PrintWriter out = resp.getWriter();
		out.println(" ");
		out.flush();
	}

	/**
	 * 处理正常的请求
	 * 
	 * @param req
	 * @param resp
	 * @param chain
	 * @throws IOException
	 * @throws ServletException
	 */
	private void doRequest(HttpServletRequest req, HttpServletResponse resp,
			FilterChain chain) throws IOException, ServletException {
		// 检查origin
		String origin = req.getHeader("Origin");
		if(origin == null && this.checkReferer && this.checkCrossOriginByReferer(req)) {
			origin = CrossOriginUtil.parseHostFromUrl(req.getHeader("referer"));
		}
		
		if (origin == null || origin.trim().equalsIgnoreCase("null")) {
			logger.warn("Skip empty origin request.");
			skip(req, resp, chain);
			return;
		}

		if (!CrossOriginUtil.match(origin, this.allowOrigins, true)) {
			logger.error("Refuse invalid origin request. origin:" + origin);
			refuse(req, resp, chain);
			return;
		}

		this.addAllowHeaders(origin, null, null, resp);

		if (this.hasExposeHeader()) {
			resp.addHeader(Constants.AC_EXPOSE_HEADERS, this.exposeHeaderStr);
		}

		chain.doFilter(req, resp);
	}

	/**
	 * 检查Access-Control-Request-Header。 如果所带的ACRH头为空或当中所有header
	 * name为允许的header，返回true； 否则返回false。
	 * 
	 * @param req
	 * @return
	 */
	private boolean checkACRHeaders(String headerNames) {
		if (headerNames == null || headerNames.trim().length() == 0) {
			return true;
		}

		String[] ts = headerNames.split("\\s*,\\s*");
		for (String t : ts) {
			if (!CrossOriginUtil.match(t, this.allowHeaders, false)) {
				return false;
			}
		}

		return true;
	}

	/**
	 * 添加Access-Control-Allow Header
	 * 
	 * @param origin
	 * @param resp
	 */
	private void addAllowHeaders(String origin, String method, String headers,
			HttpServletResponse resp) {
		// 根据当前资源是否支持credentials来添加origin header
		if (this.supportCredentials) {
			resp.setHeader(Constants.ACA_ORIGIN, origin);
			resp.setHeader(Constants.ACA_CREDENTIALS, "true");
		} else {
			resp.setHeader(Constants.ACA_ORIGIN, this.allowOriginStr);
		}
		if (method != null) {
			resp.setHeader(Constants.ACA_METHODS, method);
		} else {
			 resp.setHeader(Constants.ACA_METHODS, this.allowMethodStr);
		}
		if (headers != null) {
			resp.setHeader(Constants.ACA_HEADERS, headers);
		} else {
			 resp.setHeader(Constants.ACA_HEADERS, this.allowHeaderStr);
		}
	}

	/**
	 * 是否有配置expose header
	 * 
	 * @return
	 */
	private boolean hasExposeHeader() {
		return this.exposeHeaderStr != null
				&& this.exposeHeaderStr.length() > 0;
	}

	/**
	 * 请求不在本cross origin resource share规范定义范围之内，跳过filter检查
	 * 
	 * @param req
	 * @param resp
	 * @param chain
	 * @throws IOException
	 * @throws ServletException
	 */
	private void skip(HttpServletRequest req, HttpServletResponse resp,
			FilterChain chain) throws IOException, ServletException {
		chain.doFilter(req, resp);
	}

	/**
	 * 请求不符合cross origin resource share过滤，拒绝访问
	 * 
	 * @param req
	 * @param resp
	 * @param chain
	 * @throws IOException
	 * @throws ServletException
	 */
	private void refuse(HttpServletRequest req, HttpServletResponse resp,
			FilterChain chain) throws IOException, ServletException {
		resp.setStatus(HttpServletResponse.SC_OK);
		resp.addHeader(Constants.ACA_ORIGIN, "");
		resp.setContentLength(1);
		PrintWriter out = resp.getWriter();
		out.println(" ");
		out.flush();
	}
	
	/**
	 * 从referer和host判断当前请求是否跨域。
	 * 该函数主要处理IE6等这些跨域请求时不带origin头的情况。
	 * referer是发起请求页面的url, 如：http://x.y.com/index.html，
	 * host是所请求资源的域名，如：http://a.b.com，
	 * 如果二者的域名是一致的话，则认为没有跨域。
	 * 
	 * @param req
	 * @return
	 */
	private boolean checkCrossOriginByReferer(HttpServletRequest req) {
		String referer = req.getHeader("Referer");
		String host = req.getHeader("Host");
		if(referer == null || referer.trim().length() == 0 || 
			host == null || host.trim().length() == 0) {
			//NOTICE: 如果这两个字段再缺，如何判断？目前暂定为不跨域
			return false;
		}
		return !referer.startsWith(host);
	}
}
