Skip to content

Commit 893b984

Browse files
authored
Fork ConvertTFExecutorToTFLOrFlatbuffer to reenable argname truncation (#649)
1 parent b63f878 commit 893b984

File tree

4 files changed

+139
-19
lines changed

4 files changed

+139
-19
lines changed

larq_compute_engine/mlir/BUILD

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,23 @@ cc_library(
293293
],
294294
)
295295

296+
cc_library(
297+
name = "tf_to_tfl_flatbuffer",
298+
srcs = ["tf_to_tfl_flatbuffer.cc"],
299+
hdrs = [
300+
"tf_to_tfl_flatbuffer.h",
301+
],
302+
deps = [
303+
"@llvm-project//llvm:Support",
304+
"@llvm-project//mlir:IR",
305+
"@llvm-project//mlir:Pass",
306+
"@org_tensorflow//tensorflow/compiler/mlir:op_or_arg_name_mapper",
307+
"@org_tensorflow//tensorflow/compiler/mlir/lite:flatbuffer_export",
308+
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:error_util",
309+
"@org_tensorflow//tensorflow/stream_executor/lib",
310+
],
311+
)
312+
296313
tf_cc_binary(
297314
name = "lce-tf-opt",
298315
srcs = ["lce_mlir_opt.cc"],
@@ -308,9 +325,8 @@ pybind_extension(
308325
module_name = "graphdef_tfl_flatbuffer",
309326
deps = [
310327
":lce_tfl_passes",
311-
"@org_tensorflow//tensorflow/compiler/mlir:op_or_arg_name_mapper",
328+
":tf_to_tfl_flatbuffer",
312329
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite",
313-
"@org_tensorflow//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
314330
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
315331
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:import_utils",
316332
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",

larq_compute_engine/mlir/python/graphdef_tfl_flatbuffer.cc

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <exception>
22

33
#include "larq_compute_engine/mlir/tf_tfl_passes.h"
4+
#include "larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h"
45
#include "llvm/Support/MemoryBuffer.h"
56
#include "llvm/Support/ToolOutputFile.h"
67
#include "mlir/IR/MLIRContext.h"
@@ -10,7 +11,6 @@
1011
#include "pybind11/pytypes.h"
1112
#include "pybind11/stl.h"
1213
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
13-
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
1414
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
1515
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
1616
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
@@ -24,17 +24,6 @@
2424

2525
namespace tensorflow {
2626

27-
// Truncates names to a maximum length of ~50 characters since LCE op location
28-
// names can be very long otherwise.
29-
class TruncateOpOrArgLocNameMapper : public OpOrArgLocNameMapper {
30-
protected:
31-
std::string GetName(OpOrVal op_or_val) override {
32-
auto name = OpOrArgLocNameMapper::GetName(op_or_val);
33-
if (name.length() > 50) return name.substr(0, 50);
34-
return name;
35-
}
36-
};
37-
3827
pybind11::bytes ConvertGraphDefToTFLiteFlatBuffer(
3928
const pybind11::bytes& graphdef_bytes,
4029
const std::vector<string>& input_arrays,
@@ -118,11 +107,8 @@ pybind11::bytes ConvertGraphDefToTFLiteFlatBuffer(
118107
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
119108

120109
std::string result;
121-
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
122-
module->get(), /*export_to_mlir=*/false, /*emit_builtin_tflite_ops=*/true,
123-
/*emit_select_tf_ops=*/false, /*emit_custom_ops=*/true,
124-
/*select_user_tf_ops=*/{}, quant_specs, /*saved_model_tags=*/{}, &result,
125-
&pm);
110+
auto status = ConvertTFExecutorToFlatbuffer(
111+
module->get(), /*export_to_mlir=*/false, &result, &pm);
126112

127113
if (!status.ok()) {
128114
throw std::runtime_error("Could not translate to flatbuffer.");
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#include "larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h"
2+
3+
#include "llvm/Support/raw_ostream.h"
4+
#include "mlir/IR/BuiltinOps.h"
5+
#include "mlir/Pass/PassManager.h"
6+
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
7+
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
8+
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
9+
#include "tensorflow/core/framework/op.h"
10+
#include "tensorflow/stream_executor/lib/statusor.h"
11+
12+
namespace tensorflow {
13+
namespace {
14+
using mlir::ModuleOp;
15+
using mlir::Operation;
16+
17+
bool IsControlFlowV1Op(Operation* op) {
18+
return mlir::isa<mlir::tf_executor::SwitchOp, mlir::tf_executor::MergeOp,
19+
mlir::tf_executor::EnterOp, mlir::tf_executor::ExitOp,
20+
mlir::tf_executor::NextIterationSinkOp,
21+
mlir::tf_executor::NextIterationSourceOp>(op);
22+
}
23+
24+
mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) {
25+
auto result = module.walk([&](Operation* op) {
26+
return IsControlFlowV1Op(op) ? mlir::WalkResult::interrupt()
27+
: mlir::WalkResult::advance();
28+
});
29+
if (result.wasInterrupted()) {
30+
module.emitError(
31+
"The graph has Control Flow V1 ops. TFLite converter doesn't support "
32+
"Control Flow V1 ops. Consider using Control Flow V2 ops instead. See "
33+
"https://www.tensorflow.org/api_docs/python/tf/compat/v1/"
34+
"enable_control_flow_v2.");
35+
return mlir::failure();
36+
}
37+
return mlir::success();
38+
}
39+
40+
// Truncates names to a maximum length of ~50 characters since LCE op location
41+
// names can be very long otherwise.
42+
class TruncateOpOrArgLocNameMapper : public OpOrArgLocNameMapper {
43+
protected:
44+
std::string GetName(OpOrVal op_or_val) override {
45+
auto name = OpOrArgLocNameMapper::GetName(op_or_val);
46+
if (name.length() > 50) return name.substr(0, 50);
47+
return name;
48+
}
49+
};
50+
51+
} // namespace
52+
53+
Status ConvertTFExecutorToFlatbuffer(mlir::ModuleOp module, bool export_to_mlir,
54+
std::string* result,
55+
mlir::PassManager* pass_manager) {
56+
// Explicitly disable dumping Op details on failures.
57+
module.getContext()->printOpOnDiagnostic(false);
58+
59+
// Register a warning handler only log to std out.
60+
mlir::ScopedDiagnosticHandler s(
61+
module.getContext(), [](mlir::Diagnostic& diag) {
62+
if (diag.getSeverity() == mlir::DiagnosticSeverity::Warning) {
63+
for (auto& note : diag.getNotes()) {
64+
std::cout << note.str() << "\n";
65+
LOG(WARNING) << note.str() << "\n";
66+
}
67+
}
68+
return mlir::failure();
69+
});
70+
71+
mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
72+
/*propagate=*/true);
73+
74+
if (failed(IsValidGraph(module)) || failed(pass_manager->run(module))) {
75+
return statusHandler.ConsumeStatus();
76+
}
77+
78+
if (export_to_mlir) {
79+
llvm::raw_string_ostream os(*result);
80+
module.print(os);
81+
return Status::OK();
82+
}
83+
84+
// This is the only modification compared to the upstream tensorflow file
85+
TruncateOpOrArgLocNameMapper op_or_arg_name_mapper;
86+
tflite::FlatbufferExportOptions options;
87+
options.emit_builtin_tflite_ops = true;
88+
options.emit_custom_ops = true;
89+
options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
90+
if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result)) {
91+
return statusHandler.ConsumeStatus();
92+
}
93+
94+
if (mlir::failed(module.verify())) {
95+
return tensorflow::errors::Unknown("Final module is invalid");
96+
}
97+
return Status::OK();
98+
}
99+
100+
} // namespace tensorflow
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef LARQ_COMPUTE_ENGINE_MLIR_TF_TO_TFL_FLATBUFFER_H_
2+
#define LARQ_COMPUTE_ENGINE_MLIR_TF_TO_TFL_FLATBUFFER_H_
3+
4+
#include "mlir/IR/BuiltinOps.h"
5+
#include "mlir/Pass/PassManager.h"
6+
#include "tensorflow/stream_executor/lib/statusor.h"
7+
8+
namespace tensorflow {
9+
10+
// This is a fork of ConvertTFExecutorToTFLOrFlatbuffer to enable custom
11+
// OpOrArgLocNameMapper
12+
// https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h#L55-L69
13+
Status ConvertTFExecutorToFlatbuffer(mlir::ModuleOp module, bool export_to_mlir,
14+
std::string* result,
15+
mlir::PassManager* pass_manager);
16+
} // namespace tensorflow
17+
18+
#endif // LARQ_COMPUTE_ENGINE_MLIR_TF_TO_TFL_FLATBUFFER_H_

0 commit comments

Comments
 (0)