Skip to content

Commit 3e1b148

Browse files
lgeigerTombana
andauthored
Add folders for lq.quantize and lq.dequantize (#654)
Co-Authored-By: Tom Bannink <Tombana@users.noreply.github.com> Co-authored-by: Tom Bannink <Tombana@users.noreply.github.com>
1 parent 6e0f432 commit 3e1b148

File tree

12 files changed

+255
-37
lines changed

12 files changed

+255
-37
lines changed

configure.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def get_input(question):
7272

7373
def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var, var_default):
7474
"""Get var_name either from env, or user or default.
75+
7576
If var_name has been set as environment variable, use the preset value, else
7677
ask for user input. If no input is provided, the default is used.
7778
Args:
@@ -101,6 +102,7 @@ def get_var(
101102
no_reply=None,
102103
):
103104
"""Get boolean input from user.
105+
104106
If var_name is not set in env, ask user to enable query_item or not. If the
105107
response is empty, use the default.
106108
Args:
@@ -380,6 +382,7 @@ def setup_python(environ_cp):
380382

381383
def set_cc_opt_flags(environ_cp):
382384
"""Set up architecture-dependent optimization flags.
385+
383386
Also append CC optimization flags to bazel.rc.
384387
Args:
385388
environ_cp: copy of the os.environ.

larq_compute_engine/core/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ cc_library(
77
hdrs = [
88
"types.h",
99
],
10+
deps = [
11+
"@org_tensorflow//tensorflow/lite/kernels/internal:cppmath",
12+
],
1013
)
1114

1215
cc_library(

larq_compute_engine/core/types.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include <limits>
77
#include <type_traits>
88

9+
#include "tensorflow/lite/kernels/internal/cppmath.h"
10+
911
namespace compute_engine {
1012
namespace core {
1113

@@ -17,6 +19,12 @@ namespace core {
1719
#define LCE_UNLIKELY(condition) (condition)
1820
#endif
1921

22+
#if defined(__GNUC__)
23+
#define FORCE_INLINE __attribute__((always_inline)) inline
24+
#else
25+
#define FORCE_INLINE inline
26+
#endif
27+
2028
// Check that 0 <= index < limit using a single comparison, assuming
2129
// that 0 <= limit if Index is signed. Intended for use in performance
2230
// critical contexts where 0 <= index < limit is almost always true.
@@ -38,6 +46,53 @@ inline int xor_popcount(const TBitpacked& a, const TBitpacked& b) {
3846
return std::bitset<bitpacking_bitwidth>(a ^ b).count();
3947
}
4048

49+
// Clamp an int32 value to int8 range
50+
inline std::int8_t saturate(std::int32_t x) {
51+
#ifdef __arm__
52+
std::int8_t y;
53+
asm("ssat %[y], #8, %[x]\n" : [y] "=r"(y) : [x] "r"(x));
54+
return y;
55+
#else
56+
x = std::min<std::int32_t>(x, std::numeric_limits<std::int8_t>::max());
57+
x = std::max<std::int32_t>(x, std::numeric_limits<std::int8_t>::lowest());
58+
return static_cast<std::int8_t>(x);
59+
#endif
60+
}
61+
62+
// arithmetic right shift and clamp an int32 value to int8 range
63+
template <int shift>
64+
inline std::int8_t shift_saturate(std::int32_t x) {
65+
#ifdef __arm__
66+
std::int8_t y;
67+
asm("ssat %[y], #8, %[x], asr %[shift]\n"
68+
: [y] "=r"(y)
69+
: [x] "r"(x), [shift] "i"(shift));
70+
return y;
71+
#else
72+
x = x >> shift;
73+
x = std::min<std::int32_t>(x, std::numeric_limits<std::int8_t>::max());
74+
x = std::max<std::int32_t>(x, std::numeric_limits<std::int8_t>::lowest());
75+
return static_cast<std::int8_t>(x);
76+
#endif
77+
}
78+
79+
// Round-to-nearest. Handling of ties is allowed to be anything, as discussed in
80+
// https://github.com/tensorflow/tensorflow/issues/25087
81+
inline std::int32_t round(float x) {
82+
#if defined(__thumb__) && defined(__VFP_FP__) && !defined(__SOFTFP__)
83+
// The `vcvtr` instructions follows the IEEE 754 rounding standard which
84+
// rounds halfway points to the nearest *even* integer.
85+
std::int32_t y;
86+
asm("vcvtr.s32.f32 %[x], %[x] \n"
87+
"vmov %[y], %[x] \n"
88+
: [y] "=r"(y)
89+
: [x] "t"(x)); // The "t" means `x` will be in an FPU register
90+
return y;
91+
#else
92+
return ::tflite::TfLiteRound(x);
93+
#endif
94+
}
95+
4196
template <typename T, typename S>
4297
constexpr T CeilDiv(T a, S b) {
4398
return (a + b - 1) / b;

larq_compute_engine/mlir/BUILD

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,21 @@ gentbl(
134134
],
135135
)
136136

137+
cc_library(
138+
name = "larq_compute_engine_bitpack",
139+
srcs = [
140+
"transforms/bitpack.cc",
141+
],
142+
hdrs = [
143+
"transforms/bitpack.h",
144+
],
145+
deps = [
146+
"//larq_compute_engine/core:types",
147+
"//larq_compute_engine/core/bitpacking:bitpack",
148+
"@llvm-project//mlir:IR",
149+
],
150+
)
151+
137152
cc_library(
138153
name = "larq_compute_engine",
139154
srcs = [
@@ -147,6 +162,7 @@ cc_library(
147162
"transforms/passes.h",
148163
],
149164
deps = [
165+
":larq_compute_engine_bitpack",
150166
"//larq_compute_engine/core/bitpacking:bitpack",
151167
"@flatbuffers",
152168
"@llvm-project//mlir:QuantOps",
@@ -225,8 +241,7 @@ cc_library(
225241
],
226242
deps = [
227243
":larq_compute_engine",
228-
"//larq_compute_engine/core:types",
229-
"//larq_compute_engine/core/bitpacking:bitpack",
244+
":larq_compute_engine_bitpack",
230245
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite",
231246
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
232247
],

larq_compute_engine/mlir/ir/lce_ops.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "flatbuffers/flexbuffers.h"
44
#include "larq_compute_engine/core/bitpacking/bitpack.h"
5+
#include "larq_compute_engine/mlir/transforms/bitpack.h"
56
#include "tensorflow/lite/schema/schema_generated.h"
67

78
static tflite::Padding ConvertPaddingAttr(llvm::StringRef str) {
@@ -67,6 +68,18 @@ void QuantizeOp::build(OpBuilder& builder, OperationState& state, Value x) {
6768
state.addTypes(RankedTensorType::get(shape, builder.getIntegerType(32)));
6869
}
6970

71+
OpFoldResult QuantizeOp::fold(ArrayRef<Attribute> operands) {
72+
mlir::OpBuilder builder(getOperation());
73+
if (!operands[0]) return nullptr;
74+
return mlir::TFL::Bitpack(&builder, operands[0]);
75+
}
76+
77+
OpFoldResult DequantizeOp::fold(ArrayRef<Attribute> operands) {
78+
auto result_type = getType().cast<ShapedType>();
79+
if (!operands[0]) return nullptr;
80+
return mlir::TFL::Unpack(operands[0], result_type);
81+
}
82+
7083
void LarqDialect::initialize() {
7184
addOperations<
7285
#define GET_OP_LIST

larq_compute_engine/mlir/ir/lce_ops.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ Converts floating point or integer tensors to binarized bitpacked tensors.
8282
);
8383

8484
let builders = [OpBuilder<(ins "Value":$x)>];
85+
86+
let hasFolder = 1;
8587
}
8688

8789
def LQ_DequantizeOp : LQ_Op<"Dequantize", [NoSideEffect]> {
@@ -98,6 +100,8 @@ Converts binarized bitpacked tensors to floating point or integer tensors.
98100
let results = (outs
99101
TensorOf<[BF16, F16, F32, F64, I32, I64, QI8, QI16]>:$y
100102
);
103+
104+
let hasFolder = 1;
101105
}
102106

103107
def LQ_Bconv2dOp : LQ_Op<"Bconv2d", [NoSideEffect]> {

larq_compute_engine/mlir/lce_mlir_opt.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
#include "mlir/Dialect/Quant/QuantOps.h"
33
#include "mlir/Dialect/StandardOps/IR/Ops.h"
44
#include "mlir/Support/MlirOptMain.h"
5+
#include "mlir/Transforms/Passes.h"
56
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
67
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
78

89
int main(int argc, char** argv) {
10+
mlir::registerTransformsPasses();
911
mlir::DialectRegistry registry;
1012
registry.insert<mlir::StandardOpsDialect, mlir::quant::QuantizationDialect,
1113
mlir::TF::TensorFlowDialect, mlir::TFL::TensorFlowLiteDialect,
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: lce-tf-opt %s -canonicalize | FileCheck %s
2+
3+
// CHECK-LABEL: @quantize
4+
func @quantize() -> (tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>) {
5+
%pos = constant dense< 0.5> : tensor<1x1x2x32xf32>
6+
%neg = constant dense<-0.5> : tensor<1x1x2x32xf32>
7+
%0 = "lq.Quantize"(%pos) {} : (tensor<1x1x2x32xf32>) -> tensor<1x1x2x1xi32>
8+
%1 = "lq.Quantize"(%neg) {} : (tensor<1x1x2x32xf32>) -> tensor<1x1x2x1xi32>
9+
return %0, %1 : tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>
10+
11+
// CHECK: %[[neg:.*]] = constant dense<-1> : tensor<1x1x2x1xi32>
12+
// CHECK: %[[pos:.*]] = constant dense<0> : tensor<1x1x2x1xi32>
13+
// CHECK: return %[[pos]], %[[neg]] : tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>
14+
}
15+
16+
// CHECK-LABEL: @dequantize
17+
func @dequantize() -> (tensor<1x1x2x32xf32>, tensor<1x1x2x32xf32>) {
18+
%pos = constant dense<0> : tensor<1x1x2x1xi32>
19+
%neg = constant dense<-1> : tensor<1x1x2x1xi32>
20+
%0 = "lq.Dequantize"(%pos) {} : (tensor<1x1x2x1xi32>) -> tensor<1x1x2x32xf32>
21+
%1 = "lq.Dequantize"(%neg) {} : (tensor<1x1x2x1xi32>) -> tensor<1x1x2x32xf32>
22+
return %0, %1 : tensor<1x1x2x32xf32>, tensor<1x1x2x32xf32>
23+
24+
// CHECK: %[[neg:.*]] = constant dense<-1.000000e+00> : tensor<1x1x2x32xf32>
25+
// CHECK: %[[pos:.*]] = constant dense<1.000000e+00> : tensor<1x1x2x32xf32>
26+
// CHECK: return %[[pos]], %[[neg]] : tensor<1x1x2x32xf32>, tensor<1x1x2x32xf32>
27+
}
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#include "larq_compute_engine/mlir/transforms/bitpack.h"
2+
3+
#include <cmath>
4+
#include <vector>
5+
6+
#include "larq_compute_engine/core/bitpacking/bitpack.h"
7+
#include "larq_compute_engine/core/types.h"
8+
#include "mlir/Dialect/Quant/QuantTypes.h"
9+
10+
namespace mlir {
11+
namespace TFL {
12+
13+
using compute_engine::core::bitpacking_bitwidth;
14+
using compute_engine::core::round;
15+
using compute_engine::core::saturate;
16+
using compute_engine::core::TBitpacked;
17+
using namespace compute_engine::core::bitpacking;
18+
19+
DenseElementsAttr Bitpack(mlir::Builder* builder, Attribute x) {
20+
if (!x) return nullptr;
21+
22+
// ShapedType is something like tensor<1x2x3xf32> and element_type is f32
23+
auto shaped_type = x.getType().cast<ShapedType>();
24+
auto shape = shaped_type.getShape();
25+
auto element_type = shaped_type.getElementType();
26+
27+
int num_rows = shape[0] * shape[1] * shape[2];
28+
int unpacked_channels = shape[3];
29+
int packed_channels = GetBitpackedSize(unpacked_channels);
30+
31+
std::vector<TBitpacked> new_values(num_rows * packed_channels);
32+
33+
if (element_type.isF32()) {
34+
const auto& dense_elements_iter =
35+
x.cast<DenseElementsAttr>().getValues<float>();
36+
37+
std::vector<float> old_values(num_rows * unpacked_channels);
38+
39+
int i = 0;
40+
for (float x : dense_elements_iter) {
41+
old_values[i++] = x;
42+
}
43+
assert(i == num_rows * unpacked_channels);
44+
45+
bitpack_matrix(old_values.data(), num_rows, unpacked_channels,
46+
new_values.data());
47+
} else {
48+
// constant-fold bitpacking int8 tensors is currently not supported
49+
return nullptr;
50+
}
51+
52+
RankedTensorType out_tensor_type =
53+
RankedTensorType::get({shape[0], shape[1], shape[2], packed_channels},
54+
builder->getIntegerType(bitpacking_bitwidth));
55+
56+
return DenseElementsAttr::get<TBitpacked>(out_tensor_type, new_values);
57+
}
58+
59+
DenseElementsAttr Unpack(Attribute x, ShapedType result_type) {
60+
if (!x) return nullptr;
61+
if (!result_type.hasStaticShape()) return nullptr;
62+
63+
auto input_shape = x.getType().cast<ShapedType>().getShape();
64+
auto output_shape = result_type.getShape();
65+
auto output_type = result_type.getElementType();
66+
67+
int num_rows = output_shape[0] * output_shape[1] * output_shape[2];
68+
int unpacked_channels = output_shape[3];
69+
int packed_channels = GetBitpackedSize(unpacked_channels);
70+
if (input_shape[0] != output_shape[0] || input_shape[1] != output_shape[1] ||
71+
input_shape[2] != output_shape[2] || input_shape[3] != packed_channels) {
72+
return nullptr;
73+
}
74+
75+
std::vector<TBitpacked> old_values(num_rows * packed_channels);
76+
77+
const auto& dense_elements_iter =
78+
x.cast<DenseElementsAttr>().getValues<TBitpacked>();
79+
80+
int i = 0;
81+
for (TBitpacked x : dense_elements_iter) {
82+
old_values[i++] = x;
83+
}
84+
assert(i == num_rows * packed_channels);
85+
86+
if (output_type.isF32()) {
87+
std::vector<float> new_values(num_rows * unpacked_channels);
88+
89+
unpack_matrix(old_values.data(), num_rows, unpacked_channels,
90+
new_values.data());
91+
92+
return DenseElementsAttr::get<float>(result_type, new_values);
93+
} else {
94+
auto quant_type = output_type.cast<mlir::quant::UniformQuantizedType>();
95+
const double scale = quant_type.getScale();
96+
const int zero_point = quant_type.getZeroPoint();
97+
98+
std::int8_t zero_bit_result = saturate(zero_point + round(+1.0 / scale));
99+
std::int8_t one_bit_result = saturate(zero_point + round(-1.0 / scale));
100+
101+
std::vector<std::int8_t> new_values(num_rows * unpacked_channels);
102+
103+
unpack_matrix(old_values.data(), num_rows, unpacked_channels,
104+
new_values.data(), zero_bit_result, one_bit_result);
105+
106+
return DenseElementsAttr::get<std::int8_t>(result_type, new_values);
107+
}
108+
}
109+
110+
} // namespace TFL
111+
} // namespace mlir
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef LARQ_COMPUTE_ENGINE_MLIR_TRANSFORMS_BITPACK_H_
2+
#define LARQ_COMPUTE_ENGINE_MLIR_TRANSFORMS_BITPACK_H_
3+
4+
#include "mlir/IR/Attributes.h"
5+
#include "mlir/IR/Builders.h"
6+
#include "mlir/IR/BuiltinTypes.h"
7+
8+
namespace mlir {
9+
namespace TFL {
10+
11+
DenseElementsAttr Bitpack(mlir::Builder* builder, Attribute x);
12+
13+
DenseElementsAttr Unpack(Attribute x, ShapedType result_type);
14+
15+
} // namespace TFL
16+
} // namespace mlir
17+
18+
#endif

0 commit comments

Comments
 (0)