package com.bizunited.platform.mars.policy.process.runtime.service;

import java.lang.reflect.InvocationTargetException;
import java.util.Set;
import java.util.stream.Collectors;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.Validate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

import com.bizunited.platform.mars.policy.process.cache.RuntimeDefinition;
import com.bizunited.platform.mars.policy.process.cache.RuntimeNode;
import com.bizunited.platform.mars.policy.process.cache.RuntimeNodeNexts;
import com.bizunited.platform.mars.policy.process.cache.RuntimeNodeType;
import com.bizunited.platform.mars.policy.process.cache.RuntimeProcessorLinked;
import com.bizunited.platform.mars.policy.process.runtime.contexts.RuleRuntimeContext;

/**
 * 规则节点定义运行时服务的实现
 * @author yinwenjie
 *
 */
@Component("SimpleRuntimeNodeService")
public class SimpleRuntimeNodeService extends AbstractRuntimeService implements RuntimeNodeService {
  
  private static final Logger LOGGER = LoggerFactory.getLogger(SimpleRuntimeNodeService.class);
  @Autowired
  private ApplicationContext applicationContext;
  
  @Override
  public RuntimeNode findStartByDefinition(RuntimeDefinition currentDefinition) {
    if(currentDefinition == null) {
      return null;
    }
    Set<RuntimeNode> nodes = currentDefinition.getNodes();
    if(CollectionUtils.isEmpty(nodes)) {
      return null;
    }
    
    // 查询开始节点
    return nodes.stream().filter(item -> item.getType() == RuntimeNodeType.BEGIN).findFirst().orElse(null);
  }
  
  @Override
  public RuntimeNode findStartByRuntimeProcessorLinked(RuntimeProcessorLinked runtimeProcessorLinked) {
    if(runtimeProcessorLinked == null) {
      return null;
    }
    return runtimeProcessorLinked.getFounder();
  }

  public Set<RuntimeNode> findEndByDefinition(RuntimeDefinition currentDefinition) {
    if(currentDefinition == null) {
      return null;
    }
    Set<RuntimeNode> nodes = currentDefinition.getNodes();
    if(CollectionUtils.isEmpty(nodes)) {
      return null;
    }
    
    // 查询开始节点
    return nodes.stream().filter(item -> item.getType() == RuntimeNodeType.END).collect(Collectors.toSet());
  }

  @Override
  public RuntimeNode findByCodeAndContext(String nodeCode, RuleRuntimeContext context) {
    if(StringUtils.isBlank(nodeCode)|| context == null) {
      return null;
    }
    RuntimeDefinition runtimeDefinition = context.getRuntimeDefinition();
    Set<RuntimeNode> runtimeNodes = runtimeDefinition.getNodes();
    RuntimeNode currentNode = runtimeNodes.stream().filter(item -> StringUtils.equals(item.getCode(), nodeCode)).findFirst().orElse(null);
    return currentNode;
  }

  @Override
  public RuntimeNode findNextByContext(RuntimeNode currentNode, RuleRuntimeContext context) {
    if(!this.validate(currentNode, context)) {
      return null;
    }
    Set<RuntimeNodeNexts> nexts = currentNode.getNexts();
    RuntimeNodeType type = currentNode.getType();
    
    /*
     * 处理过程为：
     * 1、如果当前连线存在多个连接线的场景，则抛出异常，要求使用StarterRuleable的createProcessorLinkeds方法进行处理
     * 2、接着直接通过当前节点所处上下文运行时，查找下一个节点，必须要找到，否则也要抛出异常
     * */
    // 1、========
    Validate.isTrue(type != RuntimeNodeType.CONDITION && type != RuntimeNodeType.CONCURRENCY , "当前运行节点为多分支节点，不能使用该方法进行后续结点的查询，必须使用StarterRuleable的createProcessorLinkeds方法进行处理");
    Validate.isTrue(nexts.stream().filter(item -> item.getLineType() == 1).count() == 1l , "当前运行的规则节点定义，不是条件类型或者并行分支类型的节点，不能有多个常规连线，请检查!!");
    RuntimeNodeNexts next = nexts.stream().filter(item -> item.getLineType() == 1).findFirst().orElse(null);
    String toNodeCode = next.getToNodeCode();
    LOGGER.info("规则引擎正在查找下一节点：nodeCode = " + toNodeCode);
    
    // 2、========
    RuntimeNode targetNode = this.findByCodeAndContext(toNodeCode, context);
    Validate.notNull(targetNode , "未找到任何后续运行节点（%s）" , toNodeCode);
    return targetNode;
  }
  
  // 如果验证后，如果过程中发现不能继续，则返回false
  private boolean validate(RuntimeNode currentNode , RuleRuntimeContext context) {
    if(currentNode == null || context == null) {
      return false;
    }
    String fromId = currentNode.getId();
    RuntimeNodeType type = currentNode.getType();
    if(StringUtils.isBlank(fromId)) {
      return false;
    }
    // 如果当前节点是结束节点，则不用进行查找后续结点
    if(type == RuntimeNodeType.END) {
      return false;
    }
    Set<RuntimeNodeNexts> nexts = currentNode.getNexts();
    if(CollectionUtils.isEmpty(nexts)) {
      throw new IllegalArgumentException("未发现指定节点的后续节点定义，请检查规则模板设定!!");
    }
    return true;
  }
  
  @Override
  public RuntimeNode findExceptionNextByContext(RuntimeNode currentNode, RuleRuntimeContext context) {
    if(!this.validate(currentNode, context)) {
      return null;
    } 
    Set<RuntimeNodeNexts> nexts = currentNode.getNexts();
    Set<RuntimeNode> nodes = context.getRuntimeDefinition().getNodes();
    Throwable currentThrowable = context.getCurrentThrowable();
    if(currentThrowable == null) {
      return null;
    }
    if(currentThrowable instanceof InvocationTargetException) {
      currentThrowable = ((InvocationTargetException)currentThrowable).getTargetException();
    }
    
    /*
     * 处理方式为：
     * 1、首先nexts信息需要排序，然后再根据判定表达式，依次执行判定
     * 注意，只需要对异常线进行排序
     * 2、判定当前上下文记录的异常和分支允许的异常是否一致，只有一致才进行处理
     * 3、如果以上过程没有发现任何符合要求的分支节点（包括没有默认分支），则返回null
     * (和findNextByDefinition类似，请参考)
     * */
    
    // 1、=======
    Set<RuntimeNodeNexts> sortNexts = nexts.stream().filter(item -> item.getLineType() == 2).sorted((source , target) -> source.getSort() - target.getSort()).collect(Collectors.toSet());
    if(CollectionUtils.isEmpty(sortNexts)) {
      return null;
    } 
    
    // 2、=======
    ClassLoader classLoader = applicationContext.getClassLoader();
    RuntimeNodeNexts currentNext = null;
    for(RuntimeNodeNexts next : sortNexts) {
      String exceptions = next.getExceptions();
      if(StringUtils.isBlank(exceptions)) {
        continue;
      }
      String exceptionValues[] = StringUtils.split(exceptions, ",");
      for (String exceptionValue : exceptionValues) {
        Class<?> exceptionClass = null;
        try {
          exceptionClass = classLoader.loadClass(exceptionValue);
        } catch(Exception e) {
          LOGGER.error(e.getMessage() , e);
          throw new IllegalArgumentException(String.format("未发现指定的异常定义类[%s]，请检查!!", exceptionValue), e); 
        }
        
        // 如果条件成立，说明找到了匹配的分支
        if(exceptionClass.isAssignableFrom(currentThrowable.getClass())) {
          currentNext = next;
        }
      }
    }
    
    // 3、========
    // 如果没有找到任何匹配的异常分支线，则寻找没有填写任何异常条件的默认分支
    if(currentNext == null) {
      currentNext = sortNexts.stream().filter(item -> StringUtils.isBlank(item.getConditions())).findFirst().orElse(null);
    }
    if(currentNext == null) {
      return null;
    }
    String toNodeCode = currentNext.getToNodeCode();
    RuntimeNode targetNode = nodes.stream().filter(item -> item.getCode() == toNodeCode).findFirst().orElse(null);
    Validate.notNull(targetNode , "没有发现指定的节点[%s]信息，请检查!!" , toNodeCode);
    return targetNode;
  }
}
