Skip to content

Commit 9a77138

Browse files
authored
Remove anonymous namespaces in MLIR transforms (#768)
1 parent 4c68119 commit 9a77138

File tree

11 files changed

+45
-66
lines changed

11 files changed

+45
-66
lines changed

larq_compute_engine/mlir/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ cc_library(
287287
cc_library(
288288
name = "larq_compute_engine_prepare",
289289
srcs = [
290+
"transforms/common.h",
290291
"transforms/generated_prepare_target_arm.inc",
291292
"transforms/generated_prepare_target_other.inc",
292293
"transforms/prepare_tf.cc",
@@ -310,6 +311,7 @@ cc_library(
310311
cc_library(
311312
name = "larq_compute_engine_optimize",
312313
srcs = [
314+
"transforms/common.h",
313315
"transforms/generated_bitpack_activations.inc",
314316
"transforms/generated_optimize_target_arm.inc",
315317
"transforms/generated_optimize_target_other.inc",

larq_compute_engine/mlir/transforms/bitpack_weights.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
namespace mlir {
1111
namespace TFL {
1212

13-
namespace {
14-
1513
struct BitpackWeightsLCE
1614
: public PassWrapper<BitpackWeightsLCE, OperationPass<mlir::func::FuncOp>> {
1715
llvm::StringRef getArgument() const final {
@@ -30,18 +28,18 @@ bool IsConv2DFilter(TypedAttr filter) {
3028
filter_type.getShape().size() == 4;
3129
}
3230

31+
namespace bitpackweights {
3332
#include "larq_compute_engine/mlir/transforms/generated_bitpack_weights.inc"
33+
} // namespace bitpackweights
3434

3535
void BitpackWeightsLCE::runOnOperation() {
3636
RewritePatternSet patterns(&getContext());
3737
auto func = getOperation();
3838

39-
TFL::populateWithGenerated(patterns);
39+
bitpackweights::populateWithGenerated(patterns);
4040
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
4141
}
4242

43-
} // namespace
44-
4543
// Creates an instance of the TensorFlow dialect BitpackWeights pass.
4644
std::unique_ptr<OperationPass<mlir::func::FuncOp>>
4745
CreateBitpackWeightsLCEPass() {
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#pragma once
2+
3+
#include "mlir/IR/Attributes.h"
4+
#include "mlir/IR/BuiltinAttributes.h"
5+
6+
namespace mlir {
7+
namespace TFL {
8+
9+
inline bool IsConstantValue(Attribute values, float expected_value) {
10+
if (!values.isa<DenseElementsAttr>()) return false;
11+
12+
for (auto value : values.cast<DenseElementsAttr>().getValues<float>()) {
13+
if (value != expected_value) return false;
14+
}
15+
return true;
16+
}
17+
18+
} // namespace TFL
19+
} // namespace mlir

larq_compute_engine/mlir/transforms/fuse_padding.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
namespace mlir {
77
namespace TFL {
88

9-
namespace {
10-
119
bool NoBatchAndChannelPadding(Attribute paddings_attr) {
1210
auto paddings = GetValidPadAttr(paddings_attr);
1311
if (!paddings) return false;
@@ -33,7 +31,9 @@ bool IsSamePaddingPartial(Attribute paddings_attr, Value input, Value output,
3331
output_shape[dimension], stride);
3432
}
3533

34+
namespace fuse_padding {
3635
#include "larq_compute_engine/mlir/transforms/generated_fuse_padding.inc"
36+
}
3737

3838
// Prepare LCE operations in functions for subsequent legalization.
3939
struct FusePadding
@@ -49,16 +49,14 @@ struct FusePadding
4949
auto* ctx = &getContext();
5050
RewritePatternSet patterns(ctx);
5151
auto func = getOperation();
52-
populateWithGenerated(patterns);
52+
fuse_padding::populateWithGenerated(patterns);
5353
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
5454
}
5555
void getDependentDialects(DialectRegistry& registry) const override {
5656
registry.insert<::mlir::TFL::TensorFlowLiteDialect>();
5757
}
5858
};
5959

60-
} // namespace
61-
6260
// Creates an instance of the TensorFlow dialect FusePadding pass.
6361
std::unique_ptr<OperationPass<mlir::func::FuncOp>> CreateFusePaddingPass() {
6462
return std::make_unique<FusePadding>();

larq_compute_engine/mlir/transforms/legalize_tflite.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
namespace mlir {
88
namespace TFL {
99

10-
namespace {
11-
1210
struct LegalizeLCE
1311
: public PassWrapper<LegalizeLCE, OperationPass<mlir::func::FuncOp>> {
1412
llvm::StringRef getArgument() const final { return "tfl-legalize-lce"; }
@@ -55,8 +53,6 @@ void LegalizeLCE::runOnOperation() {
5553
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
5654
}
5755

58-
} // namespace
59-
6056
// Creates an instance of the LegalizeLCE pass.
6157
std::unique_ptr<OperationPass<mlir::func::FuncOp>> CreateLegalizeLCEPass() {
6258
return std::make_unique<LegalizeLCE>();

larq_compute_engine/mlir/transforms/op_removal.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
namespace mlir {
99
namespace TFL {
1010

11-
namespace {
12-
1311
// Op removal of pass through ops to make following patterns easier and enable
1412
// early constant folding
1513
struct OpRemoval
@@ -21,18 +19,18 @@ struct OpRemoval
2119
void runOnOperation() override;
2220
};
2321

22+
namespace op_removal {
2423
#include "larq_compute_engine/mlir/transforms/generated_op_removal.inc"
24+
} // namespace op_removal
2525

2626
void OpRemoval::runOnOperation() {
2727
RewritePatternSet patterns(&getContext());
2828
auto func = getOperation();
2929

30-
TFL::populateWithGenerated(patterns);
30+
op_removal::populateWithGenerated(patterns);
3131
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
3232
}
3333

34-
} // namespace
35-
3634
// Creates an instance of the TensorFlow dialect OpRemoval pass.
3735
std::unique_ptr<OperationPass<mlir::func::FuncOp>> CreateOpRemovalPass() {
3836
return std::make_unique<OpRemoval>();

larq_compute_engine/mlir/transforms/optimize.cc

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "larq_compute_engine/core/bitpacking/bitpack.h"
44
#include "larq_compute_engine/mlir/ir/lce_ops.h"
5+
#include "larq_compute_engine/mlir/transforms/common.h"
56
#include "larq_compute_engine/mlir/transforms/passes.h"
67
#include "llvm/ADT/Optional.h"
78
#include "llvm/ADT/STLExtras.h"
@@ -16,8 +17,6 @@
1617
namespace mlir {
1718
namespace TFL {
1819

19-
namespace {
20-
2120
// Optimize LCE operations in functions.
2221
struct OptimizeLCE
2322
: public PassWrapper<OptimizeLCE, OperationPass<mlir::func::FuncOp>> {
@@ -38,15 +37,6 @@ struct OptimizeLCE
3837
clEnumValN(LCETarget::XCORE, "xcore", "XCORE target"))};
3938
};
4039

41-
bool IsConstantValue(Attribute values, float expected_value) {
42-
if (!values.isa<DenseElementsAttr>()) return false;
43-
44-
for (auto value : values.cast<DenseElementsAttr>().getValues<float>()) {
45-
if (value != expected_value) return false;
46-
}
47-
return true;
48-
}
49-
5040
/**
5141
* =================================================
5242
* Computing thresholds for writing bitpacked output
@@ -254,15 +244,15 @@ DenseElementsAttr GetBitpackedOutputThresholds(
254244
return DenseElementsAttr::get(type, thresholds);
255245
}
256246

257-
namespace target_arm {
247+
namespace optimize_target_arm {
258248
#include "larq_compute_engine/mlir/transforms/generated_optimize_target_arm.inc"
259249
}
260250

261-
namespace target_other {
251+
namespace optimize_target_other {
262252
#include "larq_compute_engine/mlir/transforms/generated_optimize_target_other.inc"
263253
}
264254

265-
namespace bitpack_activations {
255+
namespace optimize_bitpack_activations {
266256
#include "larq_compute_engine/mlir/transforms/generated_bitpack_activations.inc"
267257
}
268258

@@ -271,17 +261,15 @@ void OptimizeLCE::runOnOperation() {
271261
auto func = getOperation();
272262

273263
if (target_ == LCETarget::ARM) {
274-
target_arm::populateWithGenerated(patterns);
264+
optimize_target_arm::populateWithGenerated(patterns);
275265
} else {
276-
target_other::populateWithGenerated(patterns);
266+
optimize_target_other::populateWithGenerated(patterns);
277267
}
278-
bitpack_activations::populateWithGenerated(patterns);
268+
optimize_bitpack_activations::populateWithGenerated(patterns);
279269

280270
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
281271
}
282272

283-
} // namespace
284-
285273
// Creates an instance of the TensorFlow dialect OptimizeLCE pass.
286274
std::unique_ptr<OperationPass<mlir::func::FuncOp>> CreateOptimizeLCEPass(
287275
const LCETarget target) {

larq_compute_engine/mlir/transforms/prepare_tf.cc

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "larq_compute_engine/core/types.h"
22
#include "larq_compute_engine/mlir/ir/lce_ops.h"
3+
#include "larq_compute_engine/mlir/transforms/common.h"
34
#include "larq_compute_engine/mlir/transforms/padding.h"
45
#include "larq_compute_engine/mlir/transforms/passes.h"
56
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -14,8 +15,6 @@
1415
namespace mlir {
1516
namespace TFL {
1617

17-
namespace {
18-
1918
using compute_engine::core::bitpacking_bitwidth;
2019

2120
// Prepare LCE operations in functions for subsequent legalization.
@@ -39,14 +38,6 @@ struct PrepareLCE
3938
clEnumValN(LCETarget::XCORE, "xcore", "XCORE target"))};
4039
};
4140

42-
bool IsConstantValue(Attribute values, float expected_value) {
43-
if (!values.isa<DenseElementsAttr>()) return false;
44-
45-
for (auto value : values.cast<DenseElementsAttr>().getValues<float>()) {
46-
if (value != expected_value) return false;
47-
}
48-
return true;
49-
}
5041
DenseElementsAttr GetConstantVector(TypedAttr filter, float val) {
5142
auto filter_type = filter.getType().cast<ShapedType>();
5243
auto filter_shape = filter_type.getShape();
@@ -162,11 +153,11 @@ IntegerAttr GetNumChannels(Builder& b, Value output_val) {
162153
return b.getI32IntegerAttr(shape_vector[shape_vector.size() - 1]);
163154
}
164155

165-
namespace target_arm {
156+
namespace prepare_target_arm {
166157
#include "larq_compute_engine/mlir/transforms/generated_prepare_target_arm.inc"
167158
}
168159

169-
namespace target_other {
160+
namespace prepare_target_other {
170161
#include "larq_compute_engine/mlir/transforms/generated_prepare_target_other.inc"
171162
}
172163

@@ -181,16 +172,14 @@ void PrepareLCE::runOnOperation() {
181172
patterns.add<ConvertTFDilatedConvOp<TF::Conv2DOp>>(ctx);
182173

183174
if (target_ == LCETarget::ARM) {
184-
target_arm::populateWithGenerated(patterns);
175+
prepare_target_arm::populateWithGenerated(patterns);
185176
} else {
186-
target_other::populateWithGenerated(patterns);
177+
prepare_target_other::populateWithGenerated(patterns);
187178
}
188179

189180
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
190181
}
191182

192-
} // namespace
193-
194183
// Creates an instance of the TensorFlow dialect PrepareLCE pass.
195184
std::unique_ptr<OperationPass<mlir::func::FuncOp>> CreatePrepareLCEPass(
196185
const LCETarget target) {

larq_compute_engine/mlir/transforms/quantize.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ namespace TFL {
1212
//===----------------------------------------------------------------------===//
1313
// The actual Quantize Pass.
1414
//
15-
namespace {
16-
1715
// Applies quantization on the model in TFL dialect.
1816
struct LCEQuantizePass
1917
: public PassWrapper<LCEQuantizePass, OperationPass<mlir::func::FuncOp>> {
@@ -24,15 +22,16 @@ struct LCEQuantizePass
2422
void runOnOperation() override;
2523
};
2624

25+
namespace lce_quantize {
2726
#include "larq_compute_engine/mlir/transforms/generated_quantize.inc"
27+
}
2828

2929
void LCEQuantizePass::runOnOperation() {
3030
RewritePatternSet patterns(&getContext());
3131
auto func = getOperation();
32-
TFL::populateWithGenerated(patterns);
32+
lce_quantize::populateWithGenerated(patterns);
3333
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
3434
}
35-
} // namespace
3635

3736
// Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass.
3837
std::unique_ptr<OperationPass<mlir::func::FuncOp>> CreateLCEQuantizePass() {

larq_compute_engine/mlir/transforms/set_batch_size.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
namespace mlir {
88

9-
namespace {
10-
119
mlir::Type SetBatchSize(mlir::Type type) {
1210
auto tensor_type = type.dyn_cast<mlir::TensorType>();
1311
if (tensor_type && tensor_type.hasRank()) {
@@ -59,8 +57,6 @@ struct SetBatchSizePass
5957
}
6058
};
6159

62-
} // namespace
63-
6460
// Creates an instance of the ZeroPointCompatibility pass.
6561
std::unique_ptr<OperationPass<mlir::func::FuncOp>> CreateSetBatchSizePass() {
6662
return std::make_unique<SetBatchSizePass>();

0 commit comments

Comments
 (0)