-
Notifications
You must be signed in to change notification settings - Fork 15
Unify kernels for basic and nested affine for loop #217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
979e0f6
fd984f2
2981d03
3d0f024
8dec902
40b4aa2
06be717
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,336 @@ | ||
/* | ||
* Copyright HeteroCL authors. All Rights Reserved. | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
#include "PassDetail.h" | ||
|
||
#include "hcl/Dialect/HeteroCLDialect.h" | ||
#include "hcl/Dialect/HeteroCLOps.h" | ||
#include "hcl/Dialect/HeteroCLTypes.h" | ||
#include "hcl/Dialect/TransformOps/HCLTransformOps.h" | ||
#include "hcl/Transforms/Passes.h" | ||
|
||
#include "hcl-c/Dialect/Dialects.h" | ||
#include "mlir/CAPI/IR.h" | ||
|
||
#include "mlir/Dialect/Affine/IR/AffineOps.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/SCF/IR/SCF.h" | ||
#include "mlir/IR/AffineMap.h" | ||
#include "mlir/IR/IRMapping.h" | ||
#include "mlir/IR/OperationSupport.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/InitAllDialects.h" | ||
|
||
using namespace mlir; | ||
using namespace hcl; | ||
|
||
namespace mlir { | ||
namespace hcl { | ||
|
||
bool compareAffineExprs(AffineExpr lhsExpr, AffineExpr rhsExpr) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One question, did you do canonicalization somewhere? If I pass in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yeah, I think that would be an issue for now. As I commented "Todo" in compareAffineMaps function, this is now a temporary solution. I was thinking about use a evaluation strategy to match some more sophisticated case as you mentioned. I will test and see how it work in future pr. |
||
// Compare the kinds of affine exprs | ||
if (lhsExpr.getKind() != rhsExpr.getKind()) { | ||
return false; | ||
} | ||
|
||
// Compare affine exprs based on kind | ||
switch (lhsExpr.getKind()) { | ||
case AffineExprKind::Constant: { | ||
auto lhsConst = lhsExpr.cast<AffineConstantExpr>(); | ||
auto rhsConst = rhsExpr.cast<AffineConstantExpr>(); | ||
return lhsConst.getValue() == rhsConst.getValue(); | ||
} | ||
case AffineExprKind::DimId: { | ||
auto lhsDim = lhsExpr.cast<AffineDimExpr>(); | ||
auto rhsDim = rhsExpr.cast<AffineDimExpr>(); | ||
return lhsDim.getPosition() == rhsDim.getPosition(); | ||
} | ||
case AffineExprKind::SymbolId: { | ||
auto lhsSymbol = lhsExpr.cast<AffineSymbolExpr>(); | ||
auto rhsSymbol = rhsExpr.cast<AffineSymbolExpr>(); | ||
return lhsSymbol.getPosition() == rhsSymbol.getPosition(); | ||
} | ||
case AffineExprKind::Add: | ||
case AffineExprKind::Mul: | ||
case AffineExprKind::Mod: | ||
case AffineExprKind::FloorDiv: | ||
case AffineExprKind::CeilDiv: { | ||
auto lhsBinary = lhsExpr.cast<AffineBinaryOpExpr>(); | ||
auto rhsBinary = rhsExpr.cast<AffineBinaryOpExpr>(); | ||
return compareAffineExprs(lhsBinary.getLHS(), rhsBinary.getLHS()) && | ||
compareAffineExprs(lhsBinary.getRHS(), rhsBinary.getRHS()); | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
bool compareAffineMaps(AffineMap lhsMap, AffineMap rhsMap) { | ||
AffineMap simplifiedLhsMap = simplifyAffineMap(lhsMap); | ||
AffineMap simplifiedRhsMap = simplifyAffineMap(rhsMap); | ||
|
||
if (simplifiedLhsMap.getNumDims() != simplifiedRhsMap.getNumDims() && | ||
simplifiedLhsMap.getNumSymbols() != simplifiedRhsMap.getNumSymbols() && | ||
simplifiedLhsMap.getNumResults() != simplifiedRhsMap.getNumResults()) | ||
return false; | ||
|
||
// Compare exprs | ||
for (unsigned i = 0; i < simplifiedLhsMap.getNumResults(); ++i) { | ||
if (!compareAffineExprs(simplifiedLhsMap.getResult(i), | ||
simplifiedRhsMap.getResult(i))) { | ||
return false; | ||
} | ||
} | ||
|
||
// Todo: Might need to compare operands or use evaluation to compare | ||
return true; | ||
} | ||
|
||
bool compareAffineForOps(affine::AffineForOp &affineForOp1, | ||
affine::AffineForOp &affineForOp2) { | ||
if (affineForOp1 == affineForOp2) | ||
return true; | ||
|
||
if (affineForOp1.getStep() != affineForOp2.getStep()) | ||
return false; | ||
if (!compareAffineMaps(affineForOp1.getLowerBoundMap(), | ||
affineForOp2.getLowerBoundMap()) || | ||
!compareAffineMaps(affineForOp1.getUpperBoundMap(), | ||
affineForOp2.getUpperBoundMap())) | ||
return false; | ||
return true; | ||
} | ||
|
||
void mergeLoop(OpBuilder &builder, affine::AffineForOp &op1, | ||
affine::AffineForOp &op2, IRMapping &mapping1, | ||
IRMapping &mapping2, Value conditionArg, bool &foundDifference) { | ||
auto loc = op1->getLoc(); | ||
|
||
// Save insertion point | ||
OpBuilder::InsertionGuard guard(builder); | ||
|
||
// Create new affine.for with same arguments | ||
auto lowerBoundMap = op1.getLowerBoundMap(); | ||
auto lowerBoundOperands = llvm::SmallVector<Value, 4>( | ||
op1.getLowerBoundOperands().begin(), op1.getLowerBoundOperands().end()); | ||
auto upperBoundMap = op1.getUpperBoundMap(); | ||
auto upperBoundOperands = llvm::SmallVector<Value, 4>( | ||
op1.getUpperBoundOperands().begin(), op1.getUpperBoundOperands().end()); | ||
int64_t step = op1.getStep(); | ||
|
||
auto newAffineForOp = builder.create<mlir::affine::AffineForOp>( | ||
loc, lowerBoundOperands, lowerBoundMap, upperBoundOperands, upperBoundMap, | ||
step); | ||
|
||
Block *body1 = op1.getBody(); | ||
Block *body2 = op2.getBody(); | ||
Block *newBody = newAffineForOp.getBody(); | ||
|
||
// Set insertion point to current loop body | ||
builder.setInsertionPointToStart(newBody); | ||
|
||
// Add IRMapping for latter cloning | ||
for (size_t i = 0; i < body1->getNumArguments(); ++i) { | ||
mapping1.map(body1->getArgument(i), newBody->getArgument(i)); | ||
} | ||
for (size_t i = 0; i < body2->getNumArguments(); ++i) { | ||
mapping2.map(body2->getArgument(i), newBody->getArgument(i)); | ||
} | ||
|
||
auto body1It = body1->begin(); | ||
auto body2It = body2->begin(); | ||
|
||
// Iterate over two FuncOps to find branch location | ||
while (body1It != body1->end() && body2It != body2->end()) { | ||
if (!foundDifference) { | ||
if (!(&(*body1It) == &(*body2It))) { | ||
// If we found an affine.for to merge | ||
// Todo: Support dynamic loop range | ||
auto affineForOp1 = dyn_cast<affine::AffineForOp>(&(*body1It)); | ||
auto affineForOp2 = dyn_cast<affine::AffineForOp>(&(*body2It)); | ||
if (affineForOp1 && affineForOp2 && | ||
compareAffineForOps(affineForOp1, affineForOp2)) { | ||
mergeLoop(builder, affineForOp1, affineForOp2, mapping1, mapping2, | ||
conditionArg, foundDifference); | ||
} else { | ||
foundDifference = true; | ||
break; | ||
} | ||
} else { | ||
builder.clone(*body1It); | ||
} | ||
++body1It; | ||
++body2It; | ||
} else { | ||
break; | ||
} | ||
} | ||
|
||
// Create branch for the rest after difference is found | ||
builder.create<scf::IfOp>( | ||
loc, conditionArg, | ||
[&](OpBuilder &thenBuilder, Location thenLoc) { | ||
while (body1It != body1->end()) { | ||
auto &op = *body1It; | ||
if (auto yieldOp = dyn_cast<affine::AffineYieldOp>(&op)) { | ||
break; | ||
} | ||
thenBuilder.clone(*body1It, mapping1); | ||
++body1It; | ||
} | ||
thenBuilder.create<scf::YieldOp>(thenLoc); | ||
}, | ||
[&](OpBuilder &elseBuilder, Location elseLoc) { | ||
while (body2It != body2->end()) { | ||
auto &op = *body2It; | ||
if (auto yieldOp = dyn_cast<affine::AffineYieldOp>(&op)) { | ||
break; | ||
} | ||
elseBuilder.clone(*body2It, mapping2); | ||
++body2It; | ||
} | ||
elseBuilder.create<scf::YieldOp>(elseLoc); | ||
}); | ||
} | ||
|
||
func::FuncOp unifyKernels(func::FuncOp &func1, func::FuncOp &func2, | ||
OpBuilder &builder) { | ||
std::string newFuncName = | ||
func1.getName().str() + "_" + func2.getName().str() + "_unified"; | ||
|
||
// Todo: Now assuming return types and input types are the same | ||
// Create new FuncOp with additional parameter | ||
auto oldFuncType = func1.getFunctionType(); | ||
auto oldInputTypes = oldFuncType.getInputs(); | ||
auto loc = builder.getUnknownLoc(); | ||
SmallVector<Type, 4> newInputTypes(oldInputTypes.begin(), | ||
oldInputTypes.end()); | ||
auto newOutputTypes = oldFuncType.getResults(); | ||
Type i1Type = builder.getI1Type(); | ||
Type memrefType = MemRefType::get({2}, i1Type); | ||
newInputTypes.push_back(memrefType); | ||
auto newFuncType = builder.getFunctionType(newInputTypes, newOutputTypes); | ||
auto newFuncOp = func::FuncOp::create(loc, newFuncName, newFuncType, | ||
ArrayRef<NamedAttribute>{}); | ||
|
||
// Create new block for insertion | ||
Block *entryBlock = newFuncOp.addEntryBlock(); | ||
auto inst = entryBlock->getArgument(entryBlock->getNumArguments() - 1); | ||
builder.setInsertionPointToStart(entryBlock); | ||
|
||
auto outterLoop = builder.create<mlir::affine::AffineForOp>(loc, 0, 2, 1); | ||
mlir::Value loopIndex = outterLoop.getInductionVar(); | ||
builder.setInsertionPointToStart(outterLoop.getBody()); | ||
mlir::Value conditionArg = builder.create<mlir::affine::AffineLoadOp>(loc, inst, loopIndex); | ||
|
||
auto &block1 = func1.front(); | ||
auto &block2 = func2.front(); | ||
auto block1It = block1.begin(); | ||
auto block2It = block2.begin(); | ||
bool foundDifference = false; | ||
|
||
// Create IRMapping for latter cloning | ||
IRMapping mapping1; | ||
for (size_t i = 0; i < block1.getNumArguments(); ++i) { | ||
mapping1.map(block1.getArgument(i), entryBlock->getArgument(i)); | ||
} | ||
IRMapping mapping2; | ||
for (size_t i = 0; i < block2.getNumArguments(); ++i) { | ||
mapping2.map(block2.getArgument(i), entryBlock->getArgument(i)); | ||
} | ||
|
||
// Iterate over two FuncOps to find branch location | ||
while (block1It != block1.end() && block2It != block2.end()) { | ||
if (!foundDifference) { | ||
if (!(&(*block1It) == &(*block2It))) { | ||
// If we found an affine.for to merge | ||
// Todo: Support dynamic loop range | ||
auto affineForOp1 = dyn_cast<affine::AffineForOp>(&(*block1It)); | ||
auto affineForOp2 = dyn_cast<affine::AffineForOp>(&(*block2It)); | ||
if (affineForOp1 && affineForOp2 && | ||
compareAffineForOps(affineForOp1, affineForOp2)) { | ||
mergeLoop(builder, affineForOp1, affineForOp2, mapping1, mapping2, | ||
conditionArg, foundDifference); | ||
} else { | ||
foundDifference = true; | ||
break; | ||
} | ||
} else { | ||
builder.clone(*block1It); | ||
} | ||
++block1It; | ||
++block2It; | ||
} else { | ||
break; | ||
} | ||
} | ||
|
||
auto &op1 = *block1It; | ||
auto &op2 = *block2It; | ||
auto returnOp1 = dyn_cast<func::ReturnOp>(&op1); | ||
auto returnOp2 = dyn_cast<func::ReturnOp>(&op2); | ||
// Create branch for the rest after difference is found | ||
if (!returnOp1 || !returnOp2) { | ||
builder.create<scf::IfOp>( | ||
loc, conditionArg, | ||
[&](OpBuilder &thenBuilder, Location thenLoc) { | ||
while (block1It != block1.end()) { | ||
auto &op = *block1It; | ||
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) { | ||
break; | ||
} | ||
thenBuilder.clone(*block1It, mapping1); | ||
++block1It; | ||
} | ||
thenBuilder.create<scf::YieldOp>(thenLoc); | ||
}, | ||
[&](OpBuilder &elseBuilder, Location elseLoc) { | ||
while (block2It != block2.end()) { | ||
auto &op = *block2It; | ||
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) { | ||
break; | ||
} | ||
elseBuilder.clone(*block2It, mapping2); | ||
++block2It; | ||
} | ||
elseBuilder.create<scf::YieldOp>(elseLoc); | ||
}); | ||
} | ||
|
||
// Create returnOp | ||
// Todo: Now assume the return value is the same | ||
builder.setInsertionPointToEnd(entryBlock); | ||
builder.clone(*block1It, mapping1); | ||
|
||
return newFuncOp; | ||
} | ||
|
||
/// Pass entry point | ||
ModuleOp applyUnifyKernels(ModuleOp &module1, ModuleOp &module2, | ||
MLIRContext *context) { | ||
auto funcOps1 = module1.getOps<func::FuncOp>(); | ||
auto funcOps2 = module2.getOps<func::FuncOp>(); | ||
|
||
auto it1 = funcOps1.begin(); | ||
auto it2 = funcOps2.begin(); | ||
|
||
ModuleOp newModule = ModuleOp::create(UnknownLoc::get(context)); | ||
OpBuilder builder(newModule.getContext()); | ||
|
||
while (it1 != funcOps1.end() && it2 != funcOps2.end()) { | ||
func::FuncOp funcOp1 = *it1; | ||
func::FuncOp funcOp2 = *it2; | ||
func::FuncOp newFuncOp = unifyKernels(funcOp1, funcOp2, builder); | ||
newModule.push_back(newFuncOp); | ||
|
||
++it1; | ||
++it2; | ||
} | ||
|
||
return newModule; | ||
} | ||
|
||
} // namespace hcl | ||
} // namespace mlir |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit weird that we pass in MLIRContext as the third argument. Can we make it a return module? Otherwise, providing some comments here describing what this
mlir_context
argument means is necessary.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, I'm thinking in the frontend, we will create a context and use it to create module1 and module2. Then, we called this primitive. This third argument is definitely removable, because we can just extract the context of module1, which will be the same context. We can modify this interface later according to our design choice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, sounds good