package com.bizunited.platform.core.repository.internal;

import com.alibaba.druid.pool.DruidDataSource;
import com.bizunited.platform.common.repository.PageRepositoryImpl;
import com.bizunited.platform.core.common.enums.DataSourceTypeEnum;
import com.bizunited.platform.core.entity.DataSourceTableEntity;
import com.google.common.collect.Sets;
import org.apache.commons.lang3.StringUtils;
import org.hibernate.Session;
import org.hibernate.SessionFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.orm.hibernate5.SessionFactoryUtils;
import org.springframework.stereotype.Repository;

import javax.persistence.EntityManager;
import javax.persistence.PersistenceContext;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/**
 * 数据源表持久化接口实现
 *
 * @Author: Paul Chan
 * @Date: 2019-04-15 15:39
 */
@Repository("DataSourceTableRepositoryImpl")
public class DataSourceTableRepositoryImpl implements DataSourceTableRepositoryCustom, PageRepositoryImpl {

  private static final Logger LOGGER = LoggerFactory.getLogger(DataSourceTableRepositoryImpl.class);
  /**
   * oracle数据库
   */
  private static final String ORACLE_QUERY_TABLE_SQL = "select TABLE_NAME from user_tables";

  @Autowired
  @PersistenceContext
  private EntityManager entityManager;

  @Override
  public Page<DataSourceTableEntity> findByConditions(Pageable pageable, Integer tableType, String dataSourceId, Integer status, String tableName) {
    StringBuilder hql = new StringBuilder("select dtt from DataSourceTableEntity dtt left join fetch dtt.dataSource ds where 1=1 ");
    StringBuilder countHql = new StringBuilder("select count(*) from DataSourceTableEntity dtt left join dtt.dataSource ds where 1=1 ");
    StringBuilder conditions = new StringBuilder();
    Map<String, Object> parameters = new HashMap<>();
    if(tableType != null){
      conditions.append(" AND dtt.tableType = :tableType");
      parameters.put("tableType", tableType);
    }
    if(status != null){
      conditions.append(" AND dtt.tStatus = :status");
      parameters.put("status", status);
    }
    if(StringUtils.isNotBlank(tableName)){
      conditions.append(" AND dtt.tableName = :tableName");
      parameters.put("tableName", tableName);
    }
    if(StringUtils.isNotBlank(dataSourceId)){
      conditions.append(" AND ds.id = :dataSourceId");
      parameters.put("dataSourceId", dataSourceId);
    } else {
      conditions.append(" AND ds.id is null");
    }
    hql.append(conditions);
    countHql.append(conditions);
    return queryByConditions(entityManager, hql.toString(), countHql.toString(), parameters, pageable, false, null);
  }

  @Override
  public Set<String> queryMainDataSourceTables() {
    Session session = (Session) entityManager.getDelegate();
    return queryDataSourceTables(session.getSessionFactory());
  }

  @Override
  public Set<String> queryDataSourceTables(SessionFactory sessionFactory) {
    DataSource dataSource = SessionFactoryUtils.getDataSource(sessionFactory);
    try(Connection conn = dataSource.getConnection()) {
      String driverClass = this.getDriverClass(dataSource);
      DataSourceTypeEnum dataSourceType = DataSourceTypeEnum.valueOfDriver(driverClass);
      if(DataSourceTypeEnum.MYSQL.equals(dataSourceType)) {
        return this.queryMysqlTables(conn);
      } else {
        return this.queryOracleTables(conn);
      }
    } catch (SQLException e) {
      LOGGER.warn(e.getMessage(), e);
    }
    return Sets.newHashSet();
  }

  /**
   * 获取mysql的数据库表
   * @param conn
   * @return
   * @throws SQLException
   */
  private Set<String> queryMysqlTables(Connection conn) throws SQLException {
    String catalog = conn.getCatalog();
    DatabaseMetaData metaData = conn.getMetaData();
    ResultSet rs = metaData.getTables(catalog, null, null, new String[] { "TABLE" });
    return this.getTableNames(rs);
  }

  /**
   * 查询oracle的数据库表
   * @param conn
   * @return
   * @throws SQLException
   */
  private Set<String> queryOracleTables(Connection conn) throws SQLException {
    try (PreparedStatement pre = conn.prepareStatement(ORACLE_QUERY_TABLE_SQL)) {
      ResultSet rs = pre.executeQuery();
      return this.getTableNames(rs);
    }
  }

  /**
   * 获取数据库表
   * @param rs
   * @return
   * @throws SQLException
   */
  private Set<String> getTableNames(ResultSet rs) throws SQLException {
    Set<String> tableNames = new HashSet<>();
    while (rs.next()){
      String tableName = rs.getString("TABLE_NAME");
      tableNames.add(tableName);
    }
    rs.close();
    return tableNames;
  }

  /**
   * 根据数据源获取驱动类
   * @param dataSource
   * @return
   */
  private String getDriverClass(DataSource dataSource) {
    String dataSourceName = dataSource.getClass().getName();
    if(dataSourceName.equals("com.alibaba.druid.pool.DruidDataSource")) {
      DruidDataSource druidDataSource = (DruidDataSource) dataSource;
      return druidDataSource.getDriverClassName();
    }
    return DataSourceTypeEnum.MYSQL.getDriverClassNames()[0];
  }

}
