Skip to content

Commit 2797f4b

Browse files
authored
[SPIRV] convert i128 allocas to <2 x i64> (#734)
1 parent 47da204 commit 2797f4b

File tree

3 files changed

+93
-1
lines changed

3 files changed

+93
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "GPUCompiler"
22
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
3-
version = "1.7.2"
3+
version = "1.7.3"
44
authors = ["Tim Besard <tim.besard@gmail.com>"]
55

66
[deps]

src/spirv.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ function finish_ir!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module,
110110
entry = wrap_byval(job, mod, entry)
111111
end
112112

113+
# SPIR-V does not support i128, convert alloca arrays to vector types
114+
convert_i128_allocas!(mod)
115+
113116
# add module metadata
114117
## OpenCL 2.0
115118
push!(metadata(mod)["opencl.ocl.version"],
@@ -283,6 +286,62 @@ function rm_freeze!(mod::LLVM.Module)
283286
return changed
284287
end
285288

289+
# convert alloca [N x i128] to alloca [N x <2 x i64>]
290+
# SPIR-V doesn't support i128 types, but we can represent them as vectors
291+
function convert_i128_allocas!(mod::LLVM.Module)
292+
job = current_job::CompilerJob
293+
changed = false
294+
@tracepoint "convert i128 allocas" begin
295+
296+
for f in functions(mod), bb in blocks(f)
297+
for inst in instructions(bb)
298+
if inst isa LLVM.AllocaInst
299+
alloca_type = LLVMType(LLVM.API.LLVMGetAllocatedType(inst))
300+
301+
# Check if this is an i128 or an array of i128
302+
if alloca_type isa LLVM.ArrayType
303+
T = eltype(alloca_type)
304+
else
305+
T = alloca_type
306+
end
307+
if T isa LLVM.IntegerType && width(T) == 128
308+
# replace i128 with <2 x i64>
309+
vec_type = LLVM.VectorType(LLVM.Int64Type(), 2)
310+
311+
if alloca_type isa LLVM.ArrayType
312+
array_size = length(alloca_type)
313+
new_alloca_type = LLVM.ArrayType(vec_type, array_size)
314+
else
315+
new_alloca_type = vec_type
316+
end
317+
align_val = alignment(inst)
318+
319+
# Create new alloca with vector type
320+
@dispose builder=IRBuilder() begin
321+
position!(builder, inst)
322+
new_alloca = alloca!(builder, new_alloca_type)
323+
alignment!(new_alloca, align_val)
324+
325+
# Bitcast the new alloca back to the original pointer type
326+
# XXX: The issue only seems to manifest itself on LLVM >= 18
327+
# where we use opaque pointers anyways, so not sure this
328+
# is needed
329+
old_ptr_type = LLVMType(LLVM.API.LLVMTypeOf(inst.ref))
330+
bitcast_ptr = bitcast!(builder, new_alloca, old_ptr_type)
331+
332+
replace_uses!(inst, bitcast_ptr)
333+
erase!(inst)
334+
changed = true
335+
end
336+
end
337+
end
338+
end
339+
end
340+
341+
end
342+
return changed
343+
end
344+
286345
# wrap byval pointers in a single-value struct
287346
function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
288347
ft = function_type(f)::LLVM.FunctionType

test/spirv.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,37 @@ end
112112

113113
end
114114

115+
@testset "replace i128 allocas" begin
116+
mod = @eval module $(gensym())
117+
# reimplement some of SIMD.jl
118+
struct Vec{N, T}
119+
data::NTuple{N, Core.VecElement{T}}
120+
end
121+
@generated function fadd(x::Vec{N, Float32}, y::Vec{N, Float32}) where {N}
122+
quote
123+
Vec(Base.llvmcall($"""
124+
%ret = fadd <$N x float> %0, %1
125+
ret <$N x float> %ret
126+
""", NTuple{N, Core.VecElement{Float32}}, NTuple{2, NTuple{N, Core.VecElement{Float32}}}, x.data, y.data))
127+
end
128+
end
129+
kernel(x...) = @noinline fadd(x...)
130+
end
131+
132+
@test @filecheck begin
133+
# TODO: should structs of `NTuple{VecElement{T}}` be passed by value instead of sret?
134+
check"CHECK-NOT: i128"
135+
check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}"
136+
@static VERSION >= v"1.12" && check"CHECK: alloca <2 x i64>, align 16"
137+
SPIRV.code_llvm(mod.kernel, NTuple{2, mod.Vec{4, Float32}}; backend, dump_module=true)
138+
end
139+
140+
@test @filecheck begin
141+
check"CHECK-NOT: i128"
142+
check"CHECK-LABEL: define void @{{(julia|j)_kernel_[0-9]+}}"
143+
@static VERSION >= v"1.12" && check"CHECK: alloca [2 x <2 x i64>], align 16"
144+
SPIRV.code_llvm(mod.kernel, NTuple{2, mod.Vec{8, Float32}}; backend, dump_module=true)
145+
end
146+
end
147+
115148
end

0 commit comments

Comments
 (0)