1- from typing import Iterator , List , Tuple , Union , Optional
2-
3- import numpy as np
4- from tqdm import tqdm
5-
6- from larq_compute_engine .tflite .python import interpreter_wrapper_lite
1+ from larq_compute_engine .tflite .python .interpreter_base import InterpreterBase
72
83__all__ = ["Interpreter" ]
94
10- Data = Union [np .ndarray , List [np .ndarray ]]
11-
12-
13- def data_generator (x : Union [Data , Iterator [Data ]]) -> Iterator [List [np .ndarray ]]:
14- if isinstance (x , np .ndarray ):
15- for inputs in x :
16- yield [np .expand_dims (inputs , axis = 0 )]
17- elif isinstance (x , list ):
18- for inputs in zip (* x ):
19- yield [np .expand_dims (inp , axis = 0 ) for inp in inputs ]
20- elif hasattr (x , "__next__" ) and hasattr (x , "__iter__" ):
21- for inputs in x :
22- if isinstance (inputs , np .ndarray ):
23- yield [np .expand_dims (inputs , axis = 0 )]
24- else :
25- yield [np .expand_dims (inp , axis = 0 ) for inp in inputs ]
26- else :
27- raise ValueError (
28- "Expected either a list of inputs or a Numpy array with implicit initial "
29- f"batch dimension or an iterator yielding one of the above. Received: { x } "
30- )
31-
325
33- class Interpreter :
6+ class Interpreter ( InterpreterBase ) :
347 """Interpreter interface for Larq Compute Engine Models.
358
369 !!! warning
@@ -46,16 +19,12 @@ class Interpreter:
4619 interpreter.predict(input_data, verbose=1)
4720 ```
4821
22+ See the base class `InterpreterBase` for the full interface.
23+
4924 # Arguments
5025 flatbuffer_model: A serialized Larq Compute Engine model in the flatbuffer format.
5126 num_threads: The number of threads used by the interpreter.
5227 use_reference_bconv: When True, uses the reference implementation of LceBconv2d.
53-
54- # Attributes
55- input_types: Returns a list of input types.
56- input_shapes: Returns a list of input shapes.
57- output_types: Returns a list of output types.
58- output_shapes: Returns a list of output shapes.
5928 """
6029
6130 def __init__ (
@@ -64,69 +33,10 @@ def __init__(
6433 num_threads : int = 1 ,
6534 use_reference_bconv : bool = False ,
6635 ):
67- self .interpreter = interpreter_wrapper_lite .LiteInterpreter (
68- flatbuffer_model , num_threads , use_reference_bconv
69- )
70-
71- @property
72- def input_types (self ) -> list :
73- """Returns a list of input types."""
74- return self .interpreter .input_types
75-
76- @property
77- def input_shapes (self ) -> List [Tuple [int ]]:
78- """Returns a list of input shapes."""
79- return self .interpreter .input_shapes
80-
81- @property
82- def input_scales (self ) -> List [Optional [Union [float , List [float ]]]]:
83- """Returns a list of input scales."""
84- return self .interpreter .input_scales
36+ from larq_compute_engine .tflite .python import interpreter_wrapper_lite
8537
86- @property
87- def input_zero_points (self ) -> List [Optional [int ]]:
88- """Returns a list of input zero points."""
89- return self .interpreter .input_zero_points
90-
91- @property
92- def output_types (self ) -> list :
93- """Returns a list of output types."""
94- return self .interpreter .output_types
95-
96- @property
97- def output_shapes (self ) -> List [Tuple [int ]]:
98- """Returns a list of output shapes."""
99- return self .interpreter .output_shapes
100-
101- @property
102- def output_scales (self ) -> List [Optional [Union [float , List [float ]]]]:
103- """Returns a list of input scales."""
104- return self .interpreter .output_scales
105-
106- @property
107- def output_zero_points (self ) -> List [Optional [int ]]:
108- """Returns a list of input zero points."""
109- return self .interpreter .output_zero_points
110-
111- def predict (self , x : Union [Data , Iterator [Data ]], verbose : int = 0 ) -> Data :
112- """Generates output predictions for the input samples.
113-
114- # Arguments
115- x: Input samples. A Numpy array, or a list of arrays in case the model has
116- multiple inputs.
117- verbose: Verbosity mode, 0 or 1.
118-
119- # Returns
120- Numpy array(s) of output predictions.
121- """
122-
123- data_iterator = data_generator (x )
124- if verbose >= 1 :
125- data_iterator = tqdm (data_iterator )
126-
127- prediction_iter = (self .interpreter .predict (inputs ) for inputs in data_iterator )
128- outputs = [np .concatenate (batches ) for batches in zip (* prediction_iter )]
129-
130- if len (self .output_shapes ) == 1 :
131- return outputs [0 ]
132- return outputs
38+ super ().__init__ (
39+ interpreter_wrapper_lite .LiteInterpreter (
40+ flatbuffer_model , num_threads , use_reference_bconv
41+ )
42+ )
0 commit comments