package com.bizunited.platform.core.repository.dynamic;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.Validate;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.AfterReturning;
import org.aspectj.lang.annotation.AfterThrowing;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.hibernate.Session;
import org.hibernate.SessionFactory;
import org.hibernate.Transaction;
import org.hibernate.resource.transaction.spi.TransactionStatus;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.aop.aspectj.MethodInvocationProceedingJoinPoint;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;

import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.stream.Collectors;

/**
 * 专门给Async("dynamicExecutor")设计的拦截器，用来在方法开始前开启指定数据源的session会话
 * 并且在方法执行完成后，视情况提交或者回滚session事务
 * @author yinwenjie
 */
@Aspect
@Component
public class DynamicDataSourceTaskAspect {
  
  @Autowired
  private DynamicDataSourceManager dynamicDataSourceManager;
  private static final String ERROR_MESS = "第三方数据源调用时，至少需要传入dataSourceCode(java.lang.String)参数值!!";
  private static final String ERROR_MESS_CODE = "未发现指定dataSourceCode(%s)对应的数据源信息,请检查";
  private static Logger LOGGER = LoggerFactory.getLogger(DynamicDataSourceTaskAspect.class);
  
  /**
   * 以Async注解为切面关注点
   */
  @Pointcut("@annotation(org.springframework.scheduling.annotation.Async)")
  public void aspectHandle() { 
    
  }
  
  @Before(value="aspectHandle() && @annotation(async)" , argNames = "async")
  public void before(JoinPoint point , Async async) {
    /*
     * -这里主要是为当前线程附加hibernate jdbc session信息，步骤如下：
     * 1、只有Async注解，且这个注解中的value属性为“dynamicExecutor”时，才进行相关工作
     * 2、排查参数中DataSourceId参数，这是一个字符串参数，那么判断方式为：
     *  a、如果当前参数值中只有一个字符串参数，那么就是这个参数是DataSourceId参数
     *  b、如果当前参数值中有多个字符串参数，那么在这个版本中进行报错
     *  // TODO 后续，需要找到这个方法中加了@DataSourceId注解的字符串参数
     * 3、获取当前dynamicDataSourceManager中的currentSession信息，并根据session的状态进行判断
     *  (注意，没有获取到这个DataSourceId对应的第三方数据源，则抛出异常)
     *  a、如果当前session没有关闭，则不用处理了
     *  b、如果当前session关闭了，则重新打开
     * */
    // 1、=======
    if(!this.checkAsync(async)) {
      return;
    }
    
    // 2、=======
    MethodInvocationProceedingJoinPoint methodJoinPoint = (MethodInvocationProceedingJoinPoint)point;
    Object[] args = methodJoinPoint.getArgs();
    Validate.isTrue(args != null && args.length > 0 , ERROR_MESS);
    // 开始处理不为null的参数，以及这些不为null的参数的参数类型
    List<Object> notNullArgs = Arrays.stream(args).filter(item -> item != null).collect(Collectors.toList());
    List<Class<?>> notNullArgsClasses = notNullArgs.stream().map(Object::getClass).collect(Collectors.toList());
    Validate.isTrue(!notNullArgsClasses.isEmpty() , ERROR_MESS);
    // 找到String类型的参数所在位置
    List<Integer> stringClassIndexs = new LinkedList<>();
    for(int index = 0 ; index < notNullArgsClasses.size() ; index++) {
      Class<?> notNullArgsClass = notNullArgsClasses.get(index);
      if(notNullArgsClass == String.class) {
        stringClassIndexs.add(index);
      }
    }
    Validate.isTrue(!stringClassIndexs.isEmpty() , ERROR_MESS);
    // a、====
    String dataSourceCode = null;
    if(stringClassIndexs.size() == 1) {
      dataSourceCode = notNullArgs.get(stringClassIndexs.get(0)).toString();
    } 
    // b、====
    else {
      throw new IllegalArgumentException("目前版本支持的第三方数据源调用时，只支持传入一个字符串类型的参数dataSourceCode(java.lang.String)!!");
    }
    
    // 3、=========
    SessionFactory sessionFactory = this.dynamicDataSourceManager.getCurrentSessionFactory(dataSourceCode);
    Validate.notNull(sessionFactory , ERROR_MESS_CODE , dataSourceCode);
    Session currentSession = sessionFactory.getCurrentSession();
    if(currentSession == null || !currentSession.isOpen()) {
      currentSession = sessionFactory.openSession();
    }
    currentSession.beginTransaction();
    LOGGER.debug("before()");
  }
  
  private boolean checkAsync(Async async) {
    if(async == null) {
      return false;
    }
    String asyncValue = async.value();
    if(StringUtils.isBlank(asyncValue) || !StringUtils.equals(asyncValue, "dynamicExecutor")) {
      return false;
    }
    
    return true;
  }
  
  @AfterReturning(value="aspectHandle() && @annotation(async)" , argNames = "async")
  public void afterReturning(JoinPoint point, Async async) {
    /*
     * 1、只有Async注解，且这个注解中的value属性为“dynamicExecutor”时，才进行相关工作
     * 2、这里主要为当前线程已经附加的hibernate jdbc session做事务提交操作
     * */
    // 1、=======
    if(!this.checkAsync(async)) {
      return;
    }
    // 取得当前dataSourceCode的值
    String dataSourceCode = this.getDataSourceCode(point);
    
    // 2、=======
    SessionFactory sessionFactory = this.dynamicDataSourceManager.getCurrentSessionFactory(dataSourceCode);
    Validate.notNull(sessionFactory , ERROR_MESS_CODE , dataSourceCode);
    Session currentSession = sessionFactory.getCurrentSession();
    if(currentSession == null || !currentSession.isOpen()) {
      currentSession = sessionFactory.openSession();
    }
    if(!currentSession.isOpen()) {
      return;
    }
    
    // 执行到这里，就要尝试提交事务并关闭session了
    Transaction transaction = currentSession.getTransaction();
    if(transaction.getStatus() == TransactionStatus.ACTIVE) {
      transaction.commit();
    } else {
      currentSession.close();
    }
    LOGGER.debug("afterReturning()");
  }
  
  @AfterThrowing(value="aspectHandle() && @annotation(async)" , argNames = "async")
  public void afterThrowing(JoinPoint point, Async async) {
    /*
     * 1、只有Async注解，且这个注解中的value属性为“dynamicExecutor”时，才进行相关工作
     * 2、这里主要为当前线程任务在抛出异常后，进行hibernate jdbc session做事务的回滚
     * */
    // 1、=======
    if(!this.checkAsync(async)) {
      return;
    }
    // 取得当前dataSourceCode的值
    String dataSourceCode = this.getDataSourceCode(point);
    
    // 2、=======
    SessionFactory sessionFactory = this.dynamicDataSourceManager.getCurrentSessionFactory(dataSourceCode);
    Validate.notNull(sessionFactory , ERROR_MESS_CODE , dataSourceCode);
    Session currentSession = sessionFactory.getCurrentSession();
    if(currentSession == null || !currentSession.isOpen()) {
      currentSession = sessionFactory.openSession();
    }
    if(!currentSession.isOpen()) {
      return;
    }
    
    // 执行到这里，就要尝试提交事务并关闭session了
    Transaction transaction = currentSession.getTransaction();
    if(transaction.getStatus() == TransactionStatus.ACTIVE) {
      transaction.rollback();
    } else {
      currentSession.close();
    }
    LOGGER.debug("afterThrowing()");
  }
  
  /**
   * 这个私有方法基本不需要验证，因为在before中已经验证过了
   * @param point
   * @return
   */
  private String getDataSourceCode(JoinPoint point) {
    MethodInvocationProceedingJoinPoint methodJoinPoint = (MethodInvocationProceedingJoinPoint)point;
    Object[] args = methodJoinPoint.getArgs();
    List<Object> notNullArgs = Arrays.stream(args).filter(item -> item != null).collect(Collectors.toList());
    List<Class<?>> notNullArgsClasses = notNullArgs.stream().map(Object::getClass).collect(Collectors.toList());
    String dataSourceCode = null;
    for(int index = 0 ; index < notNullArgsClasses.size() ; index++) {
      Class<?> notNullArgsClass = notNullArgsClasses.get(index);
      if(notNullArgsClass == String.class) {
        dataSourceCode = notNullArgs.get(index).toString();
        break;
      }
    }
    
    return dataSourceCode;
  }
}
