diff --git a/llvm/include/llvm/Analysis/PtrUseVisitor.h b/llvm/include/llvm/Analysis/PtrUseVisitor.h index 0858d8aee2186..a39f6881f24f3 100644 --- a/llvm/include/llvm/Analysis/PtrUseVisitor.h +++ b/llvm/include/llvm/Analysis/PtrUseVisitor.h @@ -134,6 +134,7 @@ class PtrUseVisitorBase { UseAndIsOffsetKnownPair UseAndIsOffsetKnown; APInt Offset; + Value *ProtectedFieldDisc; }; /// The worklist of to-visit uses. @@ -158,6 +159,10 @@ class PtrUseVisitorBase { /// The constant offset of the use if that is known. APInt Offset; + // When this access is via an llvm.protected.field.ptr intrinsic, contains + // the second argument to the intrinsic, the discriminator. + Value *ProtectedFieldDisc; + /// @} /// Note that the constructor is protected because this class must be a base @@ -230,6 +235,7 @@ class PtrUseVisitor : protected InstVisitor, IntegerType *IntIdxTy = cast(DL.getIndexType(I.getType())); IsOffsetKnown = true; Offset = APInt(IntIdxTy->getBitWidth(), 0); + ProtectedFieldDisc = nullptr; PI.reset(); // Enqueue the uses of this pointer. @@ -242,6 +248,7 @@ class PtrUseVisitor : protected InstVisitor, IsOffsetKnown = ToVisit.UseAndIsOffsetKnown.getInt(); if (IsOffsetKnown) Offset = std::move(ToVisit.Offset); + ProtectedFieldDisc = ToVisit.ProtectedFieldDisc; Instruction *I = cast(U->getUser()); static_cast(this)->visit(I); @@ -300,6 +307,14 @@ class PtrUseVisitor : protected InstVisitor, case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: return; // No-op intrinsics. + + case Intrinsic::protected_field_ptr: { + if (!IsOffsetKnown) + return Base::visitIntrinsicInst(II); + ProtectedFieldDisc = II.getArgOperand(1); + enqueueUsers(II); + break; + } } } diff --git a/llvm/lib/Analysis/PtrUseVisitor.cpp b/llvm/lib/Analysis/PtrUseVisitor.cpp index 9c79546f491ef..59a09c4ea8721 100644 --- a/llvm/lib/Analysis/PtrUseVisitor.cpp +++ b/llvm/lib/Analysis/PtrUseVisitor.cpp @@ -22,7 +22,8 @@ void detail::PtrUseVisitorBase::enqueueUsers(Value &I) { if (VisitedUses.insert(&U).second) { UseToVisit NewU = { UseToVisit::UseAndIsOffsetKnownPair(&U, IsOffsetKnown), - Offset + Offset, + ProtectedFieldDisc, }; Worklist.push_back(std::move(NewU)); } diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp index 23256cf2acbd2..c212c4f45dc37 100644 --- a/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/llvm/lib/Transforms/Scalar/SROA.cpp @@ -62,6 +62,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" @@ -523,9 +524,10 @@ class Slice { public: Slice() = default; - Slice(uint64_t BeginOffset, uint64_t EndOffset, Use *U, bool IsSplittable) + Slice(uint64_t BeginOffset, uint64_t EndOffset, Use *U, bool IsSplittable, + Value *ProtectedFieldDisc) : BeginOffset(BeginOffset), EndOffset(EndOffset), - UseAndIsSplittable(U, IsSplittable) {} + UseAndIsSplittable(U, IsSplittable), ProtectedFieldDisc(ProtectedFieldDisc) {} uint64_t beginOffset() const { return BeginOffset; } uint64_t endOffset() const { return EndOffset; } @@ -538,6 +540,10 @@ class Slice { bool isDead() const { return getUse() == nullptr; } void kill() { UseAndIsSplittable.setPointer(nullptr); } + // When this access is via an llvm.protected.field.ptr intrinsic, contains + // the second argument to the intrinsic, the discriminator. + Value *ProtectedFieldDisc; + /// Support for ordering ranges. /// /// This provides an ordering over ranges such that start offsets are @@ -631,6 +637,9 @@ class AllocaSlices { /// Access the dead users for this alloca. ArrayRef getDeadUsers() const { return DeadUsers; } + /// Access the PFP users for this alloca. + ArrayRef getPFPUsers() const { return PFPUsers; } + /// Access Uses that should be dropped if the alloca is promotable. ArrayRef getDeadUsesIfPromotable() const { return DeadUseIfPromotable; @@ -691,6 +700,10 @@ class AllocaSlices { /// they come from outside of the allocated space. SmallVector DeadUsers; + /// Users that are llvm.protected.field.ptr intrinsics. These will be RAUW'd + /// to their first argument if we rewrite the alloca. + SmallVector PFPUsers; + /// Uses which will become dead if can promote the alloca. SmallVector DeadUseIfPromotable; @@ -1064,7 +1077,8 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor { EndOffset = AllocSize; } - AS.Slices.push_back(Slice(BeginOffset, EndOffset, U, IsSplittable)); + AS.Slices.push_back( + Slice(BeginOffset, EndOffset, U, IsSplittable, ProtectedFieldDisc)); } void visitBitCastInst(BitCastInst &BC) { @@ -1274,6 +1288,9 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor { return; } + if (II.getIntrinsicID() == Intrinsic::protected_field_ptr) + AS.PFPUsers.push_back(&II); + Base::visitIntrinsicInst(II); } @@ -4682,7 +4699,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { NewSlices.push_back( Slice(BaseOffset + PartOffset, BaseOffset + PartOffset + PartSize, &PLoad->getOperandUse(PLoad->getPointerOperandIndex()), - /*IsSplittable*/ false)); + /*IsSplittable*/ false, nullptr)); LLVM_DEBUG(dbgs() << " new slice [" << NewSlices.back().beginOffset() << ", " << NewSlices.back().endOffset() << "): " << *PLoad << "\n"); @@ -4838,10 +4855,12 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { LLVMContext::MD_access_group}); // Now build a new slice for the alloca. + // ProtectedFieldDisc==nullptr is a lie, but it doesn't matter because we + // already determined that all accesses are consistent. NewSlices.push_back( Slice(BaseOffset + PartOffset, BaseOffset + PartOffset + PartSize, &PStore->getOperandUse(PStore->getPointerOperandIndex()), - /*IsSplittable*/ false)); + /*IsSplittable*/ false, nullptr)); LLVM_DEBUG(dbgs() << " new slice [" << NewSlices.back().beginOffset() << ", " << NewSlices.back().endOffset() << "): " << *PStore << "\n"); @@ -5618,6 +5637,32 @@ SROA::runOnAlloca(AllocaInst &AI) { return {Changed, CFGChanged}; } + for (auto &P : AS.partitions()) { + std::optional ProtectedFieldDisc; + // For now, we can't split if a field is accessed both via protected + // field and not. + for (Slice &S : P) { + if (auto *II = dyn_cast(S.getUse()->getUser())) + if (II->getIntrinsicID() == Intrinsic::lifetime_start || + II->getIntrinsicID() == Intrinsic::lifetime_end) + continue; + if (!ProtectedFieldDisc) + ProtectedFieldDisc = S.ProtectedFieldDisc; + if (*ProtectedFieldDisc != S.ProtectedFieldDisc) + return {Changed, CFGChanged}; + } + for (Slice *S : P.splitSliceTails()) { + if (auto *II = dyn_cast(S->getUse()->getUser())) + if (II->getIntrinsicID() == Intrinsic::lifetime_start || + II->getIntrinsicID() == Intrinsic::lifetime_end) + continue; + if (!ProtectedFieldDisc) + ProtectedFieldDisc = S->ProtectedFieldDisc; + if (*ProtectedFieldDisc != S->ProtectedFieldDisc) + return {Changed, CFGChanged}; + } + } + // Delete all the dead users of this alloca before splitting and rewriting it. for (Instruction *DeadUser : AS.getDeadUsers()) { // Free up everything used by this instruction. @@ -5635,6 +5680,12 @@ SROA::runOnAlloca(AllocaInst &AI) { clobberUse(*DeadOp); Changed = true; } + for (IntrinsicInst *PFPUser : AS.getPFPUsers()) { + PFPUser->replaceAllUsesWith(PFPUser->getArgOperand(0)); + + DeadInsts.push_back(PFPUser); + Changed = true; + } // No slices to split. Leave the dead alloca for a later pass to clean up. if (AS.begin() == AS.end()) diff --git a/llvm/test/Transforms/SROA/protected-field-pointer.ll b/llvm/test/Transforms/SROA/protected-field-pointer.ll new file mode 100644 index 0000000000000..d4d3432487e83 --- /dev/null +++ b/llvm/test/Transforms/SROA/protected-field-pointer.ll @@ -0,0 +1,73 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -passes=sroa -S < %s | FileCheck %s + +define void @slice(ptr %ptr1, ptr %ptr2, ptr %out1, ptr %out2) { +; CHECK-LABEL: define void @slice( +; CHECK-SAME: ptr [[PTR1:%.*]], ptr [[PTR2:%.*]], ptr [[OUT1:%.*]], ptr [[OUT2:%.*]]) { +; CHECK-NEXT: store ptr [[PTR1]], ptr [[OUT1]], align 8 +; CHECK-NEXT: store ptr [[PTR2]], ptr [[OUT2]], align 8 +; CHECK-NEXT: ret void +; + %alloca = alloca { ptr, ptr } + + %protptrptr1.1 = call ptr @llvm.protected.field.ptr(ptr %alloca, i64 1, i1 true) + store ptr %ptr1, ptr %protptrptr1.1 + %protptrptr1.2 = call ptr @llvm.protected.field.ptr(ptr %alloca, i64 1, i1 true) + %ptr1a = load ptr, ptr %protptrptr1.2 + + %gep = getelementptr { ptr, ptr }, ptr %alloca, i64 0, i32 1 + %protptrptr2.1 = call ptr @llvm.protected.field.ptr(ptr %gep, i64 2, i1 true) + store ptr %ptr2, ptr %protptrptr2.1 + %protptrptr2.2 = call ptr @llvm.protected.field.ptr(ptr %gep, i64 2, i1 true) + %ptr2a = load ptr, ptr %protptrptr2.2 + + store ptr %ptr1a, ptr %out1 + store ptr %ptr2a, ptr %out2 + ret void +} + +define ptr @mixed(ptr %ptr) { +; CHECK-LABEL: define ptr @mixed( +; CHECK-SAME: ptr [[PTR:%.*]]) { +; CHECK-NEXT: [[ALLOCA:%.*]] = alloca ptr, align 8 +; CHECK-NEXT: store ptr [[PTR]], ptr [[ALLOCA]], align 8 +; CHECK-NEXT: [[PROTPTRPTR1_2:%.*]] = call ptr @llvm.protected.field.ptr(ptr [[ALLOCA]], i64 1, i1 true) +; CHECK-NEXT: [[PTR1A:%.*]] = load ptr, ptr [[PROTPTRPTR1_2]], align 8 +; CHECK-NEXT: ret ptr [[PTR1A]] +; + %alloca = alloca ptr + + store ptr %ptr, ptr %alloca + %protptrptr1.2 = call ptr @llvm.protected.field.ptr(ptr %alloca, i64 1, i1 true) + %ptr1a = load ptr, ptr %protptrptr1.2 + + ret ptr %ptr1a +} + +define void @split_non_promotable(ptr %ptr1, ptr %ptr2, ptr %out1, ptr %out2) { +; CHECK-LABEL: define void @split_non_promotable( +; CHECK-SAME: ptr [[PTR1:%.*]], ptr [[PTR2:%.*]], ptr [[OUT1:%.*]], ptr [[OUT2:%.*]]) { +; CHECK-NEXT: [[ALLOCA_SROA_2:%.*]] = alloca ptr, align 8 +; CHECK-NEXT: store volatile ptr [[PTR2]], ptr [[ALLOCA_SROA_2]], align 8 +; CHECK-NEXT: [[PTR2A:%.*]] = load volatile ptr, ptr [[ALLOCA_SROA_2]], align 8 +; CHECK-NEXT: store ptr [[PTR1]], ptr [[OUT1]], align 8 +; CHECK-NEXT: store ptr [[PTR2A]], ptr [[OUT2]], align 8 +; CHECK-NEXT: ret void +; + %alloca = alloca { ptr, ptr } + + %protptrptr1.1 = call ptr @llvm.protected.field.ptr(ptr %alloca, i64 1, i1 true) + store ptr %ptr1, ptr %protptrptr1.1 + %protptrptr1.2 = call ptr @llvm.protected.field.ptr(ptr %alloca, i64 1, i1 true) + %ptr1a = load ptr, ptr %protptrptr1.2 + + %gep = getelementptr { ptr, ptr }, ptr %alloca, i64 0, i32 1 + %protptrptr2.1 = call ptr @llvm.protected.field.ptr(ptr %gep, i64 2, i1 true) + store volatile ptr %ptr2, ptr %protptrptr2.1 + %protptrptr2.2 = call ptr @llvm.protected.field.ptr(ptr %gep, i64 2, i1 true) + %ptr2a = load volatile ptr, ptr %protptrptr2.2 + + store ptr %ptr1a, ptr %out1 + store ptr %ptr2a, ptr %out2 + ret void +}