@@ -12,27 +12,25 @@ namespace bconv2d {
1212
1313template <typename AccumScalar, typename DstScalar>
1414inline void BConv2DOptimizedIndirectBGEMM (
15- const indirect_bgemm::IndirectBGEMMKernel<DstScalar> kernel,
16- const BConv2DParams* bconv2d_params,
15+ const indirect_bgemm::Kernel* kernel, const BConv2DParams* bconv2d_params,
1716 const RuntimeShape& bitpacked_input_shape, const RuntimeShape& output_shape,
18- const OutputTransform<DstScalar>& output_transform,
19- const TBitpacked* packed_weights, const TBitpacked** indirection_buffer,
20- DstScalar* output_data, const float * padding_buffer, const int pad_value) {
21- TF_LITE_ASSERT_EQ (bitpacked_input_shape.DimensionsCount (), 4 );
22- TF_LITE_ASSERT_EQ (output_shape.DimensionsCount (), 4 );
23-
17+ DstScalar* output_ptr, const float * padding_buffer, const int pad_value) {
2418 ruy::profiler::ScopeLabel label (" BConv2D (optimized, indirect BGEMM)" );
2519
26- const std::int32_t conv_kernel_size =
27- bconv2d_params->filter_height * bconv2d_params->filter_width ;
28- const std::int32_t bitpacked_input_channels = bitpacked_input_shape.Dims (3 );
29- const std::int32_t output_size =
30- output_shape.Dims (0 ) * output_shape.Dims (1 ) * output_shape.Dims (2 );
31- const std::int32_t output_channels = bconv2d_params->channels_out ;
20+ // If writing bitpacked output with a channel count that isn't a multiple of
21+ // 32 (i.e. where padding bits will be required in the output), fill the
22+ // output tensor with zeroes in advance so that the BGEMM doesn't have to
23+ // worry about doing the padding.
24+ if (std::is_same<DstScalar, TBitpacked>::value &&
25+ (kernel->output_channels % bitpacking_bitwidth != 0 )) {
26+ std::fill (
27+ output_ptr,
28+ output_ptr + kernel->num_output_pixels *
29+ bitpacking::GetBitpackedSize (kernel->output_channels ),
30+ TBitpacked (0 ));
31+ }
3232
33- indirect_bgemm::RunKernel (kernel, conv_kernel_size, bitpacked_input_channels,
34- output_size, output_channels, output_transform,
35- packed_weights, indirection_buffer, output_data);
33+ kernel->Dispatch (reinterpret_cast <void *>(output_ptr));
3634
3735 if (std::is_same<DstScalar, float >::value &&
3836 bconv2d_params->padding_type == TfLitePadding::kTfLitePaddingSame &&
@@ -44,7 +42,8 @@ inline void BConv2DOptimizedIndirectBGEMM(
4442 const int dilation_width_factor = bconv2d_params->dilation_width_factor ;
4543 const int dilation_height_factor = bconv2d_params->dilation_height_factor ;
4644 const int batches = MatchingDim (bitpacked_input_shape, 0 , output_shape, 0 );
47- const int input_depth = bconv2d_params->channels_in ;
45+ const int input_depth_per_group =
46+ bconv2d_params->channels_in / bconv2d_params->groups ;
4847 const int input_width = bitpacked_input_shape.Dims (2 );
4948 const int input_height = bitpacked_input_shape.Dims (1 );
5049 const int filter_height = bconv2d_params->filter_height ;
@@ -54,10 +53,10 @@ inline void BConv2DOptimizedIndirectBGEMM(
5453 const int output_height = output_shape.Dims (1 );
5554
5655 zero_padding_correction::ApplyCorrection (
57- batches, input_height, input_width, input_depth, filter_height ,
58- filter_width, output_depth, stride_height, stride_width,
56+ batches, input_height, input_width, input_depth_per_group ,
57+ filter_height, filter_width, output_depth, stride_height, stride_width,
5958 dilation_height_factor, dilation_width_factor,
60- reinterpret_cast <float *>(output_data ), output_height, output_width,
59+ reinterpret_cast <float *>(output_ptr ), output_height, output_width,
6160 padding_buffer);
6261 }
6362}
0 commit comments