Skip to content

Commit eea1035

Browse files
Jezurkofrabert
authored andcommitted
[MLIR][mlir-link] Add hook to perform per Module computation before summary
This allows the linker interface to pre-compute some information that might be necessary for the linking.
1 parent e08816f commit eea1035

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

mlir/include/mlir/Linker/LinkerInterface.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,12 @@ class SymbolLinkerInterface : public LinkerInterface<SymbolLinkerInterface> {
149149
return state.clone(src);
150150
}
151151

152+
/// Perform tasks that need to be computed on whole-module basis before actual summary.
153+
/// E.g. Pre-compute COMDAT resolution before actually linking the modules.
154+
virtual LogicalResult moduleOpSummary(ModuleOp module) {
155+
return success();
156+
}
157+
152158
/// Dependencies of the given operation required to be linked.
153159
virtual SmallVector<Operation *>
154160
dependencies(Operation *op, SymbolTableCollection &collection) const = 0;
@@ -286,6 +292,14 @@ class SymbolLinkerInterfaces {
286292
return Conflict::noConflict(src);
287293
}
288294

295+
LogicalResult moduleOpSummary(ModuleOp src) {
296+
for (SymbolLinkerInterface *linker : interfaces) {
297+
if (failed(linker->moduleOpSummary(src)))
298+
return failure();
299+
}
300+
return success();
301+
}
302+
289303
private:
290304
SetVector<SymbolLinkerInterface *> interfaces;
291305
};

mlir/lib/IR/BuiltinLinkerInterface.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,20 @@ class BuiltinLinkerInterface : public ModuleLinkerInterface {
3939
SymbolTableCollection &collection) override {
4040
// Collect all operations to process in parallel
4141
SmallVector<Operation *> ops;
42-
src.walk([&](Operation *op) {
43-
if (op != src)
44-
ops.push_back(op);
42+
WalkResult result = src.walk([&](Operation *op) {
43+
if (op == src) {
44+
if (symbolLinkers.moduleOpSummary(src).failed())
45+
return WalkResult::interrupt();
46+
return WalkResult::advance();
47+
}
48+
ops.push_back(op);
49+
return WalkResult::advance();
4550
});
4651

52+
if (result.wasInterrupted()) {
53+
return failure();
54+
}
55+
4756
// Process operations in parallel
4857
return failableParallelForEach(
4958
src.getContext(), ops, [&](Operation *op) {

0 commit comments

Comments
 (0)