11# Copyright (c) OpenMMLab. All rights reserved.
22import argparse
3+ import json
4+
35import cv2
4- from tritonclient .grpc import InferenceServerClient , InferInput , InferRequestedOutput
56import numpy as np
6- import json
7+ from tritonclient .grpc import (InferenceServerClient , InferInput ,
8+ InferRequestedOutput )
79
810
911def parse_args ():
1012 parser = argparse .ArgumentParser ()
11- parser .add_argument ('model_name' , type = str ,
12- help = 'model name' )
13- parser .add_argument ('image' , type = str ,
14- help = 'image path' )
13+ parser .add_argument ('model_name' , type = str , help = 'model name' )
14+ parser .add_argument ('image' , type = str , help = 'image path' )
1515 return parser .parse_args ()
1616
1717
@@ -24,14 +24,16 @@ def __init__(self, url, model_name, model_version):
2424 self ._client = InferenceServerClient (self ._url )
2525 model_config = self ._client .get_model_config (self ._model_name ,
2626 self ._model_version )
27- model_metadata = self ._client .get_model_metadata (self . _model_name ,
28- self ._model_version )
27+ model_metadata = self ._client .get_model_metadata (
28+ self . _model_name , self ._model_version )
2929 print (f'[model config]:\n { model_config } ' )
3030 print (f'[model metadata]:\n { model_metadata } ' )
3131 self ._inputs = {input .name : input for input in model_metadata .inputs }
3232 self ._input_names = list (self ._inputs )
3333 self ._outputs = {
34- output .name : output for output in model_metadata .outputs }
34+ output .name : output
35+ for output in model_metadata .outputs
36+ }
3537 self ._output_names = list (self ._outputs )
3638 self ._outputs_req = [
3739 InferRequestedOutput (name ) for name in self ._outputs
@@ -46,10 +48,10 @@ def infer(self, image, box):
4648 results: dict, {name : numpy.array}
4749 """
4850
49- inputs = [InferInput ( self . _input_names [ 0 ], image . shape ,
50- " UINT8" ),
51- InferInput (self ._input_names [1 ], box .shape ,
52- "BYTES" ) ]
51+ inputs = [
52+ InferInput ( self . _input_names [ 0 ], image . shape , ' UINT8' ),
53+ InferInput (self ._input_names [1 ], box .shape , 'BYTES' )
54+ ]
5355 inputs [0 ].set_data_from_numpy (image )
5456 inputs [1 ].set_data_from_numpy (box )
5557 results = self ._client .infer (
@@ -72,20 +74,18 @@ def visualize(img, results):
7274 cv2 .imwrite ('keypoint-detection.jpg' , img )
7375
7476
75- if __name__ == " __main__" :
77+ if __name__ == ' __main__' :
7678 args = parse_args ()
7779 model_name = args .model_name
78- model_version = "1"
79- url = " localhost:8001"
80+ model_version = '1'
81+ url = ' localhost:8001'
8082 client = GRPCTritonClient (url , model_name , model_version )
8183 img = cv2 .imread (args .image )
8284 bbox = {
8385 'type' : 'PoseBbox' ,
84- 'value' : [
85- {
86- 'bbox' : [0.0 , 0.0 , img .shape [1 ], img .shape [0 ]]
87- }
88- ]
86+ 'value' : [{
87+ 'bbox' : [0.0 , 0.0 , img .shape [1 ], img .shape [0 ]]
88+ }]
8989 }
9090 bbox = np .array ([json .dumps (bbox ).encode ('utf-8' )])
9191 results = client .infer (img , bbox )
0 commit comments