Skip to content

SROA: Recognize llvm.protected.field.ptr intrinsics. #151650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: users/pcc/spr/main.sroa-recognize-llvmprotectedfieldptr-intrinsics
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions llvm/include/llvm/Analysis/PtrUseVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class PtrUseVisitorBase {

UseAndIsOffsetKnownPair UseAndIsOffsetKnown;
APInt Offset;
Value *ProtectedFieldDisc;
};

/// The worklist of to-visit uses.
Expand All @@ -158,6 +159,10 @@ class PtrUseVisitorBase {
/// The constant offset of the use if that is known.
APInt Offset;

// When this access is via an llvm.protected.field.ptr intrinsic, contains
// the second argument to the intrinsic, the discriminator.
Value *ProtectedFieldDisc;

/// @}

/// Note that the constructor is protected because this class must be a base
Expand Down Expand Up @@ -230,6 +235,7 @@ class PtrUseVisitor : protected InstVisitor<DerivedT>,
IntegerType *IntIdxTy = cast<IntegerType>(DL.getIndexType(I.getType()));
IsOffsetKnown = true;
Offset = APInt(IntIdxTy->getBitWidth(), 0);
ProtectedFieldDisc = nullptr;
PI.reset();

// Enqueue the uses of this pointer.
Expand All @@ -242,6 +248,7 @@ class PtrUseVisitor : protected InstVisitor<DerivedT>,
IsOffsetKnown = ToVisit.UseAndIsOffsetKnown.getInt();
if (IsOffsetKnown)
Offset = std::move(ToVisit.Offset);
ProtectedFieldDisc = ToVisit.ProtectedFieldDisc;

Instruction *I = cast<Instruction>(U->getUser());
static_cast<DerivedT*>(this)->visit(I);
Expand Down Expand Up @@ -300,6 +307,14 @@ class PtrUseVisitor : protected InstVisitor<DerivedT>,
case Intrinsic::lifetime_start:
case Intrinsic::lifetime_end:
return; // No-op intrinsics.

case Intrinsic::protected_field_ptr: {
if (!IsOffsetKnown)
return Base::visitIntrinsicInst(II);
ProtectedFieldDisc = II.getArgOperand(1);
enqueueUsers(II);
break;
}
}
}

Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Analysis/PtrUseVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ void detail::PtrUseVisitorBase::enqueueUsers(Value &I) {
if (VisitedUses.insert(&U).second) {
UseToVisit NewU = {
UseToVisit::UseAndIsOffsetKnownPair(&U, IsOffsetKnown),
Offset
Offset,
ProtectedFieldDisc,
};
Worklist.push_back(std::move(NewU));
}
Expand Down
61 changes: 56 additions & 5 deletions llvm/lib/Transforms/Scalar/SROA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
Expand Down Expand Up @@ -523,9 +524,10 @@ class Slice {
public:
Slice() = default;

Slice(uint64_t BeginOffset, uint64_t EndOffset, Use *U, bool IsSplittable)
Slice(uint64_t BeginOffset, uint64_t EndOffset, Use *U, bool IsSplittable,
Value *ProtectedFieldDisc)
: BeginOffset(BeginOffset), EndOffset(EndOffset),
UseAndIsSplittable(U, IsSplittable) {}
UseAndIsSplittable(U, IsSplittable), ProtectedFieldDisc(ProtectedFieldDisc) {}

uint64_t beginOffset() const { return BeginOffset; }
uint64_t endOffset() const { return EndOffset; }
Expand All @@ -538,6 +540,10 @@ class Slice {
bool isDead() const { return getUse() == nullptr; }
void kill() { UseAndIsSplittable.setPointer(nullptr); }

// When this access is via an llvm.protected.field.ptr intrinsic, contains
// the second argument to the intrinsic, the discriminator.
Value *ProtectedFieldDisc;

/// Support for ordering ranges.
///
/// This provides an ordering over ranges such that start offsets are
Expand Down Expand Up @@ -631,6 +637,9 @@ class AllocaSlices {
/// Access the dead users for this alloca.
ArrayRef<Instruction *> getDeadUsers() const { return DeadUsers; }

/// Access the PFP users for this alloca.
ArrayRef<IntrinsicInst *> getPFPUsers() const { return PFPUsers; }

/// Access Uses that should be dropped if the alloca is promotable.
ArrayRef<Use *> getDeadUsesIfPromotable() const {
return DeadUseIfPromotable;
Expand Down Expand Up @@ -691,6 +700,10 @@ class AllocaSlices {
/// they come from outside of the allocated space.
SmallVector<Instruction *, 8> DeadUsers;

/// Users that are llvm.protected.field.ptr intrinsics. These will be RAUW'd
/// to their first argument if we rewrite the alloca.
SmallVector<IntrinsicInst *, 0> PFPUsers;

/// Uses which will become dead if can promote the alloca.
SmallVector<Use *, 8> DeadUseIfPromotable;

Expand Down Expand Up @@ -1064,7 +1077,8 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> {
EndOffset = AllocSize;
}

AS.Slices.push_back(Slice(BeginOffset, EndOffset, U, IsSplittable));
AS.Slices.push_back(
Slice(BeginOffset, EndOffset, U, IsSplittable, ProtectedFieldDisc));
}

void visitBitCastInst(BitCastInst &BC) {
Expand Down Expand Up @@ -1274,6 +1288,9 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> {
return;
}

if (II.getIntrinsicID() == Intrinsic::protected_field_ptr)
AS.PFPUsers.push_back(&II);

Base::visitIntrinsicInst(II);
}

Expand Down Expand Up @@ -4682,7 +4699,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) {
NewSlices.push_back(
Slice(BaseOffset + PartOffset, BaseOffset + PartOffset + PartSize,
&PLoad->getOperandUse(PLoad->getPointerOperandIndex()),
/*IsSplittable*/ false));
/*IsSplittable*/ false, nullptr));
LLVM_DEBUG(dbgs() << " new slice [" << NewSlices.back().beginOffset()
<< ", " << NewSlices.back().endOffset()
<< "): " << *PLoad << "\n");
Expand Down Expand Up @@ -4838,10 +4855,12 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) {
LLVMContext::MD_access_group});

// Now build a new slice for the alloca.
// ProtectedFieldDisc==nullptr is a lie, but it doesn't matter because we
// already determined that all accesses are consistent.
NewSlices.push_back(
Slice(BaseOffset + PartOffset, BaseOffset + PartOffset + PartSize,
&PStore->getOperandUse(PStore->getPointerOperandIndex()),
/*IsSplittable*/ false));
/*IsSplittable*/ false, nullptr));
LLVM_DEBUG(dbgs() << " new slice [" << NewSlices.back().beginOffset()
<< ", " << NewSlices.back().endOffset()
<< "): " << *PStore << "\n");
Expand Down Expand Up @@ -5618,6 +5637,32 @@ SROA::runOnAlloca(AllocaInst &AI) {
return {Changed, CFGChanged};
}

for (auto &P : AS.partitions()) {
std::optional<Value *> ProtectedFieldDisc;
// For now, we can't split if a field is accessed both via protected
// field and not.
for (Slice &S : P) {
if (auto *II = dyn_cast<IntrinsicInst>(S.getUse()->getUser()))
if (II->getIntrinsicID() == Intrinsic::lifetime_start ||
II->getIntrinsicID() == Intrinsic::lifetime_end)
continue;
if (!ProtectedFieldDisc)
ProtectedFieldDisc = S.ProtectedFieldDisc;
if (*ProtectedFieldDisc != S.ProtectedFieldDisc)
return {Changed, CFGChanged};
}
for (Slice *S : P.splitSliceTails()) {
if (auto *II = dyn_cast<IntrinsicInst>(S->getUse()->getUser()))
if (II->getIntrinsicID() == Intrinsic::lifetime_start ||
II->getIntrinsicID() == Intrinsic::lifetime_end)
continue;
if (!ProtectedFieldDisc)
ProtectedFieldDisc = S->ProtectedFieldDisc;
if (*ProtectedFieldDisc != S->ProtectedFieldDisc)
return {Changed, CFGChanged};
}
}

// Delete all the dead users of this alloca before splitting and rewriting it.
for (Instruction *DeadUser : AS.getDeadUsers()) {
// Free up everything used by this instruction.
Expand All @@ -5635,6 +5680,12 @@ SROA::runOnAlloca(AllocaInst &AI) {
clobberUse(*DeadOp);
Changed = true;
}
for (IntrinsicInst *PFPUser : AS.getPFPUsers()) {
PFPUser->replaceAllUsesWith(PFPUser->getArgOperand(0));

DeadInsts.push_back(PFPUser);
Changed = true;
}

// No slices to split. Leave the dead alloca for a later pass to clean up.
if (AS.begin() == AS.end())
Expand Down
73 changes: 73 additions & 0 deletions llvm/test/Transforms/SROA/protected-field-pointer.ll
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use update_test_checks.py.

Please also add a test where the alloca is split but not promoted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -passes=sroa -S < %s | FileCheck %s

define void @slice(ptr %ptr1, ptr %ptr2, ptr %out1, ptr %out2) {
; CHECK-LABEL: define void @slice(
; CHECK-SAME: ptr [[PTR1:%.*]], ptr [[PTR2:%.*]], ptr [[OUT1:%.*]], ptr [[OUT2:%.*]]) {
; CHECK-NEXT: store ptr [[PTR1]], ptr [[OUT1]], align 8
; CHECK-NEXT: store ptr [[PTR2]], ptr [[OUT2]], align 8
; CHECK-NEXT: ret void
;
%alloca = alloca { ptr, ptr }

%protptrptr1.1 = call ptr @llvm.protected.field.ptr(ptr %alloca, i64 1, i1 true)
store ptr %ptr1, ptr %protptrptr1.1
%protptrptr1.2 = call ptr @llvm.protected.field.ptr(ptr %alloca, i64 1, i1 true)
%ptr1a = load ptr, ptr %protptrptr1.2

%gep = getelementptr { ptr, ptr }, ptr %alloca, i64 0, i32 1
%protptrptr2.1 = call ptr @llvm.protected.field.ptr(ptr %gep, i64 2, i1 true)
store ptr %ptr2, ptr %protptrptr2.1
%protptrptr2.2 = call ptr @llvm.protected.field.ptr(ptr %gep, i64 2, i1 true)
%ptr2a = load ptr, ptr %protptrptr2.2

store ptr %ptr1a, ptr %out1
store ptr %ptr2a, ptr %out2
ret void
}

define ptr @mixed(ptr %ptr) {
; CHECK-LABEL: define ptr @mixed(
; CHECK-SAME: ptr [[PTR:%.*]]) {
; CHECK-NEXT: [[ALLOCA:%.*]] = alloca ptr, align 8
; CHECK-NEXT: store ptr [[PTR]], ptr [[ALLOCA]], align 8
; CHECK-NEXT: [[PROTPTRPTR1_2:%.*]] = call ptr @llvm.protected.field.ptr(ptr [[ALLOCA]], i64 1, i1 true)
; CHECK-NEXT: [[PTR1A:%.*]] = load ptr, ptr [[PROTPTRPTR1_2]], align 8
; CHECK-NEXT: ret ptr [[PTR1A]]
;
%alloca = alloca ptr

store ptr %ptr, ptr %alloca
%protptrptr1.2 = call ptr @llvm.protected.field.ptr(ptr %alloca, i64 1, i1 true)
%ptr1a = load ptr, ptr %protptrptr1.2

ret ptr %ptr1a
}

define void @split_non_promotable(ptr %ptr1, ptr %ptr2, ptr %out1, ptr %out2) {
; CHECK-LABEL: define void @split_non_promotable(
; CHECK-SAME: ptr [[PTR1:%.*]], ptr [[PTR2:%.*]], ptr [[OUT1:%.*]], ptr [[OUT2:%.*]]) {
; CHECK-NEXT: [[ALLOCA_SROA_2:%.*]] = alloca ptr, align 8
; CHECK-NEXT: store volatile ptr [[PTR2]], ptr [[ALLOCA_SROA_2]], align 8
; CHECK-NEXT: [[PTR2A:%.*]] = load volatile ptr, ptr [[ALLOCA_SROA_2]], align 8
; CHECK-NEXT: store ptr [[PTR1]], ptr [[OUT1]], align 8
; CHECK-NEXT: store ptr [[PTR2A]], ptr [[OUT2]], align 8
; CHECK-NEXT: ret void
;
%alloca = alloca { ptr, ptr }

%protptrptr1.1 = call ptr @llvm.protected.field.ptr(ptr %alloca, i64 1, i1 true)
store ptr %ptr1, ptr %protptrptr1.1
%protptrptr1.2 = call ptr @llvm.protected.field.ptr(ptr %alloca, i64 1, i1 true)
%ptr1a = load ptr, ptr %protptrptr1.2

%gep = getelementptr { ptr, ptr }, ptr %alloca, i64 0, i32 1
%protptrptr2.1 = call ptr @llvm.protected.field.ptr(ptr %gep, i64 2, i1 true)
store volatile ptr %ptr2, ptr %protptrptr2.1
%protptrptr2.2 = call ptr @llvm.protected.field.ptr(ptr %gep, i64 2, i1 true)
%ptr2a = load volatile ptr, ptr %protptrptr2.2

store ptr %ptr1a, ptr %out1
store ptr %ptr2a, ptr %out2
ret void
}
Loading