package com.bizunited.platform.rbac.okta.starter.handle;

import com.bizunited.platform.common.controller.model.ResponseCode;
import com.bizunited.platform.common.controller.model.ResponseModel;
import com.bizunited.platform.rbac.security.starter.handle.HandleOutPut;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.util.UrlUtils;
import org.springframework.web.util.UriComponentsBuilder;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Base64;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;

/**
 * 重写了访问请求无权限时或者登录信息过期时的异常处理器，请注意，只有在鉴权并明确抛出AccessDeniedException异常时，该处理器才会生效
 * @author yinwenjie
 *
 */
public class SimpleAccessDeniedHandler implements AuthenticationEntryPoint , AccessDeniedHandler , InitializingBean , HandleOutPut {
  
  private static final String SESSION_USER_MISS = "用户已失效或未登录!";
  
  private static final String SESSION_DENIED = "用户无权访问该功能!";

  private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());

  private ClientRegistrationRepository clientRegistrationRepository;

  private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
      new HttpSessionOAuth2AuthorizationRequestRepository();

  public SimpleAccessDeniedHandler(ClientRegistrationRepository clientRegistrationRepository) {
    this.clientRegistrationRepository = clientRegistrationRepository;
  }
  
  @Override
  public void commence(HttpServletRequest request, HttpServletResponse response, AuthenticationException authException) throws IOException, ServletException {
    OAuth2AuthorizationRequest authorizationRequest = this.getOktaOAuth2AuthorizationRequest(request);
    if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(authorizationRequest.getGrantType())) {
      this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
    }
    ResponseModel result = new ResponseModel(new Date().getTime(), authorizationRequest.getAuthorizationRequestUri(), ResponseCode.E601,new IllegalAccessException(SESSION_USER_MISS));
    this.writeResponse(response, result);
  }


  @Override
  public void afterPropertiesSet() throws Exception {
  }

  @Override
  public void handle(HttpServletRequest request, HttpServletResponse response, AccessDeniedException accessDeniedException) throws IOException, ServletException {
    ResponseModel result = new ResponseModel(new Date().getTime(), null, ResponseCode.E602, new IllegalAccessException(SESSION_DENIED));
    this.writeResponse(response, result);
  }

  /**
   * 获取okta权限认证的请求，根据认证请求，可以获取到okta登录的url等信息
   * @param request
   * @return
   */
  private OAuth2AuthorizationRequest getOktaOAuth2AuthorizationRequest(HttpServletRequest request) {
    String registrationId = "okta";
    ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId);
    if (clientRegistration == null) {
      throw new IllegalArgumentException("Invalid Client Registration with Id: " + registrationId);
    }

    OAuth2AuthorizationRequest.Builder builder;
    if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
      builder = OAuth2AuthorizationRequest.authorizationCode();
    } else if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizationGrantType())) {
      builder = OAuth2AuthorizationRequest.implicit();
    } else {
      throw new IllegalArgumentException("Invalid Authorization Grant Type ("  +
          clientRegistration.getAuthorizationGrantType().getValue() +
          ") for Client Registration with Id: " + clientRegistration.getRegistrationId());
    }

    String redirectUriStr = this.expandRedirectUri(request, clientRegistration, "login");

    Map<String, Object> additionalParameters = new HashMap<>();
    additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());

    OAuth2AuthorizationRequest authorizationRequest = builder
        .clientId(clientRegistration.getClientId())
        .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
        .redirectUri(redirectUriStr)
        .scopes(clientRegistration.getScopes())
        .state(this.stateGenerator.generateKey())
        .additionalParameters(additionalParameters)
        .build();
    return authorizationRequest;
  }

  /**
   * 拼装重定向的url
   * @param request
   * @param clientRegistration
   * @param action
   * @return
   */
  private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration, String action) {
    // Supported URI variables -> baseUrl, action, registrationId
    // Used in -> CommonOAuth2Provider.DEFAULT_REDIRECT_URL = "{baseUrl}/{action}/oauth2/code/{registrationId}"
    Map<String, String> uriVariables = new HashMap<>();
    uriVariables.put("registrationId", clientRegistration.getRegistrationId());
    String baseUrl = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
        .replaceQuery(null)
        .replacePath(request.getContextPath())
        .build()
        .toUriString();
    uriVariables.put("baseUrl", baseUrl);
    if (action != null) {
      uriVariables.put("action", action);
    }
    return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUriTemplate())
        .buildAndExpand(uriVariables)
        .toUriString();
  }
}
