Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/hcl/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ bool applyLegalizeCast(ModuleOp &module);
bool applyRemoveStrideMap(ModuleOp &module);
bool applyMemRefDCE(ModuleOp &module);
bool applyDataPlacement(ModuleOp &module);
ModuleOp applyUnifyKernels(ModuleOp &module1, ModuleOp &module2, MLIRContext *context);

/// Registers all HCL transformation passes
void registerHCLPasses();
Expand Down
9 changes: 9 additions & 0 deletions lib/Bindings/Python/HCLModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ static bool memRefDCE(MlirModule &mlir_mod) {
return applyMemRefDCE(mod);
}

static MlirModule UnifyKernels(MlirModule &mlir_mod1, MlirModule &mlir_mod2,
MlirContext &mlir_context) {
Comment on lines +160 to +161
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit weird that we pass in MLIRContext as the third argument. Can we make it a return module? Otherwise, providing some comments here describing what this mlir_context argument means is necessary.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit weird that we pass in MLIRContext as the third argument. Can we make it a return module? Otherwise, providing some comments here describing what this mlir_context argument means is necessary.

For now, I'm thinking in the frontend, we will create a context and use it to create module1 and module2. Then, we called this primitive. This third argument is definitely removable, because we can just extract the context of module1, which will be the same context. We can modify this interface later according to our design choice.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, sounds good

auto mod1 = unwrap(mlir_mod1);
auto mod2 = unwrap(mlir_mod2);
auto context = unwrap(mlir_context);
return wrap(applyUnifyKernels(mod1, mod2, context));
}

//===----------------------------------------------------------------------===//
// HCL Python module definition
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -259,4 +267,5 @@ PYBIND11_MODULE(_hcl, m) {

// Utility pass APIs.
hcl_m.def("memref_dce", &memRefDCE);
hcl_m.def("unify_kernels", &UnifyKernels);
}
1 change: 1 addition & 0 deletions lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_library(MLIRHCLPasses
MemRefDCE.cpp
DataPlacement.cpp
TransformInterpreter.cpp
UnifyKernels.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/hcl
Expand Down
336 changes: 336 additions & 0 deletions lib/Transforms/UnifyKernels.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,336 @@
/*
* Copyright HeteroCL authors. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

#include "PassDetail.h"

#include "hcl/Dialect/HeteroCLDialect.h"
#include "hcl/Dialect/HeteroCLOps.h"
#include "hcl/Dialect/HeteroCLTypes.h"
#include "hcl/Dialect/TransformOps/HCLTransformOps.h"
#include "hcl/Transforms/Passes.h"

#include "hcl-c/Dialect/Dialects.h"
#include "mlir/CAPI/IR.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/InitAllDialects.h"

using namespace mlir;
using namespace hcl;

namespace mlir {
namespace hcl {

bool compareAffineExprs(AffineExpr lhsExpr, AffineExpr rhsExpr) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question, did you do canonicalization somewhere? If I pass in d0+1 and 1+d0, will this function return true?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question, did you do canonicalization somewhere? If I pass in d0+1 and 1+d0, will this function return true?

Yeah, I think that would be an issue for now. As I commented "Todo" in compareAffineMaps function, this is now a temporary solution. I was thinking about use a evaluation strategy to match some more sophisticated case as you mentioned. I will test and see how it work in future pr.

// Compare the kinds of affine exprs
if (lhsExpr.getKind() != rhsExpr.getKind()) {
return false;
}

// Compare affine exprs based on kind
switch (lhsExpr.getKind()) {
case AffineExprKind::Constant: {
auto lhsConst = lhsExpr.cast<AffineConstantExpr>();
auto rhsConst = rhsExpr.cast<AffineConstantExpr>();
return lhsConst.getValue() == rhsConst.getValue();
}
case AffineExprKind::DimId: {
auto lhsDim = lhsExpr.cast<AffineDimExpr>();
auto rhsDim = rhsExpr.cast<AffineDimExpr>();
return lhsDim.getPosition() == rhsDim.getPosition();
}
case AffineExprKind::SymbolId: {
auto lhsSymbol = lhsExpr.cast<AffineSymbolExpr>();
auto rhsSymbol = rhsExpr.cast<AffineSymbolExpr>();
return lhsSymbol.getPosition() == rhsSymbol.getPosition();
}
case AffineExprKind::Add:
case AffineExprKind::Mul:
case AffineExprKind::Mod:
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv: {
auto lhsBinary = lhsExpr.cast<AffineBinaryOpExpr>();
auto rhsBinary = rhsExpr.cast<AffineBinaryOpExpr>();
return compareAffineExprs(lhsBinary.getLHS(), rhsBinary.getLHS()) &&
compareAffineExprs(lhsBinary.getRHS(), rhsBinary.getRHS());
}
}
return false;
}

bool compareAffineMaps(AffineMap lhsMap, AffineMap rhsMap) {
AffineMap simplifiedLhsMap = simplifyAffineMap(lhsMap);
AffineMap simplifiedRhsMap = simplifyAffineMap(rhsMap);

if (simplifiedLhsMap.getNumDims() != simplifiedRhsMap.getNumDims() &&
simplifiedLhsMap.getNumSymbols() != simplifiedRhsMap.getNumSymbols() &&
simplifiedLhsMap.getNumResults() != simplifiedRhsMap.getNumResults())
return false;

// Compare exprs
for (unsigned i = 0; i < simplifiedLhsMap.getNumResults(); ++i) {
if (!compareAffineExprs(simplifiedLhsMap.getResult(i),
simplifiedRhsMap.getResult(i))) {
return false;
}
}

// Todo: Might need to compare operands or use evaluation to compare
return true;
}

bool compareAffineForOps(affine::AffineForOp &affineForOp1,
affine::AffineForOp &affineForOp2) {
if (affineForOp1 == affineForOp2)
return true;

if (affineForOp1.getStep() != affineForOp2.getStep())
return false;
if (!compareAffineMaps(affineForOp1.getLowerBoundMap(),
affineForOp2.getLowerBoundMap()) ||
!compareAffineMaps(affineForOp1.getUpperBoundMap(),
affineForOp2.getUpperBoundMap()))
return false;
return true;
}

void mergeLoop(OpBuilder &builder, affine::AffineForOp &op1,
affine::AffineForOp &op2, IRMapping &mapping1,
IRMapping &mapping2, Value conditionArg, bool &foundDifference) {
auto loc = op1->getLoc();

// Save insertion point
OpBuilder::InsertionGuard guard(builder);

// Create new affine.for with same arguments
auto lowerBoundMap = op1.getLowerBoundMap();
auto lowerBoundOperands = llvm::SmallVector<Value, 4>(
op1.getLowerBoundOperands().begin(), op1.getLowerBoundOperands().end());
auto upperBoundMap = op1.getUpperBoundMap();
auto upperBoundOperands = llvm::SmallVector<Value, 4>(
op1.getUpperBoundOperands().begin(), op1.getUpperBoundOperands().end());
int64_t step = op1.getStep();

auto newAffineForOp = builder.create<mlir::affine::AffineForOp>(
loc, lowerBoundOperands, lowerBoundMap, upperBoundOperands, upperBoundMap,
step);

Block *body1 = op1.getBody();
Block *body2 = op2.getBody();
Block *newBody = newAffineForOp.getBody();

// Set insertion point to current loop body
builder.setInsertionPointToStart(newBody);

// Add IRMapping for latter cloning
for (size_t i = 0; i < body1->getNumArguments(); ++i) {
mapping1.map(body1->getArgument(i), newBody->getArgument(i));
}
for (size_t i = 0; i < body2->getNumArguments(); ++i) {
mapping2.map(body2->getArgument(i), newBody->getArgument(i));
}

auto body1It = body1->begin();
auto body2It = body2->begin();

// Iterate over two FuncOps to find branch location
while (body1It != body1->end() && body2It != body2->end()) {
if (!foundDifference) {
if (!(&(*body1It) == &(*body2It))) {
// If we found an affine.for to merge
// Todo: Support dynamic loop range
auto affineForOp1 = dyn_cast<affine::AffineForOp>(&(*body1It));
auto affineForOp2 = dyn_cast<affine::AffineForOp>(&(*body2It));
if (affineForOp1 && affineForOp2 &&
compareAffineForOps(affineForOp1, affineForOp2)) {
mergeLoop(builder, affineForOp1, affineForOp2, mapping1, mapping2,
conditionArg, foundDifference);
} else {
foundDifference = true;
break;
}
} else {
builder.clone(*body1It);
}
++body1It;
++body2It;
} else {
break;
}
}

// Create branch for the rest after difference is found
builder.create<scf::IfOp>(
loc, conditionArg,
[&](OpBuilder &thenBuilder, Location thenLoc) {
while (body1It != body1->end()) {
auto &op = *body1It;
if (auto yieldOp = dyn_cast<affine::AffineYieldOp>(&op)) {
break;
}
thenBuilder.clone(*body1It, mapping1);
++body1It;
}
thenBuilder.create<scf::YieldOp>(thenLoc);
},
[&](OpBuilder &elseBuilder, Location elseLoc) {
while (body2It != body2->end()) {
auto &op = *body2It;
if (auto yieldOp = dyn_cast<affine::AffineYieldOp>(&op)) {
break;
}
elseBuilder.clone(*body2It, mapping2);
++body2It;
}
elseBuilder.create<scf::YieldOp>(elseLoc);
});
}

func::FuncOp unifyKernels(func::FuncOp &func1, func::FuncOp &func2,
OpBuilder &builder) {
std::string newFuncName =
func1.getName().str() + "_" + func2.getName().str() + "_unified";

// Todo: Now assuming return types and input types are the same
// Create new FuncOp with additional parameter
auto oldFuncType = func1.getFunctionType();
auto oldInputTypes = oldFuncType.getInputs();
auto loc = builder.getUnknownLoc();
SmallVector<Type, 4> newInputTypes(oldInputTypes.begin(),
oldInputTypes.end());
auto newOutputTypes = oldFuncType.getResults();
Type i1Type = builder.getI1Type();
Type memrefType = MemRefType::get({2}, i1Type);
newInputTypes.push_back(memrefType);
auto newFuncType = builder.getFunctionType(newInputTypes, newOutputTypes);
auto newFuncOp = func::FuncOp::create(loc, newFuncName, newFuncType,
ArrayRef<NamedAttribute>{});

// Create new block for insertion
Block *entryBlock = newFuncOp.addEntryBlock();
auto inst = entryBlock->getArgument(entryBlock->getNumArguments() - 1);
builder.setInsertionPointToStart(entryBlock);

auto outterLoop = builder.create<mlir::affine::AffineForOp>(loc, 0, 2, 1);
mlir::Value loopIndex = outterLoop.getInductionVar();
builder.setInsertionPointToStart(outterLoop.getBody());
mlir::Value conditionArg = builder.create<mlir::affine::AffineLoadOp>(loc, inst, loopIndex);

auto &block1 = func1.front();
auto &block2 = func2.front();
auto block1It = block1.begin();
auto block2It = block2.begin();
bool foundDifference = false;

// Create IRMapping for latter cloning
IRMapping mapping1;
for (size_t i = 0; i < block1.getNumArguments(); ++i) {
mapping1.map(block1.getArgument(i), entryBlock->getArgument(i));
}
IRMapping mapping2;
for (size_t i = 0; i < block2.getNumArguments(); ++i) {
mapping2.map(block2.getArgument(i), entryBlock->getArgument(i));
}

// Iterate over two FuncOps to find branch location
while (block1It != block1.end() && block2It != block2.end()) {
if (!foundDifference) {
if (!(&(*block1It) == &(*block2It))) {
// If we found an affine.for to merge
// Todo: Support dynamic loop range
auto affineForOp1 = dyn_cast<affine::AffineForOp>(&(*block1It));
auto affineForOp2 = dyn_cast<affine::AffineForOp>(&(*block2It));
if (affineForOp1 && affineForOp2 &&
compareAffineForOps(affineForOp1, affineForOp2)) {
mergeLoop(builder, affineForOp1, affineForOp2, mapping1, mapping2,
conditionArg, foundDifference);
} else {
foundDifference = true;
break;
}
} else {
builder.clone(*block1It);
}
++block1It;
++block2It;
} else {
break;
}
}

auto &op1 = *block1It;
auto &op2 = *block2It;
auto returnOp1 = dyn_cast<func::ReturnOp>(&op1);
auto returnOp2 = dyn_cast<func::ReturnOp>(&op2);
// Create branch for the rest after difference is found
if (!returnOp1 || !returnOp2) {
builder.create<scf::IfOp>(
loc, conditionArg,
[&](OpBuilder &thenBuilder, Location thenLoc) {
while (block1It != block1.end()) {
auto &op = *block1It;
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
break;
}
thenBuilder.clone(*block1It, mapping1);
++block1It;
}
thenBuilder.create<scf::YieldOp>(thenLoc);
},
[&](OpBuilder &elseBuilder, Location elseLoc) {
while (block2It != block2.end()) {
auto &op = *block2It;
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
break;
}
elseBuilder.clone(*block2It, mapping2);
++block2It;
}
elseBuilder.create<scf::YieldOp>(elseLoc);
});
}

// Create returnOp
// Todo: Now assume the return value is the same
builder.setInsertionPointToEnd(entryBlock);
builder.clone(*block1It, mapping1);

return newFuncOp;
}

/// Pass entry point
ModuleOp applyUnifyKernels(ModuleOp &module1, ModuleOp &module2,
MLIRContext *context) {
auto funcOps1 = module1.getOps<func::FuncOp>();
auto funcOps2 = module2.getOps<func::FuncOp>();

auto it1 = funcOps1.begin();
auto it2 = funcOps2.begin();

ModuleOp newModule = ModuleOp::create(UnknownLoc::get(context));
OpBuilder builder(newModule.getContext());

while (it1 != funcOps1.end() && it2 != funcOps2.end()) {
func::FuncOp funcOp1 = *it1;
func::FuncOp funcOp2 = *it2;
func::FuncOp newFuncOp = unifyKernels(funcOp1, funcOp2, builder);
newModule.push_back(newFuncOp);

++it1;
++it2;
}

return newModule;
}

} // namespace hcl
} // namespace mlir