Skip to content

Commit f5cfef8

Browse files
[TOSA] TorchToTosa option to emit partial conversion with Torch IR (#4106)
Add "require-full-tosa-conversion" option to TorchToTosa pipeline pass. The default option is "true". When this option is set to "false", models with non-legalized aten ops can still be partially converted with a mixed of Torch and TOSA ops in the IR. Example usage: "builtin.module(torch-backend-to-tosa-backend-pipeline{require-full-tosa-conversion=false})" Signed-off-by: Justin Ngo <justin.ngo@arm.com>
1 parent 493bb33 commit f5cfef8

File tree

5 files changed

+42
-6
lines changed

5 files changed

+42
-6
lines changed

include/torch-mlir/Conversion/Passes.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,13 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
122122
guards in case of shape mismatches.
123123
}];
124124
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
125+
126+
let options = [
127+
Option<"requireFullTosaConversion", "require-full-tosa-conversion",
128+
"bool", /*default=*/"true",
129+
"Require TorchToTosa full conversion by adding Torch Dialect to "
130+
"TorchToTosa list of illegal dialects">,
131+
];
125132
}
126133
#endif
127134

include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter,
3030
RewritePatternSet &patterns);
3131

3232
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
33+
std::unique_ptr<OperationPass<func::FuncOp>>
34+
createConvertTorchToTosaPass(bool requireFullTosaConversion);
3335
} // namespace torch
3436
} // namespace mlir
3537

include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,19 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);
2828

2929
// Do not register the TOSA options if the TOSA target is disabled
3030
#ifdef TORCH_MLIR_ENABLE_TOSA
31+
struct TosaBackendPipelineOptions
32+
: public PassPipelineOptions<TosaBackendPipelineOptions> {
33+
Option<bool> requireFullTosaConversion{
34+
*this, "require-full-tosa-conversion",
35+
llvm::cl::desc("Require full TorchToTosa conversion by adding Torch "
36+
"Dialect to TorchToTosa list of illegal dialects"),
37+
llvm::cl::init(true)};
38+
};
39+
3140
/// Creates a pipeline that lowers from the torch backend contract to the
3241
/// TOSA backend contract.
33-
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
42+
void createTorchBackendToTosaBackendPipeline(
43+
OpPassManager &pm, const TosaBackendPipelineOptions &options);
3444

3545
std::unique_ptr<OperationPass<ModuleOp>> createVerifyTosaBackendContractPass();
3646
#endif // TORCH_MLIR_ENABLE_TOSA

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8950,6 +8950,11 @@ LogicalResult ConvertAtenOp<AtenUnfoldOp>::matchAndRewrite(
89508950
namespace {
89518951
class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
89528952
public:
8953+
ConvertTorchToTosa() = default;
8954+
ConvertTorchToTosa(bool requireFullTosaConversion) {
8955+
this->requireFullTosaConversion = requireFullTosaConversion;
8956+
}
8957+
89538958
void getDependentDialects(DialectRegistry &registry) const override {
89548959
registry.insert<tosa::TosaDialect>();
89558960
registry.insert<tensor::TensorDialect>();
@@ -8962,7 +8967,12 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
89628967
ConversionTarget target(*context);
89638968
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect,
89648969
arith::ArithDialect>();
8965-
target.addIllegalDialect<Torch::TorchDialect>();
8970+
8971+
if (this->requireFullTosaConversion) {
8972+
target.addIllegalDialect<Torch::TorchDialect>();
8973+
} else {
8974+
target.addLegalDialect<Torch::TorchDialect>();
8975+
}
89668976

89678977
TypeConverter typeConverter;
89688978
typeConverter.addConversion([](Type type) { return type; });
@@ -9318,5 +9328,10 @@ std::set<StringRef> torch::populateTorchToTosaConversionPatternsAndIllegalOps(
93189328

93199329
std::unique_ptr<OperationPass<func::FuncOp>>
93209330
mlir::torch::createConvertTorchToTosaPass() {
9321-
return std::make_unique<ConvertTorchToTosa>();
9331+
return std::make_unique<ConvertTorchToTosa>(true);
9332+
}
9333+
9334+
std::unique_ptr<OperationPass<func::FuncOp>>
9335+
mlir::torch::createConvertTorchToTosaPass(bool requireFullTosaConversion) {
9336+
return std::make_unique<ConvertTorchToTosa>(requireFullTosaConversion);
93229337
}

lib/Dialect/TorchConversion/Transforms/Passes.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ void mlir::torch::registerTorchConversionPasses() {
5050
"contract.",
5151
TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline);
5252
#ifdef TORCH_MLIR_ENABLE_TOSA
53-
mlir::PassPipelineRegistration<>(
53+
mlir::PassPipelineRegistration<TorchConversion::TosaBackendPipelineOptions>(
5454
"torch-backend-to-tosa-backend-pipeline",
5555
"Pipeline lowering torch backend contract to TOSA backend "
5656
"contract.",
@@ -113,8 +113,10 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
113113

114114
#ifdef TORCH_MLIR_ENABLE_TOSA
115115
void TorchConversion::createTorchBackendToTosaBackendPipeline(
116-
OpPassManager &pm) {
117-
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
116+
OpPassManager &pm,
117+
const TorchConversion::TosaBackendPipelineOptions &options) {
118+
pm.addNestedPass<func::FuncOp>(
119+
createConvertTorchToTosaPass(options.requireFullTosaConversion));
118120
// Perform rank broadcasting so TosaToLinalg pass works
119121
pm.addNestedPass<func::FuncOp>(createTosaMakeBroadcastablePass());
120122

0 commit comments

Comments
 (0)