Skip to content

Commit c5d7e32

Browse files
authored
Split Python interpreter in two classes (#675)
* Split Python interpreter in two classes * Update docstring * Use lazy import of C++ module
1 parent 9dbea32 commit c5d7e32

File tree

3 files changed

+126
-100
lines changed

3 files changed

+126
-100
lines changed

larq_compute_engine/tflite/python/BUILD

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,22 @@ pybind_extension(
2929
],
3030
)
3131

32+
py_library(
33+
name = "interpreter_base",
34+
srcs = [
35+
"__init__.py",
36+
"interpreter_base.py",
37+
],
38+
)
39+
3240
py_library(
3341
name = "interpreter",
3442
srcs = [
3543
"__init__.py",
3644
"interpreter.py",
3745
],
3846
deps = [
47+
":interpreter_base",
3948
":interpreter_wrapper_lite",
4049
],
4150
)
Lines changed: 10 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,9 @@
1-
from typing import Iterator, List, Tuple, Union, Optional
2-
3-
import numpy as np
4-
from tqdm import tqdm
5-
6-
from larq_compute_engine.tflite.python import interpreter_wrapper_lite
1+
from larq_compute_engine.tflite.python.interpreter_base import InterpreterBase
72

83
__all__ = ["Interpreter"]
94

10-
Data = Union[np.ndarray, List[np.ndarray]]
11-
12-
13-
def data_generator(x: Union[Data, Iterator[Data]]) -> Iterator[List[np.ndarray]]:
14-
if isinstance(x, np.ndarray):
15-
for inputs in x:
16-
yield [np.expand_dims(inputs, axis=0)]
17-
elif isinstance(x, list):
18-
for inputs in zip(*x):
19-
yield [np.expand_dims(inp, axis=0) for inp in inputs]
20-
elif hasattr(x, "__next__") and hasattr(x, "__iter__"):
21-
for inputs in x:
22-
if isinstance(inputs, np.ndarray):
23-
yield [np.expand_dims(inputs, axis=0)]
24-
else:
25-
yield [np.expand_dims(inp, axis=0) for inp in inputs]
26-
else:
27-
raise ValueError(
28-
"Expected either a list of inputs or a Numpy array with implicit initial "
29-
f"batch dimension or an iterator yielding one of the above. Received: {x}"
30-
)
31-
325

33-
class Interpreter:
6+
class Interpreter(InterpreterBase):
347
"""Interpreter interface for Larq Compute Engine Models.
358
369
!!! warning
@@ -46,16 +19,12 @@ class Interpreter:
4619
interpreter.predict(input_data, verbose=1)
4720
```
4821
22+
See the base class `InterpreterBase` for the full interface.
23+
4924
# Arguments
5025
flatbuffer_model: A serialized Larq Compute Engine model in the flatbuffer format.
5126
num_threads: The number of threads used by the interpreter.
5227
use_reference_bconv: When True, uses the reference implementation of LceBconv2d.
53-
54-
# Attributes
55-
input_types: Returns a list of input types.
56-
input_shapes: Returns a list of input shapes.
57-
output_types: Returns a list of output types.
58-
output_shapes: Returns a list of output shapes.
5928
"""
6029

6130
def __init__(
@@ -64,69 +33,10 @@ def __init__(
6433
num_threads: int = 1,
6534
use_reference_bconv: bool = False,
6635
):
67-
self.interpreter = interpreter_wrapper_lite.LiteInterpreter(
68-
flatbuffer_model, num_threads, use_reference_bconv
69-
)
70-
71-
@property
72-
def input_types(self) -> list:
73-
"""Returns a list of input types."""
74-
return self.interpreter.input_types
75-
76-
@property
77-
def input_shapes(self) -> List[Tuple[int]]:
78-
"""Returns a list of input shapes."""
79-
return self.interpreter.input_shapes
80-
81-
@property
82-
def input_scales(self) -> List[Optional[Union[float, List[float]]]]:
83-
"""Returns a list of input scales."""
84-
return self.interpreter.input_scales
36+
from larq_compute_engine.tflite.python import interpreter_wrapper_lite
8537

86-
@property
87-
def input_zero_points(self) -> List[Optional[int]]:
88-
"""Returns a list of input zero points."""
89-
return self.interpreter.input_zero_points
90-
91-
@property
92-
def output_types(self) -> list:
93-
"""Returns a list of output types."""
94-
return self.interpreter.output_types
95-
96-
@property
97-
def output_shapes(self) -> List[Tuple[int]]:
98-
"""Returns a list of output shapes."""
99-
return self.interpreter.output_shapes
100-
101-
@property
102-
def output_scales(self) -> List[Optional[Union[float, List[float]]]]:
103-
"""Returns a list of input scales."""
104-
return self.interpreter.output_scales
105-
106-
@property
107-
def output_zero_points(self) -> List[Optional[int]]:
108-
"""Returns a list of input zero points."""
109-
return self.interpreter.output_zero_points
110-
111-
def predict(self, x: Union[Data, Iterator[Data]], verbose: int = 0) -> Data:
112-
"""Generates output predictions for the input samples.
113-
114-
# Arguments
115-
x: Input samples. A Numpy array, or a list of arrays in case the model has
116-
multiple inputs.
117-
verbose: Verbosity mode, 0 or 1.
118-
119-
# Returns
120-
Numpy array(s) of output predictions.
121-
"""
122-
123-
data_iterator = data_generator(x)
124-
if verbose >= 1:
125-
data_iterator = tqdm(data_iterator)
126-
127-
prediction_iter = (self.interpreter.predict(inputs) for inputs in data_iterator)
128-
outputs = [np.concatenate(batches) for batches in zip(*prediction_iter)]
129-
130-
if len(self.output_shapes) == 1:
131-
return outputs[0]
132-
return outputs
38+
super().__init__(
39+
interpreter_wrapper_lite.LiteInterpreter(
40+
flatbuffer_model, num_threads, use_reference_bconv
41+
)
42+
)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from typing import Iterator, List, Tuple, Union, Optional
2+
3+
import numpy as np
4+
from tqdm import tqdm
5+
6+
Data = Union[np.ndarray, List[np.ndarray]]
7+
8+
9+
def data_generator(x: Union[Data, Iterator[Data]]) -> Iterator[List[np.ndarray]]:
10+
if isinstance(x, np.ndarray):
11+
for inputs in x:
12+
yield [np.expand_dims(inputs, axis=0)]
13+
elif isinstance(x, list):
14+
for inputs in zip(*x):
15+
yield [np.expand_dims(inp, axis=0) for inp in inputs]
16+
elif hasattr(x, "__next__") and hasattr(x, "__iter__"):
17+
for inputs in x:
18+
if isinstance(inputs, np.ndarray):
19+
yield [np.expand_dims(inputs, axis=0)]
20+
else:
21+
yield [np.expand_dims(inp, axis=0) for inp in inputs]
22+
else:
23+
raise ValueError(
24+
"Expected either a list of inputs or a Numpy array with implicit initial "
25+
f"batch dimension or an iterator yielding one of the above. Received: {x}"
26+
)
27+
28+
29+
class InterpreterBase:
30+
"""Interpreter interface for Larq Compute Engine Models.
31+
32+
# Attributes
33+
input_types: Returns a list of input types.
34+
input_shapes: Returns a list of input shapes.
35+
input_scales: Returns a list of input scales.
36+
input_zero_points: Returns a list of input zero points.
37+
output_types: Returns a list of output types.
38+
output_shapes: Returns a list of output shapes.
39+
output_scales: Returns a list of input scales.
40+
output_zero_points: Returns a list of input zero points.
41+
"""
42+
43+
def __init__(self, interpreter):
44+
self.interpreter = interpreter
45+
46+
@property
47+
def input_types(self) -> list:
48+
"""Returns a list of input types."""
49+
return self.interpreter.input_types
50+
51+
@property
52+
def input_shapes(self) -> List[Tuple[int]]:
53+
"""Returns a list of input shapes."""
54+
return self.interpreter.input_shapes
55+
56+
@property
57+
def input_scales(self) -> List[Optional[Union[float, List[float]]]]:
58+
"""Returns a list of input scales."""
59+
return self.interpreter.input_scales
60+
61+
@property
62+
def input_zero_points(self) -> List[Optional[int]]:
63+
"""Returns a list of input zero points."""
64+
return self.interpreter.input_zero_points
65+
66+
@property
67+
def output_types(self) -> list:
68+
"""Returns a list of output types."""
69+
return self.interpreter.output_types
70+
71+
@property
72+
def output_shapes(self) -> List[Tuple[int]]:
73+
"""Returns a list of output shapes."""
74+
return self.interpreter.output_shapes
75+
76+
@property
77+
def output_scales(self) -> List[Optional[Union[float, List[float]]]]:
78+
"""Returns a list of input scales."""
79+
return self.interpreter.output_scales
80+
81+
@property
82+
def output_zero_points(self) -> List[Optional[int]]:
83+
"""Returns a list of input zero points."""
84+
return self.interpreter.output_zero_points
85+
86+
def predict(self, x: Union[Data, Iterator[Data]], verbose: int = 0) -> Data:
87+
"""Generates output predictions for the input samples.
88+
89+
# Arguments
90+
x: Input samples. A Numpy array, or a list of arrays in case the model has
91+
multiple inputs.
92+
verbose: Verbosity mode, 0 or 1.
93+
94+
# Returns
95+
Numpy array(s) of output predictions.
96+
"""
97+
98+
data_iterator = data_generator(x)
99+
if verbose >= 1:
100+
data_iterator = tqdm(data_iterator)
101+
102+
prediction_iter = (self.interpreter.predict(inputs) for inputs in data_iterator)
103+
outputs = [np.concatenate(batches) for batches in zip(*prediction_iter)]
104+
105+
if len(self.output_shapes) == 1:
106+
return outputs[0]
107+
return outputs

0 commit comments

Comments
 (0)