Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions src/spirv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ function finish_ir!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module,
entry = wrap_byval(job, mod, entry)
end

# SPIR-V does not support i128, convert alloca arrays to vector types
convert_i128_allocas!(mod)

# add module metadata
## OpenCL 2.0
push!(metadata(mod)["opencl.ocl.version"],
Expand Down Expand Up @@ -283,6 +286,62 @@ function rm_freeze!(mod::LLVM.Module)
return changed
end

# convert alloca [N x i128] to alloca [N x <2 x i64>]
# SPIR-V doesn't support i128 types, but we can represent them as vectors
function convert_i128_allocas!(mod::LLVM.Module)
job = current_job::CompilerJob
changed = false
@tracepoint "convert i128 allocas" begin

for f in functions(mod), bb in blocks(f)
for inst in instructions(bb)
if inst isa LLVM.AllocaInst
alloca_type = LLVMType(LLVM.API.LLVMGetAllocatedType(inst))

# Check if this is an i128 or an array of i128
if alloca_type isa LLVM.ArrayType
T = eltype(alloca_type)
else
T = alloca_type
end
if T isa LLVM.IntegerType && width(T) == 128
# replace i128 with <2 x i64>
vec_type = LLVM.VectorType(LLVM.Int64Type(), 2)

if alloca_type isa LLVM.ArrayType
array_size = length(alloca_type)
new_alloca_type = LLVM.ArrayType(vec_type, array_size)
else
new_alloca_type = vec_type
end
align_val = alignment(inst)

# Create new alloca with vector type
@dispose builder=IRBuilder() begin
position!(builder, inst)
new_alloca = alloca!(builder, new_alloca_type)
alignment!(new_alloca, align_val)

# Bitcast the new alloca back to the original pointer type
# XXX: The issue only seems to manifest itself on LLVM >= 18
# where we use opaque pointers anyways, so not sure this
# is needed
old_ptr_type = LLVMType(LLVM.API.LLVMTypeOf(inst.ref))
bitcast_ptr = bitcast!(builder, new_alloca, old_ptr_type)

replace_uses!(inst, bitcast_ptr)
erase!(inst)
changed = true
end
end
end
end
end

end
return changed
end

# wrap byval pointers in a single-value struct
function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
ft = function_type(f)::LLVM.FunctionType
Expand Down
33 changes: 33 additions & 0 deletions test/spirv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,37 @@ end

end

@testset "replace i128 allocas" begin
mod = @eval module $(gensym())
# reimplement some of SIMD.jl
struct Vec{N, T}
data::NTuple{N, Core.VecElement{T}}
end
@generated function fadd(x::Vec{N, Float32}, y::Vec{N, Float32}) where {N}
quote
Vec(Base.llvmcall($"""
%ret = fadd <$N x float> %0, %1
ret <$N x float> %ret
""", NTuple{N, Core.VecElement{Float32}}, NTuple{2, NTuple{N, Core.VecElement{Float32}}}, x.data, y.data))
end
end
kernel(x...) = @noinline fadd(x...)
end

@test @filecheck begin
# TODO: should structs of `NTuple{VecElement{T}}` be passed by value instead of sret?
check"CHECK-NOT: i128"
check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}"
@static VERSION >= v"1.12" && check"CHECK: alloca <2 x i64>, align 16"
SPIRV.code_llvm(mod.kernel, NTuple{2, mod.Vec{4, Float32}}; backend, dump_module=true)
end

@test @filecheck begin
check"CHECK-NOT: i128"
check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}"
@static VERSION >= v"1.12" && check"CHECK: alloca [2 x <2 x i64>], align 16"
SPIRV.code_llvm(mod.kernel, NTuple{2, mod.Vec{8, Float32}}; backend, dump_module=true)
end
end

end
Loading