package com.biz.ai.core;

import com.biz.ai.Model.DisplayItem;
import com.biz.ai.Model.Picture;
import com.biz.ai.Model.Wine;
import com.biz.ai.util.Constant;
import com.biz.ai.util.Util;
import com.google.protobuf.ByteString;
import com.google.protobuf.Int64Value;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.StatusRuntimeException;
import org.junit.Test;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;
import tensorflow.serving.Model;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
 * The inception predict client for TensorFlow models.
 */
public class InceptionPredictClient {
    private static final Logger logger = Logger.getLogger(InceptionPredictClient.class.getName());
    private final ManagedChannel channel;
    private final PredictionServiceGrpc.PredictionServiceBlockingStub blockingStub;

    /**
     * 实例化
     *
     * @param host    主机
     * @param port    端口
     * @param timeout 超时时间 单位毫秒
     */
    public InceptionPredictClient(String host, int port, long timeout) {
        channel = ManagedChannelBuilder.forAddress(host, port)
                // Channels are secure by default (via SSL/TLS). For the example we disable TLS to avoid
                // needing certificates.
                .usePlaintext(true)
                .build();
        blockingStub = PredictionServiceGrpc.newBlockingStub(channel).withDeadlineAfter(timeout, TimeUnit.MILLISECONDS);
    }


    @Test
    public void predict() {
//        System.out.println("Start the predict client");
//          public String[] imagePath = new String[2];
////        for (int i = 0; i < 100; i++) {
////            imagePath[i] = "C:\\Users\\keith\\Desktop\\000277.jpg";
////        }
//        //陈列
//        imagePath[0] = "C:\\Users\\keith\\Desktop\\000282.jpg";
//        imagePath[1] = "C:\\Users\\keith\\Desktop\\000277.jpg";
//        //家宴
////        imagePath[0] = "C:\\Users\\keith\\Desktop\\zhuozi7.jpg";
//
////        System.out.println(imagePath.length);
//        // Run predict client to send request
//        InceptionPredictClient client = new InceptionPredictClient();
//
//        try {
//            client.do_predict(modelName, signature, modelVersion, imagePath);
////            client.for_do_predict(modelName, signature, modelVersion, imagePath);
//        } catch (Exception e) {
//            System.out.println(e);
//        } finally {
//            try {
//                client.shutdown();
//            } catch (Exception e) {
//                System.out.println(e);
//            }
//        }
//
//        System.out.println("End of predict client");
    }


    public void shutdown() throws InterruptedException {
        channel.shutdown().awaitTermination(5, TimeUnit.SECONDS);
    }


    public Picture do_predict(String modelName, String sign, String imagePath) {

        // Generate features TensorProto
        TensorProto.Builder featuresTensorBuilder = TensorProto.newBuilder();

        try {
            logger.info("Start to convert the image: " + imagePath);
            InputStream imageStream = new FileInputStream(imagePath);
            ByteString imageData = ByteString.readFrom(imageStream);
            featuresTensorBuilder.addStringVal(imageData);
            imageStream.close();
        } catch (IOException e) {
            logger.log(Level.WARNING, e.getMessage());
            System.exit(1);
        }

        TensorShapeProto.Dim featuresDim1 = TensorShapeProto.Dim.newBuilder().setSize(1).build();
        TensorShapeProto featuresShape = TensorShapeProto.newBuilder().addDim(featuresDim1).build();
        featuresTensorBuilder.setDtype(DataType.DT_STRING).setTensorShape(featuresShape);
        TensorProto featuresTensorProto = featuresTensorBuilder.build();


        // Generate gRPC request
        Model.ModelSpec.Builder modelTensorBuilder = Model.ModelSpec.newBuilder().setName(modelName).setSignatureName(sign);
//        if (modelVersion > 0) {
//            Int64Value version = Int64Value.newBuilder().setValue(modelVersion).build();
//            modelTensorBuilder.setVersion(version);
//        }
        Model.ModelSpec modelSpec = modelTensorBuilder.build();
        Predict.PredictRequest request = Predict.PredictRequest.newBuilder().setModelSpec(modelSpec).putInputs("inputs", featuresTensorProto).build();

        // Request gRPC server
        Predict.PredictResponse response;
        try {
//            long start = System.currentTimeMillis();
            response = blockingStub.predict(request);
//            long end = System.currentTimeMillis();
//            System.out.println(end - start);
            // Refer to https://github.com/thammegowda/tensorflow-grpc-java/blob/master/src/main/java/edu/usc/irds/tensorflow/grpc/TensorflowObjectRecogniser.java

            Map<String, TensorProto> outputs = response.getOutputsMap();

//            System.out.println(outputs);
            //陈列解析
            return chenglie(outputs,imagePath);
        } catch (StatusRuntimeException e) {
            logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus());
            e.printStackTrace();
            return null;
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }


    /**
     * 陈列识别解析方法
     *
     * @param outputs
     * @return
     * @throws Exception
     */
    private Picture chenglie(Map<String, TensorProto> outputs,String imgpath) throws Exception {

        //预测数量
        TensorProto num = outputs.get(Constant.num_detections);
        //检测到的种类
        TensorProto kind = outputs.get(Constant.detection_classes);
        //每一个数量对应的得分
        TensorProto score = outputs.get(Constant.detection_scores);
        //每个预测数量的位置
        TensorProto position = outputs.get(Constant.detection_boxes);

        Picture picture = new Picture();
        List<Wine> wines = new ArrayList<>();
        if (score != null) {
            //预测个数
            float d_num = 0;
            int dtType = num.getDtypeValue();
            if (dtType == 1) {
                d_num = num.getFloatVal(num.getFloatValCount() - 1);
            } else {
                throw new Exception("unknown data type");
            }

            for (int i = 0; i < d_num; i++) {
                Wine wine = new Wine();
                int dtScoreType = score.getDtypeValue();
                List<Float> scoreList = null;
                if (dtScoreType == 1) {
                    scoreList = score.getFloatValList();
                } else {
                    throw new Exception("unknown data type");
                }

                //把分数大于阈值的预测结果放到集合里
                if (scoreList.get(i) >= 0.5) {
                    //获取种类
                    int dtClassType = kind.getDtypeValue();
                    int dtPositionType = position.getDtypeValue();
                    List<Float> kindList = null;
                    if (dtClassType == 1) {
                        kindList = kind.getFloatValList();
                        wine.setKind(kindList.get(i));
                        wine.setScore(scoreList.get(i));
                        //获取位置
                        if (dtPositionType == 1) {
                            List<Float> positionList = position.getFloatValList();

                            List<Float> temp = new ArrayList<>();
                            for (int f = i * 4; f < (i + 1) * 4; f++) {
                                temp.add(positionList.get(f));
                            }
                            wine.setPosition(temp);
                        }
                    } else {
                        throw new Exception("unknown data type");
                    }
                    wines.add(wine);
                }
            }
            picture.setWineList(wines);
        }

        List<Wine> wineList = picture.getWineList();

        List<DisplayItem> displayItemList = new ArrayList<>();
        if (wineList.size() > 0) {
            for (Wine wine : wineList) {
                DisplayItem displayItem = new DisplayItem();
                displayItem.setKind(wine.getKind());

                displayItem.setX1(wine.getPosition().get(0));
                displayItem.setY1(wine.getPosition().get(1));
                displayItem.setX2(wine.getPosition().get(2));
                displayItem.setY2(wine.getPosition().get(3));

                displayItemList.add(displayItem);
            }
        }
        File file = new File(imgpath);
        if (file.exists()) {
            String path = file.getParent();
            String name = file.getName();
            String nameNoExt = Util.getFileName(name);
            String ext = Util.getFileExt(name);
            Util.label(path, nameNoExt, ext, displayItemList);
        }
        System.out.println(picture.toString());
        return picture;
    }


    public void for_do_predict(String modelName, String sign, long modelVersion, String[] imagePath) {
        try {
            long start = System.currentTimeMillis();
            for (String path : imagePath) {
                // Generate features TensorProto
                TensorProto.Builder featuresTensorBuilder = TensorProto.newBuilder();

                logger.info("Start to convert the image: " + path);
                InputStream imageStream = new FileInputStream(path);
                ByteString imageData = ByteString.readFrom(imageStream);
                featuresTensorBuilder.addStringVal(imageData);
                imageStream.close();

                TensorShapeProto.Dim featuresDim1 = TensorShapeProto.Dim.newBuilder().setSize(imagePath.length).build();
                TensorShapeProto featuresShape = TensorShapeProto.newBuilder().addDim(featuresDim1).build();
                featuresTensorBuilder.setDtype(DataType.DT_STRING).setTensorShape(featuresShape);
                TensorProto featuresTensorProto = featuresTensorBuilder.build();


                // Generate gRPC request
                Model.ModelSpec.Builder modelTensorBuilder = Model.ModelSpec.newBuilder().setName(modelName).setSignatureName(sign);
                if (modelVersion > 0) {
                    Int64Value version = Int64Value.newBuilder().setValue(modelVersion).build();
                    modelTensorBuilder.setVersion(version);
                }
                Model.ModelSpec modelSpec = modelTensorBuilder.build();
                Predict.PredictRequest request = Predict.PredictRequest.newBuilder().setModelSpec(modelSpec).putInputs("images", featuresTensorProto).build();


                // Request gRPC server
                Predict.PredictResponse response = blockingStub.predict(request);

                // Refer to https://github.com/thammegowda/tensorflow-grpc-java/blob/master/src/main/java/edu/usc/irds/tensorflow/grpc/TensorflowObjectRecogniser.java
                Map<String, TensorProto> outputs = response.getOutputsMap();
            }
            long end = System.currentTimeMillis();
            System.out.println(end - start);
        } catch (IOException e) {
            logger.log(Level.WARNING, e.getMessage());
            System.exit(1);
        }
    }
}