package com.biz.cascore.interceptor;

import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.lang.StringUtils;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.springframework.util.ReflectionUtils;

import com.biz.cascore.pagination.Page;
import com.biz.cascore.pagination.Pageable;

@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class})})  
public class PageableHelper implements Interceptor{
	
//	private static final Logger logger = Logger.getLogger(PageHelper.class);
	
	
	/** 正则过滤*/
	private String pageSqlId;
	/** 方言*/
	private String dialect;

	public Object intercept(Invocation invocation) throws Throwable {
		if (invocation.getTarget() instanceof StatementHandler) {
			 statementHandlerExecutor(invocation);
		}
		return invocation.proceed();
	}

	public Object plugin(Object target) {
		if (target instanceof StatementHandler) {  
	        return Plugin.wrap(target, this);    
	    } else {    
	        return target;
	    }
	}

	public void setProperties(Properties properties) {
		dialect = properties.getProperty("dialect");
		pageSqlId = properties.getProperty("pageSqlId");
	}
	
	
	  /**
     * 给当前的参数对象page设置总记录数
     *
     * @param page Mapper映射语句对应的参数对象
     * @param mappedStatement Mapper映射语句
     * @param connection 当前的数据库连接
     */
    private void setTotalRecord(Page<?> page, MappedStatement mappedStatement, Connection connection) {
       //获取对应的BoundSql，这个BoundSql其实跟我们利用StatementHandler获取到的BoundSql是同一个对象。
       //delegate里面的boundSql也是通过mappedStatement.getBoundSql(paramObj)方法获取到的。
       BoundSql boundSql = mappedStatement.getBoundSql(page);
       //获取到我们自己写在Mapper映射语句中对应的Sql语句
       String sql = boundSql.getSql();
       //通过查询Sql语句获取到对应的计算总记录数的sql语句
       String countSql = this.getCountSql(sql);
       //通过BoundSql获取对应的参数映射
       List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
       //利用Configuration、查询记录数的Sql语句countSql、参数映射关系parameterMappings和参数对象page建立查询记录数对应的BoundSql对象。
       BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, page);
       //通过mappedStatement、参数对象page和BoundSql对象countBoundSql建立一个用于设定参数的ParameterHandler对象
       ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, page, countBoundSql);
       //通过connection建立一个countSql对应的PreparedStatement对象。
       PreparedStatement pstmt = null;
       ResultSet rs = null;
       int retryCount = 5;  
       boolean transactionCompleted = false;
       do{
	       try {
	           pstmt = connection.prepareStatement(countSql);
	           //通过parameterHandler给PreparedStatement对象设置参数
	           parameterHandler.setParameters(pstmt);
	           //之后就是执行获取总记录数的Sql语句和获取结果了。
	           rs = pstmt.executeQuery();
	           if (rs.next()) {
	              Long totalRecordCount = rs.getLong(1);
	              //给当前的参数page对象设置总记录数
	              page.setTotal(totalRecordCount);
	           }
	           transactionCompleted = true;
	       } catch (SQLException e) {
	    	   String sqlState = e.getSQLState();  
	           // 这个08S01就是这个异常的sql状态。单独处理手动重新链接就可以了。  
	            if ("08S01".equals(sqlState) || "40001".equals(sqlState)){            
	                retryCount--;
	             } else {                  
	                retryCount = 0;              
	             }
	       }finally {
	           try {
	               if (rs != null)
	                   rs.close();
	                if (pstmt != null)
	                   pstmt.close();
	            } catch (SQLException e) {
	               e.printStackTrace();
	            }
	        }
	      } while (!transactionCompleted && (retryCount > 0));
    }
    
    /**
     * 根据原Sql语句获取对应的查询总记录数的Sql语句
     * @param sql
     * @return
     */
    private String getCountSql(String sql) {
       int index = StringUtils.indexOfAny(sql, new String[]{"from","FROM"});
       return "select count(*) " + sql.substring(index);
    }
    
    /**
     * 根据page对象获取对应的分页查询Sql语句，这里只做了两种数据库类型，Mysql和Oracle
     * 其它的数据库都 没有进行分页
     *
     * @param page 分页对象
     * @param sql 原sql语句
     * @return
     */
    private String getPageSql(Pageable pageable, String sql) {
       StringBuffer sqlBuffer = new StringBuffer(sql);
       if ("mysql".equalsIgnoreCase(dialect)) {
           return getMysqlPageSql(pageable, sqlBuffer);
       }else if("oracle".equalsIgnoreCase(dialect)) {
    	   return getOraclePageSql(pageable,sqlBuffer);
       }
       return sqlBuffer.toString();
    }
    
    /**
     * 获取Oracle数据库的分页查询语句
     * @param page 分页对象
     * @param sqlBuffer 包含原sql语句的StringBuffer对象
     * @return Mysql数据库分页语句
     */
	private String getOraclePageSql(Pageable pageable, StringBuffer sqlBuffer) {
		int offset = ( pageable.getPage() - 1 ) * pageable.getRows();
		int endResult = pageable.getPage() * pageable.getRows();
		StringBuffer pageBuffer = new StringBuffer();
		pageBuffer.append(" SELECT T.* FROM ( SELECT ROWNUM RN,TMP.* FROM ( ").append(sqlBuffer).append(" )TMP WHERE ROWNUM  &lt;= ").append(endResult).append(" )  T ");
		pageBuffer.append(" WHERE RN &gt;").append(offset);
		return pageBuffer.toString();
	}

	/**
     * 获取Mysql数据库的分页查询语句
     * @param page 分页对象
     * @param sqlBuffer 包含原sql语句的StringBuffer对象
     * @return Mysql数据库分页语句
     */
    private String getMysqlPageSql(Pageable pageable, StringBuffer sqlBuffer) {
       //计算第一条记录的位置，Mysql中记录的位置是从0开始的。
       int offset = ((pageable.getPage() - 1) * pageable.getRows());
       sqlBuffer.append(" limit ").append(offset).append(",").append(pageable.getRows());
       return sqlBuffer.toString();
    }

	/**
	 * 获取  private
	 * @return the pageSqlId
	 */
	public String getPageSqlId() {
		return pageSqlId;
	}

	/**
	 * 获取  方言
	 * @return the dialect
	 */
	public String getDialect() {
		return dialect;
	}

	/**
	 * 设置 private
	 * @param pageSqlId the private to set
	 */
	public void setPageSqlId(String pageSqlId) {
		this.pageSqlId = pageSqlId;
	}

	/**
	 * 设置 方言
	 * @param dialect the 方言 to set
	 */
	public void setDialect(String dialect) {
		this.dialect = dialect;
	}
	
	/**
	 * SQL 预处理
	 * @Title: statementHandlerExecutor 
	 * @param invocation
	 * @return
	 * @throws Throwable
	 */
	@SuppressWarnings("unchecked")
	private void statementHandlerExecutor(Invocation invocation) throws Throwable{
		RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
	       //通过反射获取到当前RoutingStatementHandler对象的delegate属性
			Field delegateField = ReflectionUtils.findField(handler.getClass(), "delegate");
			delegateField.setAccessible(true);
	        StatementHandler statementHandler = (StatementHandler)ReflectionUtils.getField(delegateField, handler);//ReflectionUtils.getFieldValue(handler, "delegate");
	        
	        Field mappedField = ReflectionUtils.findField(statementHandler.getClass(), "mappedStatement");
	        mappedField.setAccessible(true);
	        MappedStatement mappedStatement = (MappedStatement) ReflectionUtils.getField(mappedField, statementHandler);
	        String queryId = mappedStatement.getId();
	        
	        Pattern pattern = Pattern.compile(pageSqlId);
	        if(StringUtils.isNotBlank(queryId)) {
	        	Matcher matcher = pattern.matcher(queryId);
	        	//如果匹配指定类型则进行分页
	        	if(matcher.find()) {
	        		BoundSql boundSql = statementHandler.getBoundSql();
	        		Field sqlField = ReflectionUtils.findField(boundSql.getClass(), "sql");
	        		sqlField.setAccessible(true);
	        		
	        		Object parameterObj = boundSql.getParameterObject();
	        		Page<?> page = null;
	        		Pageable pageable = null;
	        		if(parameterObj instanceof Page) {
	        			page = (Page<?>) parameterObj;
	        		} else if (parameterObj instanceof Map){ //多参数的情况，找到第一个Page的参数  
	                    for (Map.Entry<String, Object> e : ((Map<String, Object>)parameterObj).entrySet()){  
	                        if (e.getValue() instanceof Page){  
	                        	page = (Page<?>)e.getValue();  
	                            break;  
	                        }  
	                    }  
	                }
	        		if(page != null) {
	        			pageable = page.getPageable();
	        			 //拦截到的prepare方法参数是一个Connection对象
	                    Connection connection = (Connection)invocation.getArgs()[0];
	                    //获取当前要执行的Sql语句，也就是我们直接在Mapper映射语句中写的Sql语句
	                    String sql = boundSql.getSql();
	                    //给当前的page参数对象设置总记录数
	                    this.setTotalRecord(page,mappedStatement, connection);
	                    //获取分页Sql语句
	                    String pageSql = this.getPageSql(pageable, sql);
	                    //利用反射设置当前BoundSql对应的sql属性为我们建立好的分页Sql语句
	                    ReflectionUtils.setField(sqlField, boundSql, pageSql);
	        		}
	        	}
	        }
	}

}
