diff --git a/include/hcl/Transforms/Passes.h b/include/hcl/Transforms/Passes.h index be8f9426..1af6cce6 100644 --- a/include/hcl/Transforms/Passes.h +++ b/include/hcl/Transforms/Passes.h @@ -28,6 +28,7 @@ bool applyLegalizeCast(ModuleOp &module); bool applyRemoveStrideMap(ModuleOp &module); bool applyMemRefDCE(ModuleOp &module); bool applyDataPlacement(ModuleOp &module); +ModuleOp applyUnifyKernels(ModuleOp &module1, ModuleOp &module2, MLIRContext *context); /// Registers all HCL transformation passes void registerHCLPasses(); diff --git a/lib/Bindings/Python/HCLModule.cpp b/lib/Bindings/Python/HCLModule.cpp index ecb00c83..000160b9 100644 --- a/lib/Bindings/Python/HCLModule.cpp +++ b/lib/Bindings/Python/HCLModule.cpp @@ -157,6 +157,14 @@ static bool memRefDCE(MlirModule &mlir_mod) { return applyMemRefDCE(mod); } +static MlirModule UnifyKernels(MlirModule &mlir_mod1, MlirModule &mlir_mod2, + MlirContext &mlir_context) { + auto mod1 = unwrap(mlir_mod1); + auto mod2 = unwrap(mlir_mod2); + auto context = unwrap(mlir_context); + return wrap(applyUnifyKernels(mod1, mod2, context)); +} + //===----------------------------------------------------------------------===// // HCL Python module definition //===----------------------------------------------------------------------===// @@ -259,4 +267,5 @@ PYBIND11_MODULE(_hcl, m) { // Utility pass APIs. hcl_m.def("memref_dce", &memRefDCE); + hcl_m.def("unify_kernels", &UnifyKernels); } diff --git a/lib/Transforms/CMakeLists.txt b/lib/Transforms/CMakeLists.txt index f2b87b1b..7002fea9 100644 --- a/lib/Transforms/CMakeLists.txt +++ b/lib/Transforms/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_library(MLIRHCLPasses MemRefDCE.cpp DataPlacement.cpp TransformInterpreter.cpp + UnifyKernels.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/hcl diff --git a/lib/Transforms/UnifyKernels.cpp b/lib/Transforms/UnifyKernels.cpp new file mode 100644 index 00000000..630d0701 --- /dev/null +++ b/lib/Transforms/UnifyKernels.cpp @@ -0,0 +1,336 @@ +/* + * Copyright HeteroCL authors. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "PassDetail.h" + +#include "hcl/Dialect/HeteroCLDialect.h" +#include "hcl/Dialect/HeteroCLOps.h" +#include "hcl/Dialect/HeteroCLTypes.h" +#include "hcl/Dialect/TransformOps/HCLTransformOps.h" +#include "hcl/Transforms/Passes.h" + +#include "hcl-c/Dialect/Dialects.h" +#include "mlir/CAPI/IR.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/InitAllDialects.h" + +using namespace mlir; +using namespace hcl; + +namespace mlir { +namespace hcl { + +bool compareAffineExprs(AffineExpr lhsExpr, AffineExpr rhsExpr) { + // Compare the kinds of affine exprs + if (lhsExpr.getKind() != rhsExpr.getKind()) { + return false; + } + + // Compare affine exprs based on kind + switch (lhsExpr.getKind()) { + case AffineExprKind::Constant: { + auto lhsConst = lhsExpr.cast(); + auto rhsConst = rhsExpr.cast(); + return lhsConst.getValue() == rhsConst.getValue(); + } + case AffineExprKind::DimId: { + auto lhsDim = lhsExpr.cast(); + auto rhsDim = rhsExpr.cast(); + return lhsDim.getPosition() == rhsDim.getPosition(); + } + case AffineExprKind::SymbolId: { + auto lhsSymbol = lhsExpr.cast(); + auto rhsSymbol = rhsExpr.cast(); + return lhsSymbol.getPosition() == rhsSymbol.getPosition(); + } + case AffineExprKind::Add: + case AffineExprKind::Mul: + case AffineExprKind::Mod: + case AffineExprKind::FloorDiv: + case AffineExprKind::CeilDiv: { + auto lhsBinary = lhsExpr.cast(); + auto rhsBinary = rhsExpr.cast(); + return compareAffineExprs(lhsBinary.getLHS(), rhsBinary.getLHS()) && + compareAffineExprs(lhsBinary.getRHS(), rhsBinary.getRHS()); + } + } + return false; +} + +bool compareAffineMaps(AffineMap lhsMap, AffineMap rhsMap) { + AffineMap simplifiedLhsMap = simplifyAffineMap(lhsMap); + AffineMap simplifiedRhsMap = simplifyAffineMap(rhsMap); + + if (simplifiedLhsMap.getNumDims() != simplifiedRhsMap.getNumDims() && + simplifiedLhsMap.getNumSymbols() != simplifiedRhsMap.getNumSymbols() && + simplifiedLhsMap.getNumResults() != simplifiedRhsMap.getNumResults()) + return false; + + // Compare exprs + for (unsigned i = 0; i < simplifiedLhsMap.getNumResults(); ++i) { + if (!compareAffineExprs(simplifiedLhsMap.getResult(i), + simplifiedRhsMap.getResult(i))) { + return false; + } + } + + // Todo: Might need to compare operands or use evaluation to compare + return true; +} + +bool compareAffineForOps(affine::AffineForOp &affineForOp1, + affine::AffineForOp &affineForOp2) { + if (affineForOp1 == affineForOp2) + return true; + + if (affineForOp1.getStep() != affineForOp2.getStep()) + return false; + if (!compareAffineMaps(affineForOp1.getLowerBoundMap(), + affineForOp2.getLowerBoundMap()) || + !compareAffineMaps(affineForOp1.getUpperBoundMap(), + affineForOp2.getUpperBoundMap())) + return false; + return true; +} + +void mergeLoop(OpBuilder &builder, affine::AffineForOp &op1, + affine::AffineForOp &op2, IRMapping &mapping1, + IRMapping &mapping2, Value conditionArg, bool &foundDifference) { + auto loc = op1->getLoc(); + + // Save insertion point + OpBuilder::InsertionGuard guard(builder); + + // Create new affine.for with same arguments + auto lowerBoundMap = op1.getLowerBoundMap(); + auto lowerBoundOperands = llvm::SmallVector( + op1.getLowerBoundOperands().begin(), op1.getLowerBoundOperands().end()); + auto upperBoundMap = op1.getUpperBoundMap(); + auto upperBoundOperands = llvm::SmallVector( + op1.getUpperBoundOperands().begin(), op1.getUpperBoundOperands().end()); + int64_t step = op1.getStep(); + + auto newAffineForOp = builder.create( + loc, lowerBoundOperands, lowerBoundMap, upperBoundOperands, upperBoundMap, + step); + + Block *body1 = op1.getBody(); + Block *body2 = op2.getBody(); + Block *newBody = newAffineForOp.getBody(); + + // Set insertion point to current loop body + builder.setInsertionPointToStart(newBody); + + // Add IRMapping for latter cloning + for (size_t i = 0; i < body1->getNumArguments(); ++i) { + mapping1.map(body1->getArgument(i), newBody->getArgument(i)); + } + for (size_t i = 0; i < body2->getNumArguments(); ++i) { + mapping2.map(body2->getArgument(i), newBody->getArgument(i)); + } + + auto body1It = body1->begin(); + auto body2It = body2->begin(); + + // Iterate over two FuncOps to find branch location + while (body1It != body1->end() && body2It != body2->end()) { + if (!foundDifference) { + if (!(&(*body1It) == &(*body2It))) { + // If we found an affine.for to merge + // Todo: Support dynamic loop range + auto affineForOp1 = dyn_cast(&(*body1It)); + auto affineForOp2 = dyn_cast(&(*body2It)); + if (affineForOp1 && affineForOp2 && + compareAffineForOps(affineForOp1, affineForOp2)) { + mergeLoop(builder, affineForOp1, affineForOp2, mapping1, mapping2, + conditionArg, foundDifference); + } else { + foundDifference = true; + break; + } + } else { + builder.clone(*body1It); + } + ++body1It; + ++body2It; + } else { + break; + } + } + + // Create branch for the rest after difference is found + builder.create( + loc, conditionArg, + [&](OpBuilder &thenBuilder, Location thenLoc) { + while (body1It != body1->end()) { + auto &op = *body1It; + if (auto yieldOp = dyn_cast(&op)) { + break; + } + thenBuilder.clone(*body1It, mapping1); + ++body1It; + } + thenBuilder.create(thenLoc); + }, + [&](OpBuilder &elseBuilder, Location elseLoc) { + while (body2It != body2->end()) { + auto &op = *body2It; + if (auto yieldOp = dyn_cast(&op)) { + break; + } + elseBuilder.clone(*body2It, mapping2); + ++body2It; + } + elseBuilder.create(elseLoc); + }); +} + +func::FuncOp unifyKernels(func::FuncOp &func1, func::FuncOp &func2, + OpBuilder &builder) { + std::string newFuncName = + func1.getName().str() + "_" + func2.getName().str() + "_unified"; + + // Todo: Now assuming return types and input types are the same + // Create new FuncOp with additional parameter + auto oldFuncType = func1.getFunctionType(); + auto oldInputTypes = oldFuncType.getInputs(); + auto loc = builder.getUnknownLoc(); + SmallVector newInputTypes(oldInputTypes.begin(), + oldInputTypes.end()); + auto newOutputTypes = oldFuncType.getResults(); + Type i1Type = builder.getI1Type(); + Type memrefType = MemRefType::get({2}, i1Type); + newInputTypes.push_back(memrefType); + auto newFuncType = builder.getFunctionType(newInputTypes, newOutputTypes); + auto newFuncOp = func::FuncOp::create(loc, newFuncName, newFuncType, + ArrayRef{}); + + // Create new block for insertion + Block *entryBlock = newFuncOp.addEntryBlock(); + auto inst = entryBlock->getArgument(entryBlock->getNumArguments() - 1); + builder.setInsertionPointToStart(entryBlock); + + auto outterLoop = builder.create(loc, 0, 2, 1); + mlir::Value loopIndex = outterLoop.getInductionVar(); + builder.setInsertionPointToStart(outterLoop.getBody()); + mlir::Value conditionArg = builder.create(loc, inst, loopIndex); + + auto &block1 = func1.front(); + auto &block2 = func2.front(); + auto block1It = block1.begin(); + auto block2It = block2.begin(); + bool foundDifference = false; + + // Create IRMapping for latter cloning + IRMapping mapping1; + for (size_t i = 0; i < block1.getNumArguments(); ++i) { + mapping1.map(block1.getArgument(i), entryBlock->getArgument(i)); + } + IRMapping mapping2; + for (size_t i = 0; i < block2.getNumArguments(); ++i) { + mapping2.map(block2.getArgument(i), entryBlock->getArgument(i)); + } + + // Iterate over two FuncOps to find branch location + while (block1It != block1.end() && block2It != block2.end()) { + if (!foundDifference) { + if (!(&(*block1It) == &(*block2It))) { + // If we found an affine.for to merge + // Todo: Support dynamic loop range + auto affineForOp1 = dyn_cast(&(*block1It)); + auto affineForOp2 = dyn_cast(&(*block2It)); + if (affineForOp1 && affineForOp2 && + compareAffineForOps(affineForOp1, affineForOp2)) { + mergeLoop(builder, affineForOp1, affineForOp2, mapping1, mapping2, + conditionArg, foundDifference); + } else { + foundDifference = true; + break; + } + } else { + builder.clone(*block1It); + } + ++block1It; + ++block2It; + } else { + break; + } + } + + auto &op1 = *block1It; + auto &op2 = *block2It; + auto returnOp1 = dyn_cast(&op1); + auto returnOp2 = dyn_cast(&op2); + // Create branch for the rest after difference is found + if (!returnOp1 || !returnOp2) { + builder.create( + loc, conditionArg, + [&](OpBuilder &thenBuilder, Location thenLoc) { + while (block1It != block1.end()) { + auto &op = *block1It; + if (auto returnOp = dyn_cast(&op)) { + break; + } + thenBuilder.clone(*block1It, mapping1); + ++block1It; + } + thenBuilder.create(thenLoc); + }, + [&](OpBuilder &elseBuilder, Location elseLoc) { + while (block2It != block2.end()) { + auto &op = *block2It; + if (auto returnOp = dyn_cast(&op)) { + break; + } + elseBuilder.clone(*block2It, mapping2); + ++block2It; + } + elseBuilder.create(elseLoc); + }); + } + + // Create returnOp + // Todo: Now assume the return value is the same + builder.setInsertionPointToEnd(entryBlock); + builder.clone(*block1It, mapping1); + + return newFuncOp; +} + +/// Pass entry point +ModuleOp applyUnifyKernels(ModuleOp &module1, ModuleOp &module2, + MLIRContext *context) { + auto funcOps1 = module1.getOps(); + auto funcOps2 = module2.getOps(); + + auto it1 = funcOps1.begin(); + auto it2 = funcOps2.begin(); + + ModuleOp newModule = ModuleOp::create(UnknownLoc::get(context)); + OpBuilder builder(newModule.getContext()); + + while (it1 != funcOps1.end() && it2 != funcOps2.end()) { + func::FuncOp funcOp1 = *it1; + func::FuncOp funcOp2 = *it2; + func::FuncOp newFuncOp = unifyKernels(funcOp1, funcOp2, builder); + newModule.push_back(newFuncOp); + + ++it1; + ++it2; + } + + return newModule; +} + +} // namespace hcl +} // namespace mlir \ No newline at end of file