From 979e0f6f247279a5341c88e47fded590b9ad45a9 Mon Sep 17 00:00:00 2001 From: Han Meng Date: Mon, 16 Sep 2024 23:26:41 -0400 Subject: [PATCH 1/7] can't clone block2 --- include/hcl/Transforms/Passes.h | 1 + lib/Bindings/Python/HCLModule.cpp | 7 +++++++ lib/Transforms/CMakeLists.txt | 1 + 3 files changed, 9 insertions(+) diff --git a/include/hcl/Transforms/Passes.h b/include/hcl/Transforms/Passes.h index be8f9426..ab1a3c1e 100644 --- a/include/hcl/Transforms/Passes.h +++ b/include/hcl/Transforms/Passes.h @@ -28,6 +28,7 @@ bool applyLegalizeCast(ModuleOp &module); bool applyRemoveStrideMap(ModuleOp &module); bool applyMemRefDCE(ModuleOp &module); bool applyDataPlacement(ModuleOp &module); +ModuleOp applyUnifyKernels(ModuleOp &module1, ModuleOp &module2); /// Registers all HCL transformation passes void registerHCLPasses(); diff --git a/lib/Bindings/Python/HCLModule.cpp b/lib/Bindings/Python/HCLModule.cpp index ecb00c83..463c4993 100644 --- a/lib/Bindings/Python/HCLModule.cpp +++ b/lib/Bindings/Python/HCLModule.cpp @@ -157,6 +157,12 @@ static bool memRefDCE(MlirModule &mlir_mod) { return applyMemRefDCE(mod); } +static MlirModule UnifyKernels(MlirModule &mlir_mod1, MlirModule &mlir_mod2) { + auto mod1 = unwrap(mlir_mod1); + auto mod2 = unwrap(mlir_mod2); + return wrap(applyUnifyKernels(mod1, mod2)); +} + //===----------------------------------------------------------------------===// // HCL Python module definition //===----------------------------------------------------------------------===// @@ -259,4 +265,5 @@ PYBIND11_MODULE(_hcl, m) { // Utility pass APIs. hcl_m.def("memref_dce", &memRefDCE); + hcl_m.def("unify_kernels", &UnifyKernels); } diff --git a/lib/Transforms/CMakeLists.txt b/lib/Transforms/CMakeLists.txt index f2b87b1b..7002fea9 100644 --- a/lib/Transforms/CMakeLists.txt +++ b/lib/Transforms/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_library(MLIRHCLPasses MemRefDCE.cpp DataPlacement.cpp TransformInterpreter.cpp + UnifyKernels.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/hcl From fd984f21f4df2e597525cfdf56580116bcbf02d1 Mon Sep 17 00:00:00 2001 From: Han Meng Date: Tue, 17 Sep 2024 10:48:14 -0400 Subject: [PATCH 2/7] add unify kernels --- lib/Transforms/UnifyKernels.cpp | 342 ++++++++++++++++++++++++++++++++ 1 file changed, 342 insertions(+) create mode 100644 lib/Transforms/UnifyKernels.cpp diff --git a/lib/Transforms/UnifyKernels.cpp b/lib/Transforms/UnifyKernels.cpp new file mode 100644 index 00000000..18be4c0f --- /dev/null +++ b/lib/Transforms/UnifyKernels.cpp @@ -0,0 +1,342 @@ +/* + * Copyright HeteroCL authors. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include "PassDetail.h" + +#include "hcl/Dialect/HeteroCLDialect.h" +#include "hcl/Dialect/HeteroCLOps.h" +#include "hcl/Dialect/HeteroCLTypes.h" +#include "hcl/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/AffineMap.h" + +using namespace mlir; +using namespace hcl; + +namespace mlir { +namespace hcl { + +bool compareAffineExprs(AffineExpr lhsExpr, AffineExpr rhsExpr) { + // Compare the kinds of affine exprs + if (lhsExpr.getKind() != rhsExpr.getKind()) { + return false; + } + + // Compare affine exprs based on kind + switch (lhsExpr.getKind()) { + case AffineExprKind::Constant: { + auto lhsConst = lhsExpr.cast(); + auto rhsConst = rhsExpr.cast(); + return lhsConst.getValue() == rhsConst.getValue(); + } + case AffineExprKind::DimId: { + auto lhsDim = lhsExpr.cast(); + auto rhsDim = rhsExpr.cast(); + return lhsDim.getPosition() == rhsDim.getPosition(); + } + case AffineExprKind::SymbolId: { + auto lhsSymbol = lhsExpr.cast(); + auto rhsSymbol = rhsExpr.cast(); + return lhsSymbol.getPosition() == rhsSymbol.getPosition(); + } + case AffineExprKind::Add: + case AffineExprKind::Mul: + case AffineExprKind::Mod: + case AffineExprKind::FloorDiv: + case AffineExprKind::CeilDiv: { + auto lhsBinary = lhsExpr.cast(); + auto rhsBinary = rhsExpr.cast(); + return compareAffineExprs(lhsBinary.getLHS(), rhsBinary.getLHS()) && + compareAffineExprs(lhsBinary.getRHS(), rhsBinary.getRHS()); + } + } + return false; +} + +bool compareAffineMaps(AffineMap lhsMap, AffineMap rhsMap) { + AffineMap simplifiedLhsMap = simplifyAffineMap(lhsMap); + AffineMap simplifiedRhsMap = simplifyAffineMap(rhsMap); + + if (simplifiedLhsMap.getNumDims() != simplifiedRhsMap.getNumDims() && + simplifiedLhsMap.getNumSymbols() != simplifiedRhsMap.getNumSymbols() && + simplifiedLhsMap.getNumResults() != simplifiedRhsMap.getNumResults()) + return false; + + // Compare exprs + for (unsigned i = 0; i < simplifiedLhsMap.getNumResults(); ++i) { + if (!compareAffineExprs(simplifiedLhsMap.getResult(i), simplifiedRhsMap.getResult(i))) { + return false; + } + } + + // Todo: Might need to compare operands or use evaluation to compare + return true; +} + +bool compareAffineForOps(affine::AffineForOp &affineForOp1, affine::AffineForOp &affineForOp2) { + if (affineForOp1 == affineForOp2) + return true; + + if (affineForOp1.getStep() != affineForOp2.getStep()) return false; + if (!compareAffineMaps(affineForOp1.getLowerBoundMap(), affineForOp2.getLowerBoundMap()) || + !compareAffineMaps(affineForOp1.getUpperBoundMap(), affineForOp2.getUpperBoundMap())) + return false; + return true; +} + +void mergeLoop(OpBuilder &builder, affine::AffineForOp &op1, affine::AffineForOp &op2, Value conditionArg, bool &foundDifference) { + auto loc = op1->getLoc(); + + // Save insertion point + OpBuilder::InsertionGuard guard(builder); + + // Create new affine.for with same arguments + auto lowerBoundMap = op1.getLowerBoundMap(); + auto lowerBoundOperands = llvm::SmallVector(op1.getLowerBoundOperands().begin(), op1.getLowerBoundOperands().end()); + auto upperBoundMap = op1.getUpperBoundMap(); + auto upperBoundOperands = llvm::SmallVector(op1.getUpperBoundOperands().begin(), op1.getUpperBoundOperands().end()); + int64_t step = op1.getStep(); + + auto newAffineForOp = builder.create( + loc, lowerBoundOperands, lowerBoundMap, upperBoundOperands, upperBoundMap, step); + + Block *body1 = op1.getBody(); + Block *body2 = op2.getBody(); + Block *newBody = newAffineForOp.getBody(); + + // Set insertion point to current loop body + builder.setInsertionPointToStart(newBody); + + auto body1It = body1->begin(); + auto body2It = body2->begin(); + + // Iterate over two FuncOps to find branch location + while (body1It != body1->end() && body2It != body2->end()) { + if (!foundDifference) { + if (!(&(*body1It) == &(*body2It))) { + // If we found an affine.for to merge + // Todo: Support dynamic loop range + auto affineForOp1 = dyn_cast(&(*body1It)); + auto affineForOp2 = dyn_cast(&(*body2It)); + if (affineForOp1 && affineForOp2 && + compareAffineForOps(affineForOp1, affineForOp2)) { + mergeLoop(builder, affineForOp1, affineForOp2, conditionArg, foundDifference); + } + else { + foundDifference = true; + break; + } + } else { + builder.clone(*body1It); + } + ++body1It; + ++body2It; + } else { + break; + } + } + + // Create branch for the rest after difference is found + IRMapping mapping1; + IRMapping mapping2; + builder.create( + loc, conditionArg, + [&](OpBuilder &thenBuilder, Location thenLoc) { + while (body1It != body1->end()) { + thenBuilder.clone(*body1It, mapping1); + ++body1It; + } + thenBuilder.create(thenLoc); + }, + [&](OpBuilder &elseBuilder, Location elseLoc) { + while (body2It != body2->end()) { + elseBuilder.clone(*body2It, mapping2); + ++body2It; + } + elseBuilder.create(elseLoc); + } + ); +} + +void printIRMapping(mlir::IRMapping &mapping) { + llvm::outs() << "IRMapping contents:\n"; + for (auto &kv : mapping.getValueMap()) { + llvm::outs() << "From Value: " << kv.first << " To Value: " << kv.second << "\n"; + } + + for (auto &kv : mapping.getBlockMap()) { + llvm::outs() << "From Block: "; + kv.first->print(llvm::outs()); + llvm::outs() << " To Block: "; + kv.second->print(llvm::outs()); + llvm::outs() << "\n"; + } +} + +func::FuncOp unifyKernels(func::FuncOp &func1, func::FuncOp &func2, OpBuilder &builder) { + std::string newFuncName = func1.getName().str() + "_" + func2.getName().str() + "_unified"; + + // Todo: Now assuming return types and input types are the same + // Create new FuncOp with additional parameter + auto oldFuncType = func1.getFunctionType(); + auto oldInputTypes = oldFuncType.getInputs(); + auto resultTypes = oldFuncType.getResults(); + auto loc = builder.getUnknownLoc(); + SmallVector newInputTypes(oldInputTypes.begin(), oldInputTypes.end()); + Type instType = builder.getI1Type(); + newInputTypes.push_back(instType); + auto newFuncType = builder.getFunctionType(newInputTypes, resultTypes); + auto newFuncOp = func::FuncOp::create(loc, newFuncName, newFuncType); + + // Create new block for insertion + Block *entryBlock = newFuncOp.addEntryBlock(); + auto conditionArg = entryBlock->getArgument(entryBlock->getNumArguments() - 1); + builder.setInsertionPointToStart(entryBlock); + + auto &block1 = func1.front(); + auto &block2 = func2.front(); + auto block1It = block1.begin(); + auto block2It = block2.begin(); + bool foundDifference = false; + + // Iterate over two FuncOps to find branch location + while (block1It != block1.end() && block2It != block2.end()) { + if (!foundDifference) { + if (!(&(*block1It) == &(*block2It))) { + // If we found an affine.for to merge + // Todo: Support dynamic loop range + auto affineForOp1 = dyn_cast(&(*block1It)); + auto affineForOp2 = dyn_cast(&(*block2It)); + if (affineForOp1 && affineForOp2 && + compareAffineForOps(affineForOp1, affineForOp2)) { + mergeLoop(builder, affineForOp1, affineForOp2, conditionArg, foundDifference); + } + else { + foundDifference = true; + break; + } + } else { + builder.clone(*block1It); + } + ++block1It; + ++block2It; + } else { + break; + } + } + + // Create branch for the rest after difference is found + IRMapping mapping1; + IRMapping mapping2; + // builder.create( + // loc, conditionArg, + // [&](OpBuilder &thenBuilder, Location thenLoc) { + // for (size_t i = 0; i < block1.getNumArguments(); ++i) { + // mapping1.map(block1.getArgument(i), entryBlock->getArgument(i)); + // } + // while (block1It != block1.end()) { + // auto &op = *block1It; + // if (auto returnOp = dyn_cast(&op)) { + // break; + // } + // thenBuilder.clone(*block1It, mapping1); + // ++block1It; + // } + // thenBuilder.create(thenLoc); + // }, + // [&](OpBuilder &elseBuilder, Location elseLoc) { + // for (size_t i = 0; i < block2.getNumArguments(); ++i) { + // mapping2.map(block2.getArgument(i), entryBlock->getArgument(i)); + // } + // while (block2It != block2.end()) { + // auto &op = *block2It; + // if (auto returnOp = dyn_cast(&op)) { + // break; + // } + // elseBuilder.clone(*block2It, mapping2); + // ++block2It; + // } + // elseBuilder.create(elseLoc); + // } + // ); + auto ifOp = builder.create(loc, conditionArg, /*hasElse*/ true); + + // Create then block + builder.setInsertionPointToStart(ifOp.thenBlock()); + for (size_t i = 0; i < block1.getNumArguments(); ++i) { + mapping1.map(block1.getArgument(i), entryBlock->getArgument(i)); + } + while (block1It != block1.end()) { + auto &op = *block1It; + if (auto returnOp = dyn_cast(&op)) { + break; + } + builder.clone(*block1It, mapping1); + ++block1It; + } + + // Create else block + builder.setInsertionPointToStart(ifOp.elseBlock()); + for (size_t i = 0; i < block2.getNumArguments(); ++i) { + mapping2.map(block2.getArgument(i), entryBlock->getArgument(i)); + } + while (block2It != block2.end()) { + auto &op = *block2It; + if (auto returnOp = dyn_cast(&op)) { + break; + } + builder.clone(*block2It, mapping2); + ++block2It; + } + + // Create return Op + builder.setInsertionPointAfter(ifOp); + + printIRMapping(mapping1); + printIRMapping(mapping2); + + builder.clone(*block1It, mapping1); + // builder.clone(*block2It, mapping2); + + return newFuncOp; +} + +/// Pass entry point +ModuleOp applyUnifyKernels(ModuleOp &module1, ModuleOp &module2) { + auto funcOps1 = module1.getOps(); + auto funcOps2 = module2.getOps(); + + auto it1 = funcOps1.begin(); + auto it2 = funcOps2.begin(); + + MLIRContext *context = module1.getContext(); + OpBuilder builder(context); + + // + ModuleOp newModule = ModuleOp::create(module1.getLoc()); + + while (it1 != funcOps1.end() && it2 != funcOps2.end()) { + func::FuncOp funcOp1 = *it1; + func::FuncOp funcOp2 = *it2; + func::FuncOp newFuncOp = unifyKernels(funcOp1, funcOp2, builder); + newModule.push_back(newFuncOp); + + ++it1; + ++it2; + } + + return newModule; +} + +} // namespace hcl +} // namespace mlir \ No newline at end of file From 2981d033f383ae7ecff20d6fc13dc07a894d0b22 Mon Sep 17 00:00:00 2001 From: Han Meng Date: Sat, 21 Sep 2024 20:30:14 -0400 Subject: [PATCH 3/7] basic and nested loop unify --- include/hcl/Transforms/Passes.h | 2 +- lib/Bindings/Python/HCLModule.cpp | 5 +- lib/Transforms/UnifyKernels.cpp | 172 +++++++--------- tools/hcl-opt/hcl-opt.cpp | 321 ++++++++++++++++++------------ 4 files changed, 271 insertions(+), 229 deletions(-) diff --git a/include/hcl/Transforms/Passes.h b/include/hcl/Transforms/Passes.h index ab1a3c1e..1af6cce6 100644 --- a/include/hcl/Transforms/Passes.h +++ b/include/hcl/Transforms/Passes.h @@ -28,7 +28,7 @@ bool applyLegalizeCast(ModuleOp &module); bool applyRemoveStrideMap(ModuleOp &module); bool applyMemRefDCE(ModuleOp &module); bool applyDataPlacement(ModuleOp &module); -ModuleOp applyUnifyKernels(ModuleOp &module1, ModuleOp &module2); +ModuleOp applyUnifyKernels(ModuleOp &module1, ModuleOp &module2, MLIRContext *context); /// Registers all HCL transformation passes void registerHCLPasses(); diff --git a/lib/Bindings/Python/HCLModule.cpp b/lib/Bindings/Python/HCLModule.cpp index 463c4993..f341e21f 100644 --- a/lib/Bindings/Python/HCLModule.cpp +++ b/lib/Bindings/Python/HCLModule.cpp @@ -157,10 +157,11 @@ static bool memRefDCE(MlirModule &mlir_mod) { return applyMemRefDCE(mod); } -static MlirModule UnifyKernels(MlirModule &mlir_mod1, MlirModule &mlir_mod2) { +static MlirModule UnifyKernels(MlirModule &mlir_mod1, MlirModule &mlir_mod2, MlirContext &mlir_context) { auto mod1 = unwrap(mlir_mod1); auto mod2 = unwrap(mlir_mod2); - return wrap(applyUnifyKernels(mod1, mod2)); + auto context = unwrap(mlir_context); + return wrap(applyUnifyKernels(mod1, mod2, context)); } //===----------------------------------------------------------------------===// diff --git a/lib/Transforms/UnifyKernels.cpp b/lib/Transforms/UnifyKernels.cpp index 18be4c0f..bddeaf1b 100644 --- a/lib/Transforms/UnifyKernels.cpp +++ b/lib/Transforms/UnifyKernels.cpp @@ -11,7 +11,12 @@ #include "hcl/Dialect/HeteroCLOps.h" #include "hcl/Dialect/HeteroCLTypes.h" #include "hcl/Transforms/Passes.h" +#include "hcl/Dialect/TransformOps/HCLTransformOps.h" +#include "hcl-c/Dialect/Dialects.h" +#include "mlir/CAPI/IR.h" + +#include "mlir/InitAllDialects.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/IRMapping.h" @@ -94,7 +99,8 @@ bool compareAffineForOps(affine::AffineForOp &affineForOp1, affine::AffineForOp return true; } -void mergeLoop(OpBuilder &builder, affine::AffineForOp &op1, affine::AffineForOp &op2, Value conditionArg, bool &foundDifference) { +void mergeLoop(OpBuilder &builder, affine::AffineForOp &op1, affine::AffineForOp &op2, IRMapping &mapping1, IRMapping &mapping2, + Value conditionArg, bool &foundDifference) { auto loc = op1->getLoc(); // Save insertion point @@ -117,6 +123,14 @@ void mergeLoop(OpBuilder &builder, affine::AffineForOp &op1, affine::AffineForOp // Set insertion point to current loop body builder.setInsertionPointToStart(newBody); + // Add IRMapping for latter cloning + for (size_t i = 0; i < body1->getNumArguments(); ++i) { + mapping1.map(body1->getArgument(i), newBody->getArgument(i)); + } + for (size_t i = 0; i < body2->getNumArguments(); ++i) { + mapping2.map(body2->getArgument(i), newBody->getArgument(i)); + } + auto body1It = body1->begin(); auto body2It = body2->begin(); @@ -130,7 +144,7 @@ void mergeLoop(OpBuilder &builder, affine::AffineForOp &op1, affine::AffineForOp auto affineForOp2 = dyn_cast(&(*body2It)); if (affineForOp1 && affineForOp2 && compareAffineForOps(affineForOp1, affineForOp2)) { - mergeLoop(builder, affineForOp1, affineForOp2, conditionArg, foundDifference); + mergeLoop(builder, affineForOp1, affineForOp2, mapping1, mapping2, conditionArg, foundDifference); } else { foundDifference = true; @@ -147,12 +161,14 @@ void mergeLoop(OpBuilder &builder, affine::AffineForOp &op1, affine::AffineForOp } // Create branch for the rest after difference is found - IRMapping mapping1; - IRMapping mapping2; builder.create( loc, conditionArg, [&](OpBuilder &thenBuilder, Location thenLoc) { while (body1It != body1->end()) { + auto &op = *body1It; + if (auto yieldOp = dyn_cast(&op)) { + break; + } thenBuilder.clone(*body1It, mapping1); ++body1It; } @@ -160,6 +176,10 @@ void mergeLoop(OpBuilder &builder, affine::AffineForOp &op1, affine::AffineForOp }, [&](OpBuilder &elseBuilder, Location elseLoc) { while (body2It != body2->end()) { + auto &op = *body2It; + if (auto yieldOp = dyn_cast(&op)) { + break; + } elseBuilder.clone(*body2It, mapping2); ++body2It; } @@ -168,21 +188,6 @@ void mergeLoop(OpBuilder &builder, affine::AffineForOp &op1, affine::AffineForOp ); } -void printIRMapping(mlir::IRMapping &mapping) { - llvm::outs() << "IRMapping contents:\n"; - for (auto &kv : mapping.getValueMap()) { - llvm::outs() << "From Value: " << kv.first << " To Value: " << kv.second << "\n"; - } - - for (auto &kv : mapping.getBlockMap()) { - llvm::outs() << "From Block: "; - kv.first->print(llvm::outs()); - llvm::outs() << " To Block: "; - kv.second->print(llvm::outs()); - llvm::outs() << "\n"; - } -} - func::FuncOp unifyKernels(func::FuncOp &func1, func::FuncOp &func2, OpBuilder &builder) { std::string newFuncName = func1.getName().str() + "_" + func2.getName().str() + "_unified"; @@ -190,13 +195,13 @@ func::FuncOp unifyKernels(func::FuncOp &func1, func::FuncOp &func2, OpBuilder &b // Create new FuncOp with additional parameter auto oldFuncType = func1.getFunctionType(); auto oldInputTypes = oldFuncType.getInputs(); - auto resultTypes = oldFuncType.getResults(); auto loc = builder.getUnknownLoc(); SmallVector newInputTypes(oldInputTypes.begin(), oldInputTypes.end()); + auto newOutputTypes = oldFuncType.getResults(); Type instType = builder.getI1Type(); newInputTypes.push_back(instType); - auto newFuncType = builder.getFunctionType(newInputTypes, resultTypes); - auto newFuncOp = func::FuncOp::create(loc, newFuncName, newFuncType); + auto newFuncType = builder.getFunctionType(newInputTypes, newOutputTypes); + auto newFuncOp = func::FuncOp::create(loc, newFuncName, newFuncType, ArrayRef{}); // Create new block for insertion Block *entryBlock = newFuncOp.addEntryBlock(); @@ -209,6 +214,16 @@ func::FuncOp unifyKernels(func::FuncOp &func1, func::FuncOp &func2, OpBuilder &b auto block2It = block2.begin(); bool foundDifference = false; + // Create IRMapping for latter cloning + IRMapping mapping1; + for (size_t i = 0; i < block1.getNumArguments(); ++i) { + mapping1.map(block1.getArgument(i), entryBlock->getArgument(i)); + } + IRMapping mapping2; + for (size_t i = 0; i < block2.getNumArguments(); ++i) { + mapping2.map(block2.getArgument(i), entryBlock->getArgument(i)); + } + // Iterate over two FuncOps to find branch location while (block1It != block1.end() && block2It != block2.end()) { if (!foundDifference) { @@ -219,7 +234,7 @@ func::FuncOp unifyKernels(func::FuncOp &func1, func::FuncOp &func2, OpBuilder &b auto affineForOp2 = dyn_cast(&(*block2It)); if (affineForOp1 && affineForOp2 && compareAffineForOps(affineForOp1, affineForOp2)) { - mergeLoop(builder, affineForOp1, affineForOp2, conditionArg, foundDifference); + mergeLoop(builder, affineForOp1, affineForOp2, mapping1, mapping2, conditionArg, foundDifference); } else { foundDifference = true; @@ -235,95 +250,56 @@ func::FuncOp unifyKernels(func::FuncOp &func1, func::FuncOp &func2, OpBuilder &b } } + auto &op1 = *block1It; + auto &op2 = *block2It; + auto returnOp1 = dyn_cast(&op1); + auto returnOp2 = dyn_cast(&op2); // Create branch for the rest after difference is found - IRMapping mapping1; - IRMapping mapping2; - // builder.create( - // loc, conditionArg, - // [&](OpBuilder &thenBuilder, Location thenLoc) { - // for (size_t i = 0; i < block1.getNumArguments(); ++i) { - // mapping1.map(block1.getArgument(i), entryBlock->getArgument(i)); - // } - // while (block1It != block1.end()) { - // auto &op = *block1It; - // if (auto returnOp = dyn_cast(&op)) { - // break; - // } - // thenBuilder.clone(*block1It, mapping1); - // ++block1It; - // } - // thenBuilder.create(thenLoc); - // }, - // [&](OpBuilder &elseBuilder, Location elseLoc) { - // for (size_t i = 0; i < block2.getNumArguments(); ++i) { - // mapping2.map(block2.getArgument(i), entryBlock->getArgument(i)); - // } - // while (block2It != block2.end()) { - // auto &op = *block2It; - // if (auto returnOp = dyn_cast(&op)) { - // break; - // } - // elseBuilder.clone(*block2It, mapping2); - // ++block2It; - // } - // elseBuilder.create(elseLoc); - // } - // ); - auto ifOp = builder.create(loc, conditionArg, /*hasElse*/ true); - - // Create then block - builder.setInsertionPointToStart(ifOp.thenBlock()); - for (size_t i = 0; i < block1.getNumArguments(); ++i) { - mapping1.map(block1.getArgument(i), entryBlock->getArgument(i)); - } - while (block1It != block1.end()) { - auto &op = *block1It; - if (auto returnOp = dyn_cast(&op)) { - break; - } - builder.clone(*block1It, mapping1); - ++block1It; - } - - // Create else block - builder.setInsertionPointToStart(ifOp.elseBlock()); - for (size_t i = 0; i < block2.getNumArguments(); ++i) { - mapping2.map(block2.getArgument(i), entryBlock->getArgument(i)); - } - while (block2It != block2.end()) { - auto &op = *block2It; - if (auto returnOp = dyn_cast(&op)) { - break; - } - builder.clone(*block2It, mapping2); - ++block2It; + if (!returnOp1 || !returnOp2) { + builder.create( + loc, conditionArg, + [&](OpBuilder &thenBuilder, Location thenLoc) { + while (block1It != block1.end()) { + auto &op = *block1It; + if (auto returnOp = dyn_cast(&op)) { + break; + } + thenBuilder.clone(*block1It, mapping1); + ++block1It; + } + thenBuilder.create(thenLoc); + }, + [&](OpBuilder &elseBuilder, Location elseLoc) { + while (block2It != block2.end()) { + auto &op = *block2It; + if (auto returnOp = dyn_cast(&op)) { + break; + } + elseBuilder.clone(*block2It, mapping2); + ++block2It; + } + elseBuilder.create(elseLoc); + } + ); } - - // Create return Op - builder.setInsertionPointAfter(ifOp); - - printIRMapping(mapping1); - printIRMapping(mapping2); - + + // Create returnOp + // Todo: Now assume the return value is the same builder.clone(*block1It, mapping1); - // builder.clone(*block2It, mapping2); return newFuncOp; } /// Pass entry point -ModuleOp applyUnifyKernels(ModuleOp &module1, ModuleOp &module2) { +ModuleOp applyUnifyKernels(ModuleOp &module1, ModuleOp &module2, MLIRContext *context) { auto funcOps1 = module1.getOps(); auto funcOps2 = module2.getOps(); auto it1 = funcOps1.begin(); auto it2 = funcOps2.begin(); - MLIRContext *context = module1.getContext(); - OpBuilder builder(context); - - // - ModuleOp newModule = ModuleOp::create(module1.getLoc()); + ModuleOp newModule = ModuleOp::create(UnknownLoc::get(context)); + OpBuilder builder(newModule.getContext()); while (it1 != funcOps1.end() && it2 != funcOps2.end()) { func::FuncOp funcOp1 = *it1; diff --git a/tools/hcl-opt/hcl-opt.cpp b/tools/hcl-opt/hcl-opt.cpp index 230ff8dc..3d4debdd 100644 --- a/tools/hcl-opt/hcl-opt.cpp +++ b/tools/hcl-opt/hcl-opt.cpp @@ -243,136 +243,201 @@ int runJiTCompiler(mlir::ModuleOp module) { return 0; } +void test() { + #include "mlir/Dialect/Func/IR/FuncOps.h" + #include "mlir/Dialect/SCF/IR/SCF.h" + #include "mlir/Dialect/Arith/IR/Arith.h" + // 创建一个 MLIRContext + mlir::MLIRContext context; + context.getOrLoadDialect(); + context.getOrLoadDialect(); + + // 创建第一个 ModuleOp + mlir::OpBuilder builder1(&context); + auto module1 = mlir::ModuleOp::create(builder1.getUnknownLoc()); + + // 在 module1 中添加一个空函数 + builder1.setInsertionPointToEnd(module1.getBody()); + mlir::FunctionType funcType1 = builder1.getFunctionType({}, builder1.getI32Type()); + mlir::func::FuncOp func1 = builder1.create( + builder1.getUnknownLoc(), + "func1", + funcType1 + ); + mlir::Block *entryBlock1 = func1.addEntryBlock(); + + // 在函数中添加常量和返回操作 + builder1.setInsertionPointToStart(entryBlock1); + auto constant1 = builder1.create( + builder1.getUnknownLoc(), + builder1.getI32Type(), + builder1.getI32IntegerAttr(42) // 常量值为 42 + ); + builder1.create(builder1.getUnknownLoc(), constant1.getResult()); + + // 打印第一个 ModuleOp + module1.dump(); + + // 创建第二个 ModuleOp + mlir::OpBuilder builder2(&context); + auto module2 = mlir::ModuleOp::create(builder2.getUnknownLoc()); + + // 在 module2 中添加一个空函数 + mlir::FunctionType funcType2 = builder2.getFunctionType({}, builder2.getI32Type()); + builder2.setInsertionPointToEnd(module2.getBody()); + mlir::func::FuncOp func2 = builder2.create( + builder2.getUnknownLoc(), + "func2", + funcType2 + ); + mlir::Block *entryBlock2 = func2.addEntryBlock(); + + // 在函数中添加常量和返回操作 + builder2.setInsertionPointToStart(entryBlock2); + auto constant2 = builder2.create( + builder2.getUnknownLoc(), + builder2.getI32Type(), + builder2.getI32IntegerAttr(7) // 常量值为 7 + ); + builder2.create(builder2.getUnknownLoc(), constant2.getResult()); + + // 打印第二个 ModuleOp + module2.dump(); + + // mlir::hcl::applyUnifyKernels(module1, module2); +} + int main(int argc, char **argv) { + test(); // Register dialects and passes in current context - mlir::DialectRegistry registry; - mlir::registerAllDialects(registry); - registry.insert(); - mlir::hcl::registerTransformDialectExtension(registry); - - mlir::MLIRContext context; - context.appendDialectRegistry(registry); - context.allowUnregisteredDialects(true); - context.printOpOnDiagnostic(true); - context.loadAllAvailableDialects(); - - mlir::registerAllPasses(); - mlir::hcl::registerHCLPasses(); - mlir::hcl::registerHCLConversionPasses(); - - // Parse pass names in main to ensure static initialization completed - llvm::cl::ParseCommandLineOptions(argc, argv, - "MLIR modular optimizer driver\n"); - - mlir::OwningOpRef module; - if (int error = loadMLIR(context, module)) - return error; - - // Initialize a pass manager - // https://mlir.llvm.org/docs/PassManagement/ - // Operation agnostic passes - mlir::PassManager pm(&context); - // Operation specific passes - mlir::OpPassManager &optPM = pm.nest(); - if (enableOpt) { - pm.addPass(mlir::hcl::createLoopTransformationPass()); - } - - if (dataPlacement) { - pm.addPass(mlir::hcl::createDataPlacementPass()); - } - - if (memRefDCE) { - pm.addPass(mlir::hcl::createMemRefDCEPass()); - } - - if (lowerComposite) { - pm.addPass(mlir::hcl::createLowerCompositeTypePass()); - } - - if (fixedPointToInteger) { - pm.addPass(mlir::hcl::createFixedPointToIntegerPass()); - } - - // lowerPrintOps should be run after lowering fixed point to integer - if (lowerPrintOps) { - pm.addPass(mlir::hcl::createLowerPrintOpsPass()); - } - - if (anyWidthInteger) { - pm.addPass(mlir::hcl::createAnyWidthIntegerPass()); - } - - if (moveReturnToInput) { - pm.addPass(mlir::hcl::createMoveReturnToInputPass()); - } - - if (lowerBitOps) { - pm.addPass(mlir::hcl::createLowerBitOpsPass()); - } - - if (legalizeCast) { - pm.addPass(mlir::hcl::createLegalizeCastPass()); - } - - if (removeStrideMap) { - pm.addPass(mlir::hcl::createRemoveStrideMapPass()); - } - - if (bufferization) { - pm.addPass(mlir::bufferization::createOneShotBufferizePass()); - } - - if (linalgConversion) { - optPM.addPass(mlir::createConvertLinalgToAffineLoopsPass()); - } - - if (enableNormalize) { - // To make all loop steps to 1. - optPM.addPass(mlir::affine::createAffineLoopNormalizePass()); - - // Sparse Conditional Constant Propagation (SCCP) - pm.addPass(mlir::createSCCPPass()); - - // To factor out the redundant AffineApply/AffineIf operations. - // optPM.addPass(mlir::createCanonicalizerPass()); - // optPM.addPass(mlir::createSimplifyAffineStructuresPass()); - - // To simplify the memory accessing. - pm.addPass(mlir::memref::createNormalizeMemRefsPass()); - - // Generic common sub expression elimination. - // pm.addPass(mlir::createCSEPass()); - } - - if (applyTransform) - pm.addPass(mlir::hcl::createTransformInterpreterPass()); - - if (runJiT || lowerToLLVM) { - if (!removeStrideMap) { - pm.addPass(mlir::hcl::createRemoveStrideMapPass()); - } - pm.addPass(mlir::hcl::createHCLToLLVMLoweringPass()); - } - - // Run the pass pipeline - if (mlir::failed(pm.run(*module))) { - return 4; - } - - // print output - std::string errorMessage; - auto outfile = mlir::openOutputFile(outputFilename, &errorMessage); - if (!outfile) { - llvm::errs() << errorMessage << "\n"; - return 2; - } - module->print(outfile->os()); - outfile->os() << "\n"; - - // run JiT - if (runJiT) - return runJiTCompiler(*module); + // mlir::DialectRegistry registry; + // mlir::registerAllDialects(registry); + // registry.insert(); + // mlir::hcl::registerTransformDialectExtension(registry); + + // mlir::MLIRContext context; + // context.appendDialectRegistry(registry); + // context.allowUnregisteredDialects(true); + // context.printOpOnDiagnostic(true); + // context.loadAllAvailableDialects(); + + // mlir::registerAllPasses(); + // mlir::hcl::registerHCLPasses(); + // mlir::hcl::registerHCLConversionPasses(); + + // // Parse pass names in main to ensure static initialization completed + // llvm::cl::ParseCommandLineOptions(argc, argv, + // "MLIR modular optimizer driver\n"); + + // mlir::OwningOpRef module; + // if (int error = loadMLIR(context, module)) + // return error; + + // // Initialize a pass manager + // // https://mlir.llvm.org/docs/PassManagement/ + // // Operation agnostic passes + // mlir::PassManager pm(&context); + // // Operation specific passes + // mlir::OpPassManager &optPM = pm.nest(); + // if (enableOpt) { + // pm.addPass(mlir::hcl::createLoopTransformationPass()); + // } + + // if (dataPlacement) { + // pm.addPass(mlir::hcl::createDataPlacementPass()); + // } + + // if (memRefDCE) { + // pm.addPass(mlir::hcl::createMemRefDCEPass()); + // } + + // if (lowerComposite) { + // pm.addPass(mlir::hcl::createLowerCompositeTypePass()); + // } + + // if (fixedPointToInteger) { + // pm.addPass(mlir::hcl::createFixedPointToIntegerPass()); + // } + + // // lowerPrintOps should be run after lowering fixed point to integer + // if (lowerPrintOps) { + // pm.addPass(mlir::hcl::createLowerPrintOpsPass()); + // } + + // if (anyWidthInteger) { + // pm.addPass(mlir::hcl::createAnyWidthIntegerPass()); + // } + + // if (moveReturnToInput) { + // pm.addPass(mlir::hcl::createMoveReturnToInputPass()); + // } + + // if (lowerBitOps) { + // pm.addPass(mlir::hcl::createLowerBitOpsPass()); + // } + + // if (legalizeCast) { + // pm.addPass(mlir::hcl::createLegalizeCastPass()); + // } + + // if (removeStrideMap) { + // pm.addPass(mlir::hcl::createRemoveStrideMapPass()); + // } + + // if (bufferization) { + // pm.addPass(mlir::bufferization::createOneShotBufferizePass()); + // } + + // if (linalgConversion) { + // optPM.addPass(mlir::createConvertLinalgToAffineLoopsPass()); + // } + + // if (enableNormalize) { + // // To make all loop steps to 1. + // optPM.addPass(mlir::affine::createAffineLoopNormalizePass()); + + // // Sparse Conditional Constant Propagation (SCCP) + // pm.addPass(mlir::createSCCPPass()); + + // // To factor out the redundant AffineApply/AffineIf operations. + // // optPM.addPass(mlir::createCanonicalizerPass()); + // // optPM.addPass(mlir::createSimplifyAffineStructuresPass()); + + // // To simplify the memory accessing. + // pm.addPass(mlir::memref::createNormalizeMemRefsPass()); + + // // Generic common sub expression elimination. + // // pm.addPass(mlir::createCSEPass()); + // } + + // if (applyTransform) + // pm.addPass(mlir::hcl::createTransformInterpreterPass()); + + // if (runJiT || lowerToLLVM) { + // if (!removeStrideMap) { + // pm.addPass(mlir::hcl::createRemoveStrideMapPass()); + // } + // pm.addPass(mlir::hcl::createHCLToLLVMLoweringPass()); + // } + + // // Run the pass pipeline + // if (mlir::failed(pm.run(*module))) { + // return 4; + // } + + // // print output + // std::string errorMessage; + // auto outfile = mlir::openOutputFile(outputFilename, &errorMessage); + // if (!outfile) { + // llvm::errs() << errorMessage << "\n"; + // return 2; + // } + // module->print(outfile->os()); + // outfile->os() << "\n"; + + // // run JiT + // if (runJiT) + // return runJiTCompiler(*module); return 0; } \ No newline at end of file From 3d0f02408c095ac2fe0dfac427afae895eafdf0f Mon Sep 17 00:00:00 2001 From: Han Meng Date: Sat, 21 Sep 2024 20:45:19 -0400 Subject: [PATCH 4/7] restore changes in hcl-opt.cpp --- tools/hcl-opt/hcl-opt.cpp | 321 +++++++++++++++----------------------- 1 file changed, 128 insertions(+), 193 deletions(-) diff --git a/tools/hcl-opt/hcl-opt.cpp b/tools/hcl-opt/hcl-opt.cpp index 3d4debdd..230ff8dc 100644 --- a/tools/hcl-opt/hcl-opt.cpp +++ b/tools/hcl-opt/hcl-opt.cpp @@ -243,201 +243,136 @@ int runJiTCompiler(mlir::ModuleOp module) { return 0; } -void test() { - #include "mlir/Dialect/Func/IR/FuncOps.h" - #include "mlir/Dialect/SCF/IR/SCF.h" - #include "mlir/Dialect/Arith/IR/Arith.h" - // 创建一个 MLIRContext - mlir::MLIRContext context; - context.getOrLoadDialect(); - context.getOrLoadDialect(); - - // 创建第一个 ModuleOp - mlir::OpBuilder builder1(&context); - auto module1 = mlir::ModuleOp::create(builder1.getUnknownLoc()); - - // 在 module1 中添加一个空函数 - builder1.setInsertionPointToEnd(module1.getBody()); - mlir::FunctionType funcType1 = builder1.getFunctionType({}, builder1.getI32Type()); - mlir::func::FuncOp func1 = builder1.create( - builder1.getUnknownLoc(), - "func1", - funcType1 - ); - mlir::Block *entryBlock1 = func1.addEntryBlock(); - - // 在函数中添加常量和返回操作 - builder1.setInsertionPointToStart(entryBlock1); - auto constant1 = builder1.create( - builder1.getUnknownLoc(), - builder1.getI32Type(), - builder1.getI32IntegerAttr(42) // 常量值为 42 - ); - builder1.create(builder1.getUnknownLoc(), constant1.getResult()); - - // 打印第一个 ModuleOp - module1.dump(); - - // 创建第二个 ModuleOp - mlir::OpBuilder builder2(&context); - auto module2 = mlir::ModuleOp::create(builder2.getUnknownLoc()); - - // 在 module2 中添加一个空函数 - mlir::FunctionType funcType2 = builder2.getFunctionType({}, builder2.getI32Type()); - builder2.setInsertionPointToEnd(module2.getBody()); - mlir::func::FuncOp func2 = builder2.create( - builder2.getUnknownLoc(), - "func2", - funcType2 - ); - mlir::Block *entryBlock2 = func2.addEntryBlock(); - - // 在函数中添加常量和返回操作 - builder2.setInsertionPointToStart(entryBlock2); - auto constant2 = builder2.create( - builder2.getUnknownLoc(), - builder2.getI32Type(), - builder2.getI32IntegerAttr(7) // 常量值为 7 - ); - builder2.create(builder2.getUnknownLoc(), constant2.getResult()); - - // 打印第二个 ModuleOp - module2.dump(); - - // mlir::hcl::applyUnifyKernels(module1, module2); -} - int main(int argc, char **argv) { - test(); // Register dialects and passes in current context - // mlir::DialectRegistry registry; - // mlir::registerAllDialects(registry); - // registry.insert(); - // mlir::hcl::registerTransformDialectExtension(registry); - - // mlir::MLIRContext context; - // context.appendDialectRegistry(registry); - // context.allowUnregisteredDialects(true); - // context.printOpOnDiagnostic(true); - // context.loadAllAvailableDialects(); - - // mlir::registerAllPasses(); - // mlir::hcl::registerHCLPasses(); - // mlir::hcl::registerHCLConversionPasses(); - - // // Parse pass names in main to ensure static initialization completed - // llvm::cl::ParseCommandLineOptions(argc, argv, - // "MLIR modular optimizer driver\n"); - - // mlir::OwningOpRef module; - // if (int error = loadMLIR(context, module)) - // return error; - - // // Initialize a pass manager - // // https://mlir.llvm.org/docs/PassManagement/ - // // Operation agnostic passes - // mlir::PassManager pm(&context); - // // Operation specific passes - // mlir::OpPassManager &optPM = pm.nest(); - // if (enableOpt) { - // pm.addPass(mlir::hcl::createLoopTransformationPass()); - // } - - // if (dataPlacement) { - // pm.addPass(mlir::hcl::createDataPlacementPass()); - // } - - // if (memRefDCE) { - // pm.addPass(mlir::hcl::createMemRefDCEPass()); - // } - - // if (lowerComposite) { - // pm.addPass(mlir::hcl::createLowerCompositeTypePass()); - // } - - // if (fixedPointToInteger) { - // pm.addPass(mlir::hcl::createFixedPointToIntegerPass()); - // } - - // // lowerPrintOps should be run after lowering fixed point to integer - // if (lowerPrintOps) { - // pm.addPass(mlir::hcl::createLowerPrintOpsPass()); - // } - - // if (anyWidthInteger) { - // pm.addPass(mlir::hcl::createAnyWidthIntegerPass()); - // } - - // if (moveReturnToInput) { - // pm.addPass(mlir::hcl::createMoveReturnToInputPass()); - // } - - // if (lowerBitOps) { - // pm.addPass(mlir::hcl::createLowerBitOpsPass()); - // } - - // if (legalizeCast) { - // pm.addPass(mlir::hcl::createLegalizeCastPass()); - // } - - // if (removeStrideMap) { - // pm.addPass(mlir::hcl::createRemoveStrideMapPass()); - // } - - // if (bufferization) { - // pm.addPass(mlir::bufferization::createOneShotBufferizePass()); - // } - - // if (linalgConversion) { - // optPM.addPass(mlir::createConvertLinalgToAffineLoopsPass()); - // } - - // if (enableNormalize) { - // // To make all loop steps to 1. - // optPM.addPass(mlir::affine::createAffineLoopNormalizePass()); - - // // Sparse Conditional Constant Propagation (SCCP) - // pm.addPass(mlir::createSCCPPass()); - - // // To factor out the redundant AffineApply/AffineIf operations. - // // optPM.addPass(mlir::createCanonicalizerPass()); - // // optPM.addPass(mlir::createSimplifyAffineStructuresPass()); - - // // To simplify the memory accessing. - // pm.addPass(mlir::memref::createNormalizeMemRefsPass()); - - // // Generic common sub expression elimination. - // // pm.addPass(mlir::createCSEPass()); - // } - - // if (applyTransform) - // pm.addPass(mlir::hcl::createTransformInterpreterPass()); - - // if (runJiT || lowerToLLVM) { - // if (!removeStrideMap) { - // pm.addPass(mlir::hcl::createRemoveStrideMapPass()); - // } - // pm.addPass(mlir::hcl::createHCLToLLVMLoweringPass()); - // } - - // // Run the pass pipeline - // if (mlir::failed(pm.run(*module))) { - // return 4; - // } - - // // print output - // std::string errorMessage; - // auto outfile = mlir::openOutputFile(outputFilename, &errorMessage); - // if (!outfile) { - // llvm::errs() << errorMessage << "\n"; - // return 2; - // } - // module->print(outfile->os()); - // outfile->os() << "\n"; - - // // run JiT - // if (runJiT) - // return runJiTCompiler(*module); + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + registry.insert(); + mlir::hcl::registerTransformDialectExtension(registry); + + mlir::MLIRContext context; + context.appendDialectRegistry(registry); + context.allowUnregisteredDialects(true); + context.printOpOnDiagnostic(true); + context.loadAllAvailableDialects(); + + mlir::registerAllPasses(); + mlir::hcl::registerHCLPasses(); + mlir::hcl::registerHCLConversionPasses(); + + // Parse pass names in main to ensure static initialization completed + llvm::cl::ParseCommandLineOptions(argc, argv, + "MLIR modular optimizer driver\n"); + + mlir::OwningOpRef module; + if (int error = loadMLIR(context, module)) + return error; + + // Initialize a pass manager + // https://mlir.llvm.org/docs/PassManagement/ + // Operation agnostic passes + mlir::PassManager pm(&context); + // Operation specific passes + mlir::OpPassManager &optPM = pm.nest(); + if (enableOpt) { + pm.addPass(mlir::hcl::createLoopTransformationPass()); + } + + if (dataPlacement) { + pm.addPass(mlir::hcl::createDataPlacementPass()); + } + + if (memRefDCE) { + pm.addPass(mlir::hcl::createMemRefDCEPass()); + } + + if (lowerComposite) { + pm.addPass(mlir::hcl::createLowerCompositeTypePass()); + } + + if (fixedPointToInteger) { + pm.addPass(mlir::hcl::createFixedPointToIntegerPass()); + } + + // lowerPrintOps should be run after lowering fixed point to integer + if (lowerPrintOps) { + pm.addPass(mlir::hcl::createLowerPrintOpsPass()); + } + + if (anyWidthInteger) { + pm.addPass(mlir::hcl::createAnyWidthIntegerPass()); + } + + if (moveReturnToInput) { + pm.addPass(mlir::hcl::createMoveReturnToInputPass()); + } + + if (lowerBitOps) { + pm.addPass(mlir::hcl::createLowerBitOpsPass()); + } + + if (legalizeCast) { + pm.addPass(mlir::hcl::createLegalizeCastPass()); + } + + if (removeStrideMap) { + pm.addPass(mlir::hcl::createRemoveStrideMapPass()); + } + + if (bufferization) { + pm.addPass(mlir::bufferization::createOneShotBufferizePass()); + } + + if (linalgConversion) { + optPM.addPass(mlir::createConvertLinalgToAffineLoopsPass()); + } + + if (enableNormalize) { + // To make all loop steps to 1. + optPM.addPass(mlir::affine::createAffineLoopNormalizePass()); + + // Sparse Conditional Constant Propagation (SCCP) + pm.addPass(mlir::createSCCPPass()); + + // To factor out the redundant AffineApply/AffineIf operations. + // optPM.addPass(mlir::createCanonicalizerPass()); + // optPM.addPass(mlir::createSimplifyAffineStructuresPass()); + + // To simplify the memory accessing. + pm.addPass(mlir::memref::createNormalizeMemRefsPass()); + + // Generic common sub expression elimination. + // pm.addPass(mlir::createCSEPass()); + } + + if (applyTransform) + pm.addPass(mlir::hcl::createTransformInterpreterPass()); + + if (runJiT || lowerToLLVM) { + if (!removeStrideMap) { + pm.addPass(mlir::hcl::createRemoveStrideMapPass()); + } + pm.addPass(mlir::hcl::createHCLToLLVMLoweringPass()); + } + + // Run the pass pipeline + if (mlir::failed(pm.run(*module))) { + return 4; + } + + // print output + std::string errorMessage; + auto outfile = mlir::openOutputFile(outputFilename, &errorMessage); + if (!outfile) { + llvm::errs() << errorMessage << "\n"; + return 2; + } + module->print(outfile->os()); + outfile->os() << "\n"; + + // run JiT + if (runJiT) + return runJiTCompiler(*module); return 0; } \ No newline at end of file From 8dec902c782be627a698d8d0b4c4f42a7156f147 Mon Sep 17 00:00:00 2001 From: Han Meng Date: Sat, 21 Sep 2024 21:09:46 -0400 Subject: [PATCH 5/7] clang-format new files --- lib/Bindings/Python/HCLModule.cpp | 3 +- lib/Transforms/UnifyKernels.cpp | 215 ++++++++++++++++-------------- 2 files changed, 116 insertions(+), 102 deletions(-) diff --git a/lib/Bindings/Python/HCLModule.cpp b/lib/Bindings/Python/HCLModule.cpp index f341e21f..000160b9 100644 --- a/lib/Bindings/Python/HCLModule.cpp +++ b/lib/Bindings/Python/HCLModule.cpp @@ -157,7 +157,8 @@ static bool memRefDCE(MlirModule &mlir_mod) { return applyMemRefDCE(mod); } -static MlirModule UnifyKernels(MlirModule &mlir_mod1, MlirModule &mlir_mod2, MlirContext &mlir_context) { +static MlirModule UnifyKernels(MlirModule &mlir_mod1, MlirModule &mlir_mod2, + MlirContext &mlir_context) { auto mod1 = unwrap(mlir_mod1); auto mod2 = unwrap(mlir_mod2); auto context = unwrap(mlir_context); diff --git a/lib/Transforms/UnifyKernels.cpp b/lib/Transforms/UnifyKernels.cpp index bddeaf1b..58049bac 100644 --- a/lib/Transforms/UnifyKernels.cpp +++ b/lib/Transforms/UnifyKernels.cpp @@ -10,20 +10,20 @@ #include "hcl/Dialect/HeteroCLDialect.h" #include "hcl/Dialect/HeteroCLOps.h" #include "hcl/Dialect/HeteroCLTypes.h" -#include "hcl/Transforms/Passes.h" #include "hcl/Dialect/TransformOps/HCLTransformOps.h" +#include "hcl/Transforms/Passes.h" #include "hcl-c/Dialect/Dialects.h" #include "mlir/CAPI/IR.h" -#include "mlir/InitAllDialects.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/OperationSupport.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/InitAllDialects.h" using namespace mlir; using namespace hcl; @@ -39,31 +39,31 @@ bool compareAffineExprs(AffineExpr lhsExpr, AffineExpr rhsExpr) { // Compare affine exprs based on kind switch (lhsExpr.getKind()) { - case AffineExprKind::Constant: { - auto lhsConst = lhsExpr.cast(); - auto rhsConst = rhsExpr.cast(); - return lhsConst.getValue() == rhsConst.getValue(); - } - case AffineExprKind::DimId: { - auto lhsDim = lhsExpr.cast(); - auto rhsDim = rhsExpr.cast(); - return lhsDim.getPosition() == rhsDim.getPosition(); - } - case AffineExprKind::SymbolId: { - auto lhsSymbol = lhsExpr.cast(); - auto rhsSymbol = rhsExpr.cast(); - return lhsSymbol.getPosition() == rhsSymbol.getPosition(); - } - case AffineExprKind::Add: - case AffineExprKind::Mul: - case AffineExprKind::Mod: - case AffineExprKind::FloorDiv: - case AffineExprKind::CeilDiv: { - auto lhsBinary = lhsExpr.cast(); - auto rhsBinary = rhsExpr.cast(); - return compareAffineExprs(lhsBinary.getLHS(), rhsBinary.getLHS()) && - compareAffineExprs(lhsBinary.getRHS(), rhsBinary.getRHS()); - } + case AffineExprKind::Constant: { + auto lhsConst = lhsExpr.cast(); + auto rhsConst = rhsExpr.cast(); + return lhsConst.getValue() == rhsConst.getValue(); + } + case AffineExprKind::DimId: { + auto lhsDim = lhsExpr.cast(); + auto rhsDim = rhsExpr.cast(); + return lhsDim.getPosition() == rhsDim.getPosition(); + } + case AffineExprKind::SymbolId: { + auto lhsSymbol = lhsExpr.cast(); + auto rhsSymbol = rhsExpr.cast(); + return lhsSymbol.getPosition() == rhsSymbol.getPosition(); + } + case AffineExprKind::Add: + case AffineExprKind::Mul: + case AffineExprKind::Mod: + case AffineExprKind::FloorDiv: + case AffineExprKind::CeilDiv: { + auto lhsBinary = lhsExpr.cast(); + auto rhsBinary = rhsExpr.cast(); + return compareAffineExprs(lhsBinary.getLHS(), rhsBinary.getLHS()) && + compareAffineExprs(lhsBinary.getRHS(), rhsBinary.getRHS()); + } } return false; } @@ -74,12 +74,13 @@ bool compareAffineMaps(AffineMap lhsMap, AffineMap rhsMap) { if (simplifiedLhsMap.getNumDims() != simplifiedRhsMap.getNumDims() && simplifiedLhsMap.getNumSymbols() != simplifiedRhsMap.getNumSymbols() && - simplifiedLhsMap.getNumResults() != simplifiedRhsMap.getNumResults()) + simplifiedLhsMap.getNumResults() != simplifiedRhsMap.getNumResults()) return false; // Compare exprs for (unsigned i = 0; i < simplifiedLhsMap.getNumResults(); ++i) { - if (!compareAffineExprs(simplifiedLhsMap.getResult(i), simplifiedRhsMap.getResult(i))) { + if (!compareAffineExprs(simplifiedLhsMap.getResult(i), + simplifiedRhsMap.getResult(i))) { return false; } } @@ -88,19 +89,24 @@ bool compareAffineMaps(AffineMap lhsMap, AffineMap rhsMap) { return true; } -bool compareAffineForOps(affine::AffineForOp &affineForOp1, affine::AffineForOp &affineForOp2) { +bool compareAffineForOps(affine::AffineForOp &affineForOp1, + affine::AffineForOp &affineForOp2) { if (affineForOp1 == affineForOp2) return true; - if (affineForOp1.getStep() != affineForOp2.getStep()) return false; - if (!compareAffineMaps(affineForOp1.getLowerBoundMap(), affineForOp2.getLowerBoundMap()) || - !compareAffineMaps(affineForOp1.getUpperBoundMap(), affineForOp2.getUpperBoundMap())) + if (affineForOp1.getStep() != affineForOp2.getStep()) + return false; + if (!compareAffineMaps(affineForOp1.getLowerBoundMap(), + affineForOp2.getLowerBoundMap()) || + !compareAffineMaps(affineForOp1.getUpperBoundMap(), + affineForOp2.getUpperBoundMap())) return false; return true; } -void mergeLoop(OpBuilder &builder, affine::AffineForOp &op1, affine::AffineForOp &op2, IRMapping &mapping1, IRMapping &mapping2, - Value conditionArg, bool &foundDifference) { +void mergeLoop(OpBuilder &builder, affine::AffineForOp &op1, + affine::AffineForOp &op2, IRMapping &mapping1, + IRMapping &mapping2, Value conditionArg, bool &foundDifference) { auto loc = op1->getLoc(); // Save insertion point @@ -108,13 +114,16 @@ void mergeLoop(OpBuilder &builder, affine::AffineForOp &op1, affine::AffineForOp // Create new affine.for with same arguments auto lowerBoundMap = op1.getLowerBoundMap(); - auto lowerBoundOperands = llvm::SmallVector(op1.getLowerBoundOperands().begin(), op1.getLowerBoundOperands().end()); + auto lowerBoundOperands = llvm::SmallVector( + op1.getLowerBoundOperands().begin(), op1.getLowerBoundOperands().end()); auto upperBoundMap = op1.getUpperBoundMap(); - auto upperBoundOperands = llvm::SmallVector(op1.getUpperBoundOperands().begin(), op1.getUpperBoundOperands().end()); + auto upperBoundOperands = llvm::SmallVector( + op1.getUpperBoundOperands().begin(), op1.getUpperBoundOperands().end()); int64_t step = op1.getStep(); auto newAffineForOp = builder.create( - loc, lowerBoundOperands, lowerBoundMap, upperBoundOperands, upperBoundMap, step); + loc, lowerBoundOperands, lowerBoundMap, upperBoundOperands, upperBoundMap, + step); Block *body1 = op1.getBody(); Block *body2 = op2.getBody(); @@ -142,11 +151,11 @@ void mergeLoop(OpBuilder &builder, affine::AffineForOp &op1, affine::AffineForOp // Todo: Support dynamic loop range auto affineForOp1 = dyn_cast(&(*body1It)); auto affineForOp2 = dyn_cast(&(*body2It)); - if (affineForOp1 && affineForOp2 && + if (affineForOp1 && affineForOp2 && compareAffineForOps(affineForOp1, affineForOp2)) { - mergeLoop(builder, affineForOp1, affineForOp2, mapping1, mapping2, conditionArg, foundDifference); - } - else { + mergeLoop(builder, affineForOp1, affineForOp2, mapping1, mapping2, + conditionArg, foundDifference); + } else { foundDifference = true; break; } @@ -162,50 +171,54 @@ void mergeLoop(OpBuilder &builder, affine::AffineForOp &op1, affine::AffineForOp // Create branch for the rest after difference is found builder.create( - loc, conditionArg, - [&](OpBuilder &thenBuilder, Location thenLoc) { - while (body1It != body1->end()) { - auto &op = *body1It; - if (auto yieldOp = dyn_cast(&op)) { - break; + loc, conditionArg, + [&](OpBuilder &thenBuilder, Location thenLoc) { + while (body1It != body1->end()) { + auto &op = *body1It; + if (auto yieldOp = dyn_cast(&op)) { + break; + } + thenBuilder.clone(*body1It, mapping1); + ++body1It; } - thenBuilder.clone(*body1It, mapping1); - ++body1It; - } - thenBuilder.create(thenLoc); - }, - [&](OpBuilder &elseBuilder, Location elseLoc) { - while (body2It != body2->end()) { - auto &op = *body2It; - if (auto yieldOp = dyn_cast(&op)) { - break; + thenBuilder.create(thenLoc); + }, + [&](OpBuilder &elseBuilder, Location elseLoc) { + while (body2It != body2->end()) { + auto &op = *body2It; + if (auto yieldOp = dyn_cast(&op)) { + break; + } + elseBuilder.clone(*body2It, mapping2); + ++body2It; } - elseBuilder.clone(*body2It, mapping2); - ++body2It; - } - elseBuilder.create(elseLoc); - } - ); + elseBuilder.create(elseLoc); + }); } -func::FuncOp unifyKernels(func::FuncOp &func1, func::FuncOp &func2, OpBuilder &builder) { - std::string newFuncName = func1.getName().str() + "_" + func2.getName().str() + "_unified"; +func::FuncOp unifyKernels(func::FuncOp &func1, func::FuncOp &func2, + OpBuilder &builder) { + std::string newFuncName = + func1.getName().str() + "_" + func2.getName().str() + "_unified"; // Todo: Now assuming return types and input types are the same // Create new FuncOp with additional parameter auto oldFuncType = func1.getFunctionType(); auto oldInputTypes = oldFuncType.getInputs(); auto loc = builder.getUnknownLoc(); - SmallVector newInputTypes(oldInputTypes.begin(), oldInputTypes.end()); + SmallVector newInputTypes(oldInputTypes.begin(), + oldInputTypes.end()); auto newOutputTypes = oldFuncType.getResults(); Type instType = builder.getI1Type(); newInputTypes.push_back(instType); auto newFuncType = builder.getFunctionType(newInputTypes, newOutputTypes); - auto newFuncOp = func::FuncOp::create(loc, newFuncName, newFuncType, ArrayRef{}); + auto newFuncOp = func::FuncOp::create(loc, newFuncName, newFuncType, + ArrayRef{}); // Create new block for insertion Block *entryBlock = newFuncOp.addEntryBlock(); - auto conditionArg = entryBlock->getArgument(entryBlock->getNumArguments() - 1); + auto conditionArg = + entryBlock->getArgument(entryBlock->getNumArguments() - 1); builder.setInsertionPointToStart(entryBlock); auto &block1 = func1.front(); @@ -226,17 +239,17 @@ func::FuncOp unifyKernels(func::FuncOp &func1, func::FuncOp &func2, OpBuilder &b // Iterate over two FuncOps to find branch location while (block1It != block1.end() && block2It != block2.end()) { - if (!foundDifference) { + if (!foundDifference) { if (!(&(*block1It) == &(*block2It))) { // If we found an affine.for to merge // Todo: Support dynamic loop range auto affineForOp1 = dyn_cast(&(*block1It)); auto affineForOp2 = dyn_cast(&(*block2It)); - if (affineForOp1 && affineForOp2 && + if (affineForOp1 && affineForOp2 && compareAffineForOps(affineForOp1, affineForOp2)) { - mergeLoop(builder, affineForOp1, affineForOp2, mapping1, mapping2, conditionArg, foundDifference); - } - else { + mergeLoop(builder, affineForOp1, affineForOp2, mapping1, mapping2, + conditionArg, foundDifference); + } else { foundDifference = true; break; } @@ -257,32 +270,31 @@ func::FuncOp unifyKernels(func::FuncOp &func1, func::FuncOp &func2, OpBuilder &b // Create branch for the rest after difference is found if (!returnOp1 || !returnOp2) { builder.create( - loc, conditionArg, - [&](OpBuilder &thenBuilder, Location thenLoc) { - while (block1It != block1.end()) { - auto &op = *block1It; - if (auto returnOp = dyn_cast(&op)) { - break; + loc, conditionArg, + [&](OpBuilder &thenBuilder, Location thenLoc) { + while (block1It != block1.end()) { + auto &op = *block1It; + if (auto returnOp = dyn_cast(&op)) { + break; + } + thenBuilder.clone(*block1It, mapping1); + ++block1It; } - thenBuilder.clone(*block1It, mapping1); - ++block1It; - } - thenBuilder.create(thenLoc); - }, - [&](OpBuilder &elseBuilder, Location elseLoc) { - while (block2It != block2.end()) { - auto &op = *block2It; - if (auto returnOp = dyn_cast(&op)) { - break; + thenBuilder.create(thenLoc); + }, + [&](OpBuilder &elseBuilder, Location elseLoc) { + while (block2It != block2.end()) { + auto &op = *block2It; + if (auto returnOp = dyn_cast(&op)) { + break; + } + elseBuilder.clone(*block2It, mapping2); + ++block2It; } - elseBuilder.clone(*block2It, mapping2); - ++block2It; - } - elseBuilder.create(elseLoc); - } - ); + elseBuilder.create(elseLoc); + }); } - + // Create returnOp // Todo: Now assume the return value is the same builder.clone(*block1It, mapping1); @@ -291,7 +303,8 @@ func::FuncOp unifyKernels(func::FuncOp &func1, func::FuncOp &func2, OpBuilder &b } /// Pass entry point -ModuleOp applyUnifyKernels(ModuleOp &module1, ModuleOp &module2, MLIRContext *context) { +ModuleOp applyUnifyKernels(ModuleOp &module1, ModuleOp &module2, + MLIRContext *context) { auto funcOps1 = module1.getOps(); auto funcOps2 = module2.getOps(); From 40b4aa2936d022c4fc184812ba379a1095e293ed Mon Sep 17 00:00:00 2001 From: Han Meng Date: Sat, 21 Sep 2024 23:06:43 -0400 Subject: [PATCH 6/7] remove iostream inclusion --- lib/Transforms/UnifyKernels.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/Transforms/UnifyKernels.cpp b/lib/Transforms/UnifyKernels.cpp index 58049bac..1ac78fe1 100644 --- a/lib/Transforms/UnifyKernels.cpp +++ b/lib/Transforms/UnifyKernels.cpp @@ -3,8 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include - #include "PassDetail.h" #include "hcl/Dialect/HeteroCLDialect.h" From 06be7177557a02cffc21d0dfcd32aa7fd4ebc422 Mon Sep 17 00:00:00 2001 From: Han Meng Date: Tue, 24 Sep 2024 21:07:41 -0400 Subject: [PATCH 7/7] support inst as an array --- lib/Transforms/UnifyKernels.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/lib/Transforms/UnifyKernels.cpp b/lib/Transforms/UnifyKernels.cpp index 1ac78fe1..630d0701 100644 --- a/lib/Transforms/UnifyKernels.cpp +++ b/lib/Transforms/UnifyKernels.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/OperationSupport.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/InitAllDialects.h" using namespace mlir; @@ -207,18 +208,23 @@ func::FuncOp unifyKernels(func::FuncOp &func1, func::FuncOp &func2, SmallVector newInputTypes(oldInputTypes.begin(), oldInputTypes.end()); auto newOutputTypes = oldFuncType.getResults(); - Type instType = builder.getI1Type(); - newInputTypes.push_back(instType); + Type i1Type = builder.getI1Type(); + Type memrefType = MemRefType::get({2}, i1Type); + newInputTypes.push_back(memrefType); auto newFuncType = builder.getFunctionType(newInputTypes, newOutputTypes); auto newFuncOp = func::FuncOp::create(loc, newFuncName, newFuncType, ArrayRef{}); // Create new block for insertion Block *entryBlock = newFuncOp.addEntryBlock(); - auto conditionArg = - entryBlock->getArgument(entryBlock->getNumArguments() - 1); + auto inst = entryBlock->getArgument(entryBlock->getNumArguments() - 1); builder.setInsertionPointToStart(entryBlock); + auto outterLoop = builder.create(loc, 0, 2, 1); + mlir::Value loopIndex = outterLoop.getInductionVar(); + builder.setInsertionPointToStart(outterLoop.getBody()); + mlir::Value conditionArg = builder.create(loc, inst, loopIndex); + auto &block1 = func1.front(); auto &block2 = func2.front(); auto block1It = block1.begin(); @@ -295,6 +301,7 @@ func::FuncOp unifyKernels(func::FuncOp &func1, func::FuncOp &func2, // Create returnOp // Todo: Now assume the return value is the same + builder.setInsertionPointToEnd(entryBlock); builder.clone(*block1It, mapping1); return newFuncOp;