|
| 1 | +package src.main.java; |
| 2 | + |
| 3 | +import com.sun.jna.Library; |
| 4 | +import com.sun.jna.Native; |
| 5 | +import com.sun.jna.Pointer; |
| 6 | +import com.sun.jna.ptr.PointerByReference; |
| 7 | + |
| 8 | +import tensorflow.eas.Predict; |
| 9 | + |
| 10 | +class Demo { |
| 11 | + |
| 12 | + public static String modelConfig = |
| 13 | + "{\"omp_num_threads\": 4," + |
| 14 | + "\"kmp_blocktime\": 0," + |
| 15 | + "\"feature_store_type\": \"memory\"," + |
| 16 | + "\"serialize_protocol\": \"protobuf\"," + |
| 17 | + "\"inter_op_parallelism_threads\": 10," + |
| 18 | + "\"intra_op_parallelism_threads\": 10," + |
| 19 | + "\"init_timeout_minutes\": 1," + |
| 20 | + "\"signature_name\": \"serving_default\"," + |
| 21 | + "\"read_thread_num\": 3," + |
| 22 | + "\"update_thread_num\": 2," + |
| 23 | + "\"model_store_type\": \"local\"," + |
| 24 | + "\"checkpoint_dir\": \"/tmp/checkpoint/\"," + |
| 25 | + "\"savedmodel_dir\": \"/tmp/saved_model/\"}"; |
| 26 | + |
| 27 | + // Load shared library via JNA |
| 28 | + public interface Processor extends Library |
| 29 | + { |
| 30 | + Processor INSTANCE = (Processor) Native.load("serving_processor", Processor.class); |
| 31 | + |
| 32 | + // Define shared library function prototype |
| 33 | + public Pointer initialize(String modelEntry, String modelConfig, int[] state); |
| 34 | + |
| 35 | + public int process(Pointer model, byte[] buffer, int size, PointerByReference outputData, int[] outputSize); |
| 36 | + } |
| 37 | + |
| 38 | + public static void main(String[] args) { |
| 39 | + Demo demo = new Demo(); |
| 40 | + String modelEntry = ""; |
| 41 | + int[] state = {0}; |
| 42 | + Pointer model = Processor.INSTANCE.initialize(modelEntry, modelConfig, state); |
| 43 | + if (state[0] == -1) { |
| 44 | + System.err.println("initialize error"); |
| 45 | + } |
| 46 | + |
| 47 | + // input type: float |
| 48 | + Predict.ArrayDataType dtype = Predict.ArrayDataType.DT_FLOAT; |
| 49 | + // input shape: [1, 1] |
| 50 | + Predict.ArrayShape arrayShape = |
| 51 | + Predict.ArrayShape.newBuilder() |
| 52 | + .addDim(1) |
| 53 | + .addDim(1) |
| 54 | + .build(); |
| 55 | + // input array |
| 56 | + Predict.ArrayProto input = |
| 57 | + Predict.ArrayProto.newBuilder() |
| 58 | + .addFloatVal((float) 1.0) |
| 59 | + .setDtype(dtype) |
| 60 | + .setArrayShape(arrayShape) |
| 61 | + .build(); |
| 62 | + // PredictRequest |
| 63 | + Predict.PredictRequest req = |
| 64 | + Predict.PredictRequest.newBuilder() |
| 65 | + .setSignatureName("serving_default") |
| 66 | + .addOutputFilter("y:0") |
| 67 | + .putInputs("x:0", input) |
| 68 | + .build(); |
| 69 | + byte[] buffer = req.toByteArray(); |
| 70 | + int size = req.getSerializedSize(); |
| 71 | + |
| 72 | + // do process |
| 73 | + PointerByReference output = new PointerByReference(); |
| 74 | + int[] outputSize = {0}; |
| 75 | + state[0] = Processor.INSTANCE.process(model, buffer, size, output, outputSize); |
| 76 | + |
| 77 | + // parse response |
| 78 | + byte[] outputString = output.getValue().getByteArray(0, outputSize[0]); |
| 79 | + String s = new String(outputString); |
| 80 | + try { |
| 81 | + Predict.PredictResponse resp = |
| 82 | + Predict.PredictResponse.newBuilder() |
| 83 | + .mergeFrom(outputString) |
| 84 | + .build(); |
| 85 | + System.out.println(resp.toString()); |
| 86 | + } catch (Exception e) { |
| 87 | + System.err.println("parse response error"); |
| 88 | + } |
| 89 | + |
| 90 | + } |
| 91 | +} |
0 commit comments