Skip to content

Commit 6e0f432

Browse files
authored
Add grouped binary convolution support (3/3): indirect BGEMM kernel. (#551)
Add support for grouped binary convolutions to the optimised indirect BGEMM kernel.
1 parent 48bb86b commit 6e0f432

File tree

17 files changed

+1512
-1090
lines changed

17 files changed

+1512
-1090
lines changed

larq_compute_engine/core/bconv2d/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ cc_library(
6565
deps = [
6666
":zero_padding_correction",
6767
"//larq_compute_engine/core/indirect_bgemm:kernels",
68-
"//larq_compute_engine/core/indirect_bgemm:prepare",
6968
"@org_tensorflow//tensorflow/lite/kernels:cpu_backend_context",
7069
"@org_tensorflow//tensorflow/lite/kernels:cpu_backend_gemm",
7170
"@org_tensorflow//tensorflow/lite/kernels:padding",

larq_compute_engine/core/bconv2d/optimized_indirect_bgemm.h

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,25 @@ namespace bconv2d {
1212

1313
template <typename AccumScalar, typename DstScalar>
1414
inline void BConv2DOptimizedIndirectBGEMM(
15-
const indirect_bgemm::IndirectBGEMMKernel<DstScalar> kernel,
16-
const BConv2DParams* bconv2d_params,
15+
const indirect_bgemm::Kernel* kernel, const BConv2DParams* bconv2d_params,
1716
const RuntimeShape& bitpacked_input_shape, const RuntimeShape& output_shape,
18-
const OutputTransform<DstScalar>& output_transform,
19-
const TBitpacked* packed_weights, const TBitpacked** indirection_buffer,
20-
DstScalar* output_data, const float* padding_buffer, const int pad_value) {
21-
TF_LITE_ASSERT_EQ(bitpacked_input_shape.DimensionsCount(), 4);
22-
TF_LITE_ASSERT_EQ(output_shape.DimensionsCount(), 4);
23-
17+
DstScalar* output_ptr, const float* padding_buffer, const int pad_value) {
2418
ruy::profiler::ScopeLabel label("BConv2D (optimized, indirect BGEMM)");
2519

26-
const std::int32_t conv_kernel_size =
27-
bconv2d_params->filter_height * bconv2d_params->filter_width;
28-
const std::int32_t bitpacked_input_channels = bitpacked_input_shape.Dims(3);
29-
const std::int32_t output_size =
30-
output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
31-
const std::int32_t output_channels = bconv2d_params->channels_out;
20+
// If writing bitpacked output with a channel count that isn't a multiple of
21+
// 32 (i.e. where padding bits will be required in the output), fill the
22+
// output tensor with zeroes in advance so that the BGEMM doesn't have to
23+
// worry about doing the padding.
24+
if (std::is_same<DstScalar, TBitpacked>::value &&
25+
(kernel->output_channels % bitpacking_bitwidth != 0)) {
26+
std::fill(
27+
output_ptr,
28+
output_ptr + kernel->num_output_pixels *
29+
bitpacking::GetBitpackedSize(kernel->output_channels),
30+
TBitpacked(0));
31+
}
3232

33-
indirect_bgemm::RunKernel(kernel, conv_kernel_size, bitpacked_input_channels,
34-
output_size, output_channels, output_transform,
35-
packed_weights, indirection_buffer, output_data);
33+
kernel->Dispatch(reinterpret_cast<void*>(output_ptr));
3634

3735
if (std::is_same<DstScalar, float>::value &&
3836
bconv2d_params->padding_type == TfLitePadding::kTfLitePaddingSame &&
@@ -44,7 +42,8 @@ inline void BConv2DOptimizedIndirectBGEMM(
4442
const int dilation_width_factor = bconv2d_params->dilation_width_factor;
4543
const int dilation_height_factor = bconv2d_params->dilation_height_factor;
4644
const int batches = MatchingDim(bitpacked_input_shape, 0, output_shape, 0);
47-
const int input_depth = bconv2d_params->channels_in;
45+
const int input_depth_per_group =
46+
bconv2d_params->channels_in / bconv2d_params->groups;
4847
const int input_width = bitpacked_input_shape.Dims(2);
4948
const int input_height = bitpacked_input_shape.Dims(1);
5049
const int filter_height = bconv2d_params->filter_height;
@@ -54,10 +53,10 @@ inline void BConv2DOptimizedIndirectBGEMM(
5453
const int output_height = output_shape.Dims(1);
5554

5655
zero_padding_correction::ApplyCorrection(
57-
batches, input_height, input_width, input_depth, filter_height,
58-
filter_width, output_depth, stride_height, stride_width,
56+
batches, input_height, input_width, input_depth_per_group,
57+
filter_height, filter_width, output_depth, stride_height, stride_width,
5958
dilation_height_factor, dilation_width_factor,
60-
reinterpret_cast<float*>(output_data), output_height, output_width,
59+
reinterpret_cast<float*>(output_ptr), output_height, output_width,
6160
padding_buffer);
6261
}
6362
}

larq_compute_engine/core/bitpacking/bitpack.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace bitpacking {
2222
// Utility functions
2323

2424
constexpr int GetBitpackedSize(int unpacked_elements) {
25-
return (unpacked_elements + bitpacking_bitwidth - 1) / bitpacking_bitwidth;
25+
return CeilDiv(unpacked_elements, bitpacking_bitwidth);
2626
}
2727

2828
constexpr int GetBitpackedMatrixSize(int rows, int cols) {

larq_compute_engine/core/indirect_bgemm/BUILD

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,23 @@ licenses(["notice"]) # Apache 2.0
22

33
package(default_visibility = ["//visibility:public"])
44

5-
cc_library(
6-
name = "prepare",
7-
hdrs = [
8-
"prepare.h",
9-
],
10-
deps = [
11-
"//larq_compute_engine/core:types",
12-
"//larq_compute_engine/core/bconv2d:params",
13-
"@org_tensorflow//tensorflow/lite/kernels/internal:types",
14-
],
15-
)
16-
175
cc_library(
186
name = "kernels",
19-
hdrs = [
20-
"kernel.h",
7+
srcs = [
218
"kernel_4x2_portable.h",
229
"kernel_8x4x1_aarch64.h",
2310
"kernel_8x4x2_aarch64.h",
2411
"kernel_8x4x4_aarch64.h",
2512
],
13+
hdrs = [
14+
"kernel.h",
15+
"select_kernel.h",
16+
],
2617
deps = [
2718
"//larq_compute_engine/core:types",
2819
"//larq_compute_engine/core/bconv2d:output_transform",
2920
"//larq_compute_engine/core/bconv2d:params",
21+
"//larq_compute_engine/core/bitpacking:bitpack",
3022
"@org_tensorflow//tensorflow/lite/kernels/internal:types",
3123
"@ruy//ruy/profiler:instrumentation",
3224
],

0 commit comments

Comments
 (0)