package com.biz.crm.interceptor;

import com.biz.crm.common.privilege.PrivilegeSearchVo;
import com.biz.crm.util.OperationConfig;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.factory.DefaultObjectFactory;
import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.stereotype.Component;

import java.lang.reflect.Method;
import java.util.Properties;

/**
 * @Project crm
 * @PackageName com.biz.crm
 * @ClassName interceptor
 * @Author HuangLong
 * @Date 2020/11/19 13:47
 * @Description 数据权限拦截器
 */
@Slf4j
@Component
@Intercepts({@Signature(
        type = Executor.class,
        method = "query",
        args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
), @Signature(
        type = Executor.class,
        method = "query",
        args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}
)})
public class DataPermissionInterceptor implements Interceptor {

    /**
     * 这是对应上面的args的序号
     */
    static int MAPPED_STATEMENT_INDEX = 0;
    static int PARAMETER_INDEX = 1;

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object[] args = invocation.getArgs();
        MappedStatement mappedStatement = (MappedStatement) args[MAPPED_STATEMENT_INDEX];
        //id为执行的mapper方法的全路径名，如com.uv.dao.UserMapper.insertUser
        String id = mappedStatement.getId();
        //sql语句类型 select、delete、insert、update
        SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType();
        //目前只在select方法上做拦截
        if (ObjectUtils.notEqual(SqlCommandType.SELECT, sqlCommandType)) {
            return invocation.proceed();
        }
        Configuration configuration = mappedStatement.getConfiguration();
        Object parameter1 = args[PARAMETER_INDEX];
        Object  target = invocation.getTarget();
        StatementHandler handler =   configuration.newStatementHandler((Executor) target, mappedStatement,
                parameter1, RowBounds.DEFAULT, null, null);
        BoundSql boundSql = handler.getBoundSql();
        String sql = boundSql.getSql();
        //获取sql语句
//        String sql = getSqlByInvocation(invocation);
        //修改后的sql
        String mSql = sql;

        //注解逻辑判断  添加注解了才拦截
        String substring = id.substring(0, id.lastIndexOf("."));
        Method[] methods = Class.forName(substring).getMethods();
        String mName = mappedStatement.getId().substring(id.lastIndexOf(".") + 1, id.length());
        //去掉_COUNT后缀，给分页sql添加权限
        if (mName.endsWith("_COUNT")) {
            mName = mName.substring(0, mName.length() - 6);
        }
        for (Method method : methods) {
            if (method.isAnnotationPresent(SqlPrivilege.class) && mName.equals(method.getName())) {
                SqlPrivilege sqlPrivilege = method.getAnnotation(SqlPrivilege.class);
                if (sqlPrivilege.flag()) {
                    String privilegeSql = OperationConfig.createOperationSql(sqlPrivilege);
//                    log.info("获取权限后的sql：{}", privilegeSql);
                    if (StringUtils.isNotEmpty(privilegeSql)) {
                        //_COUNT的方法会把原本的SQL自动格式化，1=1会变成1 = 1，所以此处使用正则来匹配替换
                        mSql = sql.replaceFirst("1[\\s]?=[\\s]?1", privilegeSql);
                    } else {
                        mSql = sql;
                    }
//                    log.info("替换原sql后： {}", mSql);
                    //不能通过反射修改sql语句,反射修改后死活没生效
                    BoundSqlSource boundSqlSource = new BoundSqlSource(boundSql);
                    MappedStatement newMappedStatement = copyFromMappedStatement(mappedStatement, boundSqlSource);
                    MetaObject metaObject = MetaObject.forObject(newMappedStatement,
                            new DefaultObjectFactory(), new DefaultObjectWrapperFactory(),
                            new DefaultReflectorFactory());
                    metaObject.setValue("sqlSource.boundSql.sql", mSql);
                    args[MAPPED_STATEMENT_INDEX] = newMappedStatement;
                }
            }
        }
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object object) {
        return Plugin.wrap(object, this);
    }

    @Override
    public void setProperties(Properties properties) {

    }

    private MappedStatement copyFromMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
        MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(),ms.getId(),newSqlSource,ms.getSqlCommandType());

        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        if(ms.getKeyProperties() != null && ms.getKeyProperties().length !=0){
            StringBuffer keyProperties = new StringBuffer();
            for(String keyProperty : ms.getKeyProperties()){
                keyProperties.append(keyProperty).append(",");
            }
            keyProperties.delete(keyProperties.length()-1, keyProperties.length());
            builder.keyProperty(keyProperties.toString());
        }

        //setStatementTimeout()
        builder.timeout(ms.getTimeout());

        //setStatementResultMap()
        builder.parameterMap(ms.getParameterMap());

        //setStatementResultMap()
        builder.resultMaps(ms.getResultMaps());
        builder.resultSetType(ms.getResultSetType());

        //setStatementCache()
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());

        return builder.build();
    }
    private class BoundSqlSource implements SqlSource {

        private BoundSql boundSql;

        private BoundSqlSource(BoundSql boundSql) {
            this.boundSql = boundSql;
        }

        @Override
        public BoundSql getBoundSql(Object parameterObject) {
            return boundSql;
        }
    }

}
