From d6cb9ffed7c285bee795f23eccba4bc80099642c Mon Sep 17 00:00:00 2001 From: "ha.git" Date: Mon, 11 Aug 2025 18:56:58 +0200 Subject: [PATCH 1/3] Implement canonicalize and canonicalize_rref routines. --- benchmark/KernelAbstractions/README.md | 2 +- .../benchmark_platform_CUDA.jl | 3 +- .../benchmark_platform_OpenCL.jl | 3 +- .../benchmark_platform_ROCm.jl | 3 +- .../benchmark_KA_mul_leftright.jl | 55 +- .../implementation/benchmark_platform.jl | 1 + .../implementation/definitions.jl | 9 +- .../implementation/imports.jl | 1 + .../implementation/utilities.jl | 8 +- .../QuantumCliffordKAExt.jl | 1 + ext/QuantumCliffordKAExt/canonicalization.jl | 7 + .../canonicalization/canonicalize.jl | 207 +++++++ .../canonicalization/canonicalize_gott.jl | 5 + .../canonicalization/canonicalize_rref.jl | 213 +++++++ .../canonicalization/common.jl | 551 ++++++++++++++++++ ext/QuantumCliffordKAExt/definitions.jl | 40 +- .../definitions/default_parameters.jl | 15 + .../definitions/enumerations.jl | 58 ++ .../definitions/fixed_sizes.jl | 5 + .../definitions/mutex_configuration.jl | 11 + .../definitions/type_shorthands.jl | 29 + .../definitions/word_size_integers.jl | 7 + ext/QuantumCliffordKAExt/imports.jl | 2 +- ext/QuantumCliffordKAExt/mul_leftright.jl | 460 +-------------- .../mul_leftright/device_mul.jl | 139 +++++ .../mul_leftright/host_interface.jl | 430 ++++++++++++++ ext/QuantumCliffordKAExt/utilities.jl | 98 +--- .../utilities/bit_manipulation.jl | 23 + .../utilities/kernel_configuration.jl | 22 + .../utilities/mutex_management.jl | 87 +++ .../utilities/reductions.jl | 111 ++++ .../utilities/scan_step.jl | 103 ++++ .../utilities/snippets.jl | 220 +++++++ src/throws.jl | 4 + .../implementation/definitions.jl | 4 +- .../implementation/utilities.jl | 51 +- 36 files changed, 2354 insertions(+), 634 deletions(-) create mode 100644 ext/QuantumCliffordKAExt/canonicalization.jl create mode 100644 ext/QuantumCliffordKAExt/canonicalization/canonicalize.jl create mode 100644 ext/QuantumCliffordKAExt/canonicalization/canonicalize_gott.jl create mode 100644 ext/QuantumCliffordKAExt/canonicalization/canonicalize_rref.jl create mode 100644 ext/QuantumCliffordKAExt/canonicalization/common.jl create mode 100644 ext/QuantumCliffordKAExt/definitions/default_parameters.jl create mode 100644 ext/QuantumCliffordKAExt/definitions/enumerations.jl create mode 100644 ext/QuantumCliffordKAExt/definitions/fixed_sizes.jl create mode 100644 ext/QuantumCliffordKAExt/definitions/mutex_configuration.jl create mode 100644 ext/QuantumCliffordKAExt/definitions/type_shorthands.jl create mode 100644 ext/QuantumCliffordKAExt/definitions/word_size_integers.jl create mode 100644 ext/QuantumCliffordKAExt/mul_leftright/device_mul.jl create mode 100644 ext/QuantumCliffordKAExt/mul_leftright/host_interface.jl create mode 100644 ext/QuantumCliffordKAExt/utilities/bit_manipulation.jl create mode 100644 ext/QuantumCliffordKAExt/utilities/kernel_configuration.jl create mode 100644 ext/QuantumCliffordKAExt/utilities/mutex_management.jl create mode 100644 ext/QuantumCliffordKAExt/utilities/reductions.jl create mode 100644 ext/QuantumCliffordKAExt/utilities/scan_step.jl create mode 100644 ext/QuantumCliffordKAExt/utilities/snippets.jl diff --git a/benchmark/KernelAbstractions/README.md b/benchmark/KernelAbstractions/README.md index 8716bd91f..517ae6833 100644 --- a/benchmark/KernelAbstractions/README.md +++ b/benchmark/KernelAbstractions/README.md @@ -4,7 +4,7 @@ 2. Ensure that all the packages listed in `implementation/imports.jl` are installed as this is not handled automatically. 3. Ensure that the backend package(s) listed in the pertinent `benchmark_platform_*.jl` script are properly setup and configured. 4. Pass said script as an argument to julia (optionally, also set the number of executing host threads) and await for the benchmark to conclude. -5. Navigate to the newly created directory in order to inspect the results. +5. Navigate to the newly created directory (hierarchy) and find the matching platform/timestamp pair in order to inspect the results. # Noteworthy Details diff --git a/benchmark/KernelAbstractions/benchmark_platform_CUDA.jl b/benchmark/KernelAbstractions/benchmark_platform_CUDA.jl index 0e1083f0d..702a50cac 100644 --- a/benchmark/KernelAbstractions/benchmark_platform_CUDA.jl +++ b/benchmark/KernelAbstractions/benchmark_platform_CUDA.jl @@ -1,9 +1,8 @@ include("implementation/benchmark_platform.jl") -using Dates: value, now, UNIXEPOCH using CUDA: CuArray, devices, synchronize const AT = CuArray -const path = "CUDA_benchmark_" * string(value(now()) - UNIXEPOCH) +const path = "QuantumClifford_benchmarks/CUDA" const can_run = length(devices()) > 0 diff --git a/benchmark/KernelAbstractions/benchmark_platform_OpenCL.jl b/benchmark/KernelAbstractions/benchmark_platform_OpenCL.jl index 6544a55e2..302c94421 100644 --- a/benchmark/KernelAbstractions/benchmark_platform_OpenCL.jl +++ b/benchmark/KernelAbstractions/benchmark_platform_OpenCL.jl @@ -1,10 +1,9 @@ include("implementation/benchmark_platform.jl") -using Dates: value, now, UNIXEPOCH import pocl_jll using OpenCL: CLArray, cl.devices, cl.platforms, cl.finish, cl.queue const AT = CLArray -const path = "OpenCL_benchmark_" * string(value(now()) - UNIXEPOCH) +const path = "QuantumClifford_benchmarks/OpenCL" const can_run = any(length(devices(platform)) > 0 for platform in platforms()) diff --git a/benchmark/KernelAbstractions/benchmark_platform_ROCm.jl b/benchmark/KernelAbstractions/benchmark_platform_ROCm.jl index cd0c2621f..50bc6868b 100644 --- a/benchmark/KernelAbstractions/benchmark_platform_ROCm.jl +++ b/benchmark/KernelAbstractions/benchmark_platform_ROCm.jl @@ -1,9 +1,8 @@ include("implementation/benchmark_platform.jl") -using Dates: value, now, UNIXEPOCH using AMDGPU: ROCArray, devices, synchronize const AT = ROCArray -const path = "ROCm_benchmark_" * string(value(now()) - UNIXEPOCH) +const path = "QuantumClifford_benchmarks/ROCm" const can_run = length(devices()) > 0 diff --git a/benchmark/KernelAbstractions/implementation/benchmark_KA_mul_leftright.jl b/benchmark/KernelAbstractions/implementation/benchmark_KA_mul_leftright.jl index 3b50b1bf6..16fefba65 100644 --- a/benchmark/KernelAbstractions/implementation/benchmark_KA_mul_leftright.jl +++ b/benchmark/KernelAbstractions/implementation/benchmark_KA_mul_leftright.jl @@ -1,19 +1,27 @@ # This must be done explicitly as they are not exported. using QuantumClifford: mul_left!, mul_right!, Tableau -@inline host_f(x, y; phases::Val{phase_B} = Val(true)) where {phase_B} = +@inline function host_f!( + x, y; + phases::Val{phase_B} = Val(default_phases) + ) where {phase_B} + mul_left!(x, y; phases = phases) -@inline function device_f( +end + +@inline function device_f!( x, y, synchronize; - phases::Val{phase_B} = Val(true), + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), block_size::Val{block_SZ} = Val(default_block_size), batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, block_SZ, batch_SZ} + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} mul_left!( x, y; - phases = phases, block_size = block_size, batch_size = batch_size + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size ) synchronize() @@ -21,7 +29,7 @@ end @inline function benchmark_KA_mul_pauli_pauli( AT, synchronize, path; - phases::Val{phase_B} = Val(true) + phases::Val{phase_B} = Val(default_phases) ) where {phase_B} host_time = zeros(Float64, length(n_MiB)) @@ -46,17 +54,17 @@ end d_p2 = copy(d_p1) synchronize() # Trigger compilation before benchmarking. - host_f(h_p1, h_p2; phases = phases) - host_time[i] = @belapsed host_f( + host_f!(h_p1, h_p2; phases = phases) + host_time[i] = @belapsed host_f!( $h_p1, $h_p2; phases = $phases ) evals = evals samples = samples seconds = seconds for (j, size) in enumerate(batch_sizes) - device_f( + device_f!( d_p1, d_p2, synchronize; phases = phases, batch_size = Val(size) ) device_time[j, i] = - @belapsed device_f( + @belapsed device_f!( $d_p1, $d_p2, $synchronize; phases = $phases, batch_size = Val($size) ) evals = evals samples = samples seconds = seconds @@ -72,6 +80,8 @@ end string(Sys.CPU_THREADS) * ", Device block size = $default_block_size" xlabel = "Pauli operator size (MiB)" label = hcat(("Device - batch size = " .* string.(batch_sizes))..., "Host") + path *= "/pauli_pauli" + mkpath(path) plot( n_MiB, 10^3 .* hcat(device_cat..., host_time); @@ -79,7 +89,7 @@ end title = title, label = label, xlabel = xlabel, ylabel = "Runtime (ms)", background_color = :transparent ) - savefig("$path/runtime_pauli_pauli_phase_$phase_B.$format") + savefig("$path/runtime.$format") plot( n_MiB, map(x -> host_time ./ x, device_cat); @@ -87,13 +97,13 @@ end label = hcat(label[1 : end - 1]...), xlabel = xlabel, ylabel = "Ratio (host/device)", background_color = :transparent ) - savefig("$path/ratio_pauli_pauli_phase_$phase_B.$format") + savefig("$path/ratio.$format") end @inline function benchmark_KA_mul_tableau_pauli( AT, synchronize, path; - phases::Val{phase_B} = Val(true) + phases::Val{phase_B} = Val(default_phases) ) where {phase_B} host_time = zeros(Float64, length(n_MiB)) @@ -126,17 +136,17 @@ end ) synchronize() # Trigger compilation before benchmarking. - host_f(h_t, h_p; phases = phases) - host_time[i] = @belapsed host_f( + host_f!(h_t, h_p; phases = phases) + host_time[i] = @belapsed host_f!( $h_t, $h_p; phases = $phases ) evals = evals samples = samples seconds = seconds for (j, size) in enumerate(batch_sizes) - device_f( + device_f!( d_t, d_p, synchronize; phases = phases, batch_size = Val(size) ) device_time[j, i] = - @belapsed device_f( + @belapsed device_f!( $d_t, $d_p, $synchronize; phases = $phases, batch_size = Val($size) ) evals = evals samples = samples seconds = seconds @@ -152,6 +162,8 @@ end string(Sys.CPU_THREADS) * ", Device block size = $default_block_size" xlabel = "Tableau size (MiB)" label = hcat(("Device - batch size = " .* string.(batch_sizes))..., "Host") + path *= "/tableau_pauli" + mkpath(path) plot( n_MiB, 10^3 .* hcat(device_cat..., host_time); @@ -159,7 +171,7 @@ end title = title, label = label, xlabel = xlabel, ylabel = "Runtime (ms)", background_color = :transparent ) - savefig("$path/runtime_tableau_pauli_phase_$phase_B.$format") + savefig("$path/runtime.$format") plot( n_MiB, map(x -> host_time ./ x, device_cat); @@ -167,17 +179,16 @@ end label = hcat(label[1 : end - 1]...), xlabel = xlabel, ylabel = "Ratio (host/device)", background_color = :transparent ) - savefig("$path/ratio_tableau_pauli_phase_$phase_B.$format") + savefig("$path/ratio.$format") end @inline function benchmark_KA_mul_leftright( AT, synchronize, path; - phases::Val{phase_B} = Val(true) + phases::Val{phase_B} = Val(default_phases) ) where {phase_B} - path = "$path/mul_leftright" - mkpath(path) + path *= "/mul_leftright/phase_$phase_B" benchmark_KA_mul_pauli_pauli(AT, synchronize, path; phases = phases) benchmark_KA_mul_tableau_pauli(AT, synchronize, path; phases = phases) diff --git a/benchmark/KernelAbstractions/implementation/benchmark_platform.jl b/benchmark/KernelAbstractions/implementation/benchmark_platform.jl index 212dff2d6..c5fa19723 100644 --- a/benchmark/KernelAbstractions/implementation/benchmark_platform.jl +++ b/benchmark/KernelAbstractions/implementation/benchmark_platform.jl @@ -4,6 +4,7 @@ include("utilities.jl") include("benchmark_KA_mul_leftright.jl") @inline function benchmark_platform(AT, synchronize, path) + path *= "/" * string(value(now()) - UNIXEPOCH) benchmark_KA_mul_leftright(AT, synchronize, path; phases = Val(true)) benchmark_KA_mul_leftright(AT, synchronize, path; phases = Val(false)) end diff --git a/benchmark/KernelAbstractions/implementation/definitions.jl b/benchmark/KernelAbstractions/implementation/definitions.jl index 9e6edf19c..bf9bd0597 100644 --- a/benchmark/KernelAbstractions/implementation/definitions.jl +++ b/benchmark/KernelAbstractions/implementation/definitions.jl @@ -16,6 +16,9 @@ const n_MiB = [2^i for i = 1:10] # TODO: Keep these or remove them now that a good default has been set? const batch_sizes = [1, 4, 8, 16, 32, 64] -# These values are inaccessible since they originate from a package extension. -const default_block_size = 256 -const default_batch_size = 32 +# These values originate from a package extension, hence the query. +const KAExt = Base.get_extension(QuantumClifford, :QuantumCliffordKAExt) +const default_phases = KAExt.default_phases +const default_primary_axis = KAExt.default_primary_axis +const default_block_size = KAExt.default_block_size +const default_batch_size = KAExt.default_batch_size diff --git a/benchmark/KernelAbstractions/implementation/imports.jl b/benchmark/KernelAbstractions/implementation/imports.jl index 9f0a80a73..c5d27490c 100644 --- a/benchmark/KernelAbstractions/implementation/imports.jl +++ b/benchmark/KernelAbstractions/implementation/imports.jl @@ -2,6 +2,7 @@ import Atomix, GPUArraysCore, KernelAbstractions using BenchmarkTools: @belapsed +using Dates: value, now, UNIXEPOCH # Assists in reducing resource demands. using GPUArrays: AllocCache, @cached, unsafe_free! using Plots: plot, savefig diff --git a/benchmark/KernelAbstractions/implementation/utilities.jl b/benchmark/KernelAbstractions/implementation/utilities.jl index 0107dc1fd..d55260606 100644 --- a/benchmark/KernelAbstractions/implementation/utilities.jl +++ b/benchmark/KernelAbstractions/implementation/utilities.jl @@ -1,5 +1,9 @@ # Works even when broadcasting on zero-dimensional arrays. -@inline u32(v) = map(x -> UInt32(x), v) +@inline function u32(v) + return map(x -> UInt32(x), v) +end # By definition, the size of (unsigned) char is set to unity. -@inline bit_count(::Type{T}) where {T} = sizeof(T) * count_zeros(zero(Cuchar)) +@inline function bit_count(::Type{T}) where {T} + return sizeof(T) * count_zeros(zero(Cuchar)) +end diff --git a/ext/QuantumCliffordKAExt/QuantumCliffordKAExt.jl b/ext/QuantumCliffordKAExt/QuantumCliffordKAExt.jl index daf6adbb1..ac2e0e649 100644 --- a/ext/QuantumCliffordKAExt/QuantumCliffordKAExt.jl +++ b/ext/QuantumCliffordKAExt/QuantumCliffordKAExt.jl @@ -7,6 +7,7 @@ include("definitions.jl") include("../../src/throws.jl") include("utilities.jl") include("mul_leftright.jl") +include("canonicalization.jl") end #=============================================================================# diff --git a/ext/QuantumCliffordKAExt/canonicalization.jl b/ext/QuantumCliffordKAExt/canonicalization.jl new file mode 100644 index 000000000..3eb3ee2e6 --- /dev/null +++ b/ext/QuantumCliffordKAExt/canonicalization.jl @@ -0,0 +1,7 @@ + +#=============================================================================# +include("canonicalization/common.jl") +include("canonicalization/canonicalize.jl") +include("canonicalization/canonicalize_rref.jl") +include("canonicalization/canonicalize_gott.jl") +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/canonicalization/canonicalize.jl b/ext/QuantumCliffordKAExt/canonicalization/canonicalize.jl new file mode 100644 index 000000000..e98c304de --- /dev/null +++ b/ext/QuantumCliffordKAExt/canonicalization/canonicalize.jl @@ -0,0 +1,207 @@ + +#=============================================================================# +import QuantumClifford: canonicalize! + +function device_canonicalize!( + ph::AbstractArray{<: Unsigned}, xzs::AbstractArray{<: Unsigned}, + output_buffer::Union{Nothing, AbstractArray{S}} = nothing; + multiplication_order::MultiplicationOrder = default_multiplication_order, + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + )::Nothing where { + S <: Integer, phase_B, primary_axis_E, block_SZ, batch_SZ + } + + phase_B isa Bool && primary_axis_E isa PrimaryAxis && + block_SZ isa Integer && block_SZ > zero(block_SZ) && + batch_SZ isa Integer && batch_SZ > zero(batch_SZ) || + throw(ArgumentError(THROW_VALS)) + + backend = KA.get_backend(xzs) + + if primary_axis_E == primary_axis_rows + tile = (one(block_SZ), block_SZ) + space = tessellate( + (size(xzs, 0x2), cld(size(xzs, 0x1) >> 0x1, batch_SZ)), + tile + ) + elseif primary_axis_E == primary_axis_qubits + tile = (block_SZ, one(block_SZ)) + space = tessellate( + (cld(size(xzs, 0x1) >> 0x1, batch_SZ), size(xzs, 0x2)), + tile + ) + end + + # Utilised for loop management. + length_xzs = size(xzs, 0x1) >> 0x1 + row_count = size(xzs, 0x2) + toggle = false + cycles_until_sync = default_scheduling_limit + + # Required for safety whilst setting up for the proceeding iteration. + mutex = create_mutex(backend) + # Double buffered for present and preceeding/proceeding iteration. + stride_fill = tracker_element_count + tracker = similar(xzs, Csize_t, stride_fill << 0x1) + fill!(tracker, typemax(Csize_t)) + # The pivot row tracker is initialised differently. + begin_fill = Integer(tracker_content_swap_to) + @inbounds fill!( + (@view tracker[begin_fill : stride_fill : end]), + zero(Csize_t) + ) + + if !isnothing(output_buffer) + fill!(output_buffer, zero(S)) + end + + bit_scan = kernel_bit_scan(backend) + snippet! = kernel_snippet!(backend) + mul_and_scan! = kernel_mul_and_scan!(backend) + + bit_scan( + xzs, nothing, mutex, tracker, + Val(sort_order_pauli_bit_prefer_x), + primary_axis, block_size, batch_size; + workgroupsize = tile, ndrange = space + ) + snippet!( + snippet_track_pivot_canonicalize!, + output_buffer, tracker, toggle, sort_order_pauli_bit_prefer_x; + ndrange = 0x1 + ) + for _ in one(row_count) : row_count + snippet!( + snippet_swap_rows_prepare_tracker!, + ph, xzs, tracker, toggle; + ndrange = length_xzs + ) + snippet!( + snippet_set_row_phase_flag!, + ph, xzs, tracker, toggle; + ndrange = row_count + ) + mul_and_scan!( + ph, xzs, multiplication_order, scan_side_greater, + nothing, false, mutex, tracker, toggle, + phases, Val(sort_order_pauli_bit_prefer_x), + primary_axis, block_size, batch_size; + workgroupsize = tile, ndrange = space + ) + # Switching the toggle is intentional. + snippet!( + snippet_track_pivot_canonicalize!, + output_buffer, tracker, !toggle, sort_order_pauli_bit_prefer_x; + ndrange = 0x1 + ) + + toggle = !toggle + cycles_until_sync -= one(cycles_until_sync) + if cycles_until_sync == zero(cycles_until_sync) + KA.synchronize(backend) + host_tracker = Array(tracker) + offset = ifelse(toggle, tracker_element_count, 0x0) + @inbounds continue_flag = + host_tracker[offset + Integer(tracker_content_bit_type)] < + Integer(pauli_bit_invalid) + if continue_flag + cycles_until_sync = default_scheduling_limit + else + break + end + end + end + + # Remove all extraneous bits left behind by previous iterations. + snippet!(snippet_mod_4_phase!, ph; ndrange = row_count) + + return nothing + +end + +# Tableau +@inline function canonicalize!( + tab::DeviceTableau, + output_buffer::Union{Nothing, AbstractGPUArray{<: Integer}} = nothing; + multiplication_order::MultiplicationOrder = default_multiplication_order, + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + @boundscheck begin + if !isnothing(output_buffer) + length(output_buffer) == 0x2 || + throw(ArgumentError(THROW_SIZE)) + end + end + + device_canonicalize!( + tab.phases, tab.xzs, output_buffer; + multiplication_order = multiplication_order, + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + + if isnothing(output_buffer) + return tab + else + return tab, output_buffer + end + +end + +# AbstractStabilizer +@inline function canonicalize!( + state::DeviceAbstractStabilizer, + output_buffer::Union{Nothing, AbstractGPUArray{<: Integer}} = nothing; + multiplication_order::MultiplicationOrder = default_multiplication_order, + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + @boundscheck begin + if !isnothing(output_buffer) + length(output_buffer) == 0x2 || + throw(ArgumentError(THROW_SIZE)) + end + end + + if state isa Stabilizer + upper = size(state.tab.xzs, 0x2) + lower = one(upper) + elseif state isa Destabilizer + upper = size(state.tab.xzs, 0x2) + lower = (upper >> one(upper)) + one(upper) + elseif state isa MixedStabilizer + upper = state.rank + lower = one(upper) + elseif state isa MixedDestabilizer + upper = size(state.tab.xzs, 0x2) + lower = (upper >> one(upper)) + one(upper) + upper = (upper >> one(upper)) + state.rank + end + + @inbounds device_canonicalize!( + (@view state.tab.phases[lower : upper]), + (@view state.tab.xzs[:, lower : upper]), + output_buffer; + multiplication_order = multiplication_order, + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + + if isnothing(output_buffer) + return state + else + return state, output_buffer + end + +end +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/canonicalization/canonicalize_gott.jl b/ext/QuantumCliffordKAExt/canonicalization/canonicalize_gott.jl new file mode 100644 index 000000000..85b64d785 --- /dev/null +++ b/ext/QuantumCliffordKAExt/canonicalization/canonicalize_gott.jl @@ -0,0 +1,5 @@ + +#=============================================================================# +import QuantumClifford: canonicalize_gott! +# TODO: Implement. +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/canonicalization/canonicalize_rref.jl b/ext/QuantumCliffordKAExt/canonicalization/canonicalize_rref.jl new file mode 100644 index 000000000..f25087494 --- /dev/null +++ b/ext/QuantumCliffordKAExt/canonicalization/canonicalize_rref.jl @@ -0,0 +1,213 @@ + +#=============================================================================# +import QuantumClifford: canonicalize_rref! + +function device_canonicalize_rref!( + ph::AbstractArray{<: Unsigned}, xzs::AbstractArray{T}, + output_buffer::Union{Nothing, AbstractArray{<: Integer}} = nothing, + bit_masks::Union{Nothing, AbstractArray{T}} = nothing; + multiplication_order::MultiplicationOrder = default_multiplication_order, + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + )::Nothing where { + T <: Unsigned, phase_B, primary_axis_E, block_SZ, batch_SZ + } + + phase_B isa Bool && primary_axis_E isa PrimaryAxis && + block_SZ isa Integer && block_SZ > zero(block_SZ) && + batch_SZ isa Integer && batch_SZ > zero(batch_SZ) || + throw(ArgumentError(THROW_VALS)) + + backend = KA.get_backend(xzs) + + if primary_axis_E == primary_axis_rows + tile = (one(block_SZ), block_SZ) + space = tessellate( + (size(xzs, 0x2), cld(size(xzs, 0x1) >> 0x1, batch_SZ)), + tile + ) + elseif primary_axis_E == primary_axis_qubits + tile = (block_SZ, one(block_SZ)) + space = tessellate( + (cld(size(xzs, 0x1) >> 0x1, batch_SZ), size(xzs, 0x2)), + tile + ) + end + + # Utilised for loop management. + length_xzs = size(xzs, 0x1) >> 0x1 + row_count = size(xzs, 0x2) + toggle = false + cycles_until_sync = default_scheduling_limit + + # Required for safety whilst setting up for the proceeding iteration. + mutex = create_mutex(backend) + # Double buffered for present and preceeding/proceeding iteration. + stride_fill = tracker_element_count + tracker = similar(xzs, Csize_t, stride_fill << 0x1) + fill!(tracker, typemax(Csize_t)) + # The pivot row tracker is initialised differently. + begin_fill = Integer(tracker_content_swap_to) + @inbounds fill!( + (@view tracker[begin_fill : stride_fill : end]), + row_count + ) + + if !isnothing(output_buffer) + fill!(output_buffer, row_count) + end + + bit_scan = kernel_bit_scan(backend) + snippet! = kernel_snippet!(backend) + mul_and_scan! = kernel_mul_and_scan!(backend) + + bit_scan( + xzs, bit_masks, mutex, tracker, + Val(sort_order_qubit_number_prefer_x), + primary_axis, block_size, batch_size; + workgroupsize = tile, ndrange = space + ) + for _ in one(row_count) : row_count + snippet!( + snippet_swap_rows_prepare_tracker!, + ph, xzs, tracker, toggle; + ndrange = length_xzs + ) + snippet!( + snippet_set_row_phase_flag!, + ph, xzs, tracker, toggle; + ndrange = row_count + ) + mul_and_scan!( + ph, xzs, multiplication_order, scan_side_lesser, + bit_masks, isnothing(bit_masks), mutex, tracker, toggle, + phases, Val(sort_order_qubit_number_prefer_x), + primary_axis, block_size, batch_size; + workgroupsize = tile, ndrange = space + ) + # Switching the toggle is intentional. + snippet!( + snippet_track_pivot_canonicalize_rref!, + output_buffer, tracker, !toggle; + ndrange = 0x1 + ) + + toggle = !toggle + cycles_until_sync -= one(cycles_until_sync) + if cycles_until_sync == zero(cycles_until_sync) + KA.synchronize(backend) + host_tracker = Array(tracker) + offset = ifelse(toggle, tracker_element_count, 0x0) + @inbounds continue_flag = + host_tracker[offset + Integer(tracker_content_bit_type)] < + Integer(pauli_bit_invalid) + if continue_flag + cycles_until_sync = default_scheduling_limit + else + break + end + end + end + + # Remove all extraneous bits left behind by previous iterations. + snippet!(snippet_mod_4_phase!, ph; ndrange = row_count) + + return nothing + +end + +# Tableau +@inline function canonicalize_rref!( + tab::DeviceTableau, + output_buffer::Union{Nothing, AbstractGPUArray{<: Integer}} = nothing, + bit_masks::Union{Nothing, AbstractGPUArray{<: Unsigned}} = nothing; + multiplication_order::MultiplicationOrder = default_multiplication_order, + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + @boundscheck begin + if !isnothing(output_buffer) + length(output_buffer) == 0x1 || + throw(ArgumentError(THROW_SIZE)) + end + if !isnothing(bit_masks) + length(bit_masks) == size(tab.xzs, 0x1) >> 0x1 || + throw(DimensionMismatch(THROW_SIZE)) + end + end + + device_canonicalize!( + tab.phases, tab.xzs, output_buffer, bit_masks; + multiplication_order = multiplication_order, + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + + if isnothing(output_buffer) + return tab + else + return tab, output_buffer + end + +end + +# AbstractStabilizer +@inline function canonicalize_rref!( + state::DeviceAbstractStabilizer, + output_buffer::Union{Nothing, AbstractGPUArray{<: Integer}} = nothing, + bit_masks::Union{Nothing, AbstractGPUArray{<: Unsigned}} = nothing; + multiplication_order::MultiplicationOrder = default_multiplication_order, + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + @boundscheck begin + if !isnothing(output_buffer) + length(output_buffer) == 0x1 || + throw(ArgumentError(THROW_SIZE)) + end + if !isnothing(bit_masks) + length(bit_masks) == size(state.tab.xzs, 0x1) >> 0x1 || + throw(DimensionMismatch(THROW_SIZE)) + end + end + + if state isa Stabilizer + upper = size(state.tab.xzs, 0x2) + lower = one(upper) + elseif state isa Destabilizer + upper = size(state.tab.xzs, 0x2) + lower = (upper >> one(upper)) + one(upper) + elseif state isa MixedStabilizer + upper = state.rank + lower = one(upper) + elseif state isa MixedDestabilizer + upper = size(state.tab.xzs, 0x2) + lower = (upper >> one(upper)) + one(upper) + upper = (upper >> one(upper)) + state.rank + end + + @inbounds device_canonicalize_rref!( + (@view state.tab.phases[lower : upper]), + (@view state.tab.xzs[:, lower : upper]), + output_buffer, bit_masks; + multiplication_order = multiplication_order, + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + + if isnothing(output_buffer) + return state + else + return state, output_buffer + end + +end +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/canonicalization/common.jl b/ext/QuantumCliffordKAExt/canonicalization/common.jl new file mode 100644 index 000000000..e27cdddec --- /dev/null +++ b/ext/QuantumCliffordKAExt/canonicalization/common.jl @@ -0,0 +1,551 @@ + +#=============================================================================# +# TODO: Make the parameters keyword arguments once support becomes available. +KA.@kernel inbounds = true unsafe_indices = true function kernel_bit_scan( + xzs::AbstractArray{T}, bit_masks::Union{Nothing, AbstractArray{T}}, + mutex::AbstractMutex, tracker::AbstractArray{S}, + ::Val{sort_order}, + ::Val{primary_axis}, ::Val{block_size}, ::Val{batch_size} + ) where { + T <: Unsigned, S <: Unsigned, + sort_order, primary_axis, block_size, batch_size + } + + if primary_axis == primary_axis_rows + j, begin_i = global_index( + KA.@index(Group, NTuple), KA.@groupsize(), KA.@index(Local, NTuple) + ) + stride_i = KA.@ndrange()[0x2] + elseif primary_axis == primary_axis_qubits + begin_i, j = global_index( + KA.@index(Group, NTuple), KA.@groupsize(), KA.@index(Local, NTuple) + ) + stride_i = KA.@ndrange()[0x1] + end + end_i = KA.@uniform (size(xzs, 0x1) >> 0x1) + + index = typemax(S) + bit_shift = typemax(S) + bit_type = pauli_bit_invalid + index_buffer = KA.@localmem S block_size + bit_shift_buffer = KA.@localmem S block_size + bit_type_buffer = KA.@localmem DeviceUnsigned block_size + + scan_target = @view xzs[:, j] + + for (i, _) in zip(begin_i : stride_i : end_i, one(batch_size) : batch_size) + x_bits = scan_target[i] + z_bits = scan_target[i + end_i] + if !isnothing(bit_masks) + mask = bit_masks[i] + x_bits &= mask + z_bits &= mask + end + + index, bit_shift, bit_type, break_flag = scan_step( + x_bits, z_bits, i, index, bit_shift, bit_type, Val(sort_order) + ) + break_flag && break + end + + local_index = KA.@index(Local, Linear) + bit_shift_buffer[local_index] = bit_shift + index_buffer[local_index] = index + if sort_order in ( + sort_order_pauli_bit_prefer_x, sort_order_qubit_number_prefer_x + ) + + bit_type_buffer[local_index] = Integer(bit_type) + + elseif sort_order in ( + sort_order_pauli_bit_prefer_z, sort_order_qubit_number_prefer_z + ) + + # This reverses the ordering of the Pauli bits. + bit_type_buffer[local_index] = xor(Integer(bit_type), 0x1) + + end + + if sort_order in ( + sort_order_pauli_bit_prefer_x, sort_order_pauli_bit_prefer_z + ) + + shared_memory_reduce!( + reduce_lexicographic_min!, local_index, Val(block_size), + bit_type_buffer, index_buffer, bit_shift_buffer + ) + + elseif sort_order in ( + sort_order_qubit_number_prefer_x, sort_order_qubit_number_prefer_z + ) + + shared_memory_reduce!( + reduce_lexicographic_min!, local_index, Val(block_size), + index_buffer, bit_shift_buffer, bit_type_buffer + ) + + end + + if local_index == one(local_index) + + if sort_order in ( + sort_order_pauli_bit_prefer_x, sort_order_qubit_number_prefer_x + ) + + enumless_bit_type = bit_type_buffer[local_index] + + elseif sort_order in ( + sort_order_pauli_bit_prefer_z, sort_order_qubit_number_prefer_z + ) + + # This reverses the ordering of the Pauli bits. + enumless_bit_type = xor(bit_type_buffer[local_index], 0x1) + + end + + # Avoid mutex contention if there is no valid contribution. + if enumless_bit_type < Integer(pauli_bit_invalid) + index = index_buffer[local_index] + bit_shift = bit_shift_buffer[local_index] + + lock_mutex!( + mutex, + (@view tracker[Integer(tracker_content_index)]), + (@view tracker[Integer(tracker_content_bit_shift)]), + (@view tracker[Integer(tracker_content_bit_type)]), + (@view tracker[Integer(tracker_content_swap_from)]) + ) + + temp_index = tracker[Integer(tracker_content_index)] + temp_bit_shift = tracker[Integer(tracker_content_bit_shift)] + temp_bit_type = tracker[Integer(tracker_content_bit_type)] + temp_row = tracker[Integer(tracker_content_swap_from)] + + if sort_order == sort_order_pauli_bit_prefer_x + + lower_than_current = isless( + (enumless_bit_type, index, bit_shift, j), + (temp_bit_type, temp_index, temp_bit_shift, temp_row) + ) + + elseif sort_order == sort_order_pauli_bit_prefer_z + + # This reverses the ordering of the Pauli bits. + temp_bit_type = xor(temp_bit_type, 0x1) + lower_than_current = isless( + (xor(enumless_bit_type, 0x1), index, bit_shift, j), + (temp_bit_type, temp_index, temp_bit_shift, temp_row) + ) + + elseif sort_order == sort_order_qubit_number_prefer_x + + lower_than_current = isless( + (index, bit_shift, enumless_bit_type, j), + (temp_index, temp_bit_shift, temp_bit_type, temp_row) + ) + + elseif sort_order == sort_order_qubit_number_prefer_z + + # This reverses the ordering of the Pauli bits. + temp_bit_type = xor(temp_bit_type, 0x1) + lower_than_current = isless( + (index, bit_shift, xor(enumless_bit_type, 0x1), j), + (temp_index, temp_bit_shift, temp_bit_type, temp_row) + ) + + end + + if lower_than_current + tracker[Integer(tracker_content_index)] = index + tracker[Integer(tracker_content_bit_shift)] = bit_shift + tracker[Integer(tracker_content_bit_type)] = enumless_bit_type + tracker[Integer(tracker_content_swap_from)] = j + end + + unlock_mutex!( + mutex, + (@view tracker[Integer(tracker_content_index)]), + (@view tracker[Integer(tracker_content_bit_shift)]), + (@view tracker[Integer(tracker_content_bit_type)]), + (@view tracker[Integer(tracker_content_swap_from)]) + ) + end + + end + +end + +# CAUTION: Keep in mind that the constants match the direction of the order. +# TODO: Make the parameters keyword arguments once support becomes available. +KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( + ph::AbstractArray{P}, xzs::AbstractArray{T}, + multiplication_order::MultiplicationOrder, scan_side::ScanSide, + bit_masks::Union{Nothing, AbstractArray{T}}, shrink_workspace::Bool, + mutex::AbstractMutex, tracker::AbstractArray{S}, toggle::Bool, + ::Val{phases}, ::Val{sort_order}, + ::Val{primary_axis}, ::Val{block_size}, ::Val{batch_size} + ) where { + P <: Unsigned, T <: Unsigned, S <: Unsigned, + phases, sort_order, primary_axis, block_size, batch_size + } + + if primary_axis == primary_axis_rows + j, begin_i = global_index( + KA.@index(Group, NTuple), KA.@groupsize(), KA.@index(Local, NTuple) + ) + stride_i = KA.@ndrange()[0x2] + elseif primary_axis == primary_axis_qubits + begin_i, j = global_index( + KA.@index(Group, NTuple), KA.@groupsize(), KA.@index(Local, NTuple) + ) + stride_i = KA.@ndrange()[0x1] + end + end_i = KA.@uniform (size(xzs, 0x1) >> 0x1) + + local_index = KA.@index(Local, Linear) + # Layout is [index, row]. + shared_parameters = KA.@localmem S 2 + # Dense storage as flags in a bit field. + shared_flags = KA.@localmem DeviceUnsigned 1 + if local_index == one(local_index) + + shared_flags[0x1] = ifelse( + begin_i == one(begin_i), + Integer(bit_field_flag_leader), + zero(DeviceUnsigned) + ) + + current = ifelse(toggle, tracker_element_count, 0x0) + temp_index = tracker[current + Integer(tracker_content_index)] + temp_bit_type = tracker[current + Integer(tracker_content_bit_type)] + temp_row = tracker[current + Integer(tracker_content_swap_to)] + + new_begin_i = ifelse( + shrink_workspace, + begin_i + temp_index - one(S), + begin_i + zero(S) + ) + + # Opt for an early exit if any of these conditions fail. + continue_flag = + temp_row != j && begin_i <= new_begin_i <= end_i && + temp_bit_type < Integer(pauli_bit_invalid) + + if continue_flag + shared_parameters[0x1] = temp_index + shared_parameters[0x2] = temp_row + + if scan_side == scan_side_lesser + flags = ifelse( + j < temp_row, + Integer(bit_field_flag_scan), + zero(DeviceUnsigned) + ) + elseif scan_side == scan_side_greater + flags = ifelse( + j > temp_row, + Integer(bit_field_flag_scan), + zero(DeviceUnsigned) + ) + end + + flags |= ifelse( + ph[j] & top_bits(0x1, P) != zero(P), + Integer(bit_field_flag_multiply), + zero(DeviceUnsigned) + ) + + shared_flags[0x1] |= flags + end + + end + KA.@synchronize() + flags = shared_flags[0x1] + + if flags & ( + Integer(bit_field_flag_scan) | Integer(bit_field_flag_multiply) + ) != zero(DeviceUnsigned) + + if phases + low = zero(T) + high = zero(T) + phase_buffer = KA.@localmem DeviceUnsigned block_size + end + + index = typemax(S) + bit_shift = typemax(S) + bit_type = pauli_bit_invalid + index_buffer = KA.@localmem S block_size + bit_shift_buffer = KA.@localmem S block_size + bit_type_buffer = KA.@localmem DeviceUnsigned block_size + + new_begin_i = ifelse( + shrink_workspace, + begin_i + shared_parameters[0x1] - one(S), + begin_i + zero(S) + ) + + if begin_i <= new_begin_i <= end_i + + begin_i = new_begin_i + + # Equivalent to bit_scan. + if flags & ( + Integer(bit_field_flag_scan) | Integer(bit_field_flag_multiply) + ) == Integer(bit_field_flag_scan) + + scan_target = @view xzs[:, j] + + for (i, _) in zip(begin_i : stride_i : end_i, one(batch_size) : batch_size) + x_bits = scan_target[i] + z_bits = scan_target[i + end_i] + if !isnothing(bit_masks) + mask = bit_masks[i] + x_bits &= mask + z_bits &= mask + end + + index, bit_shift, bit_type, break_flag = scan_step( + x_bits, z_bits, i, index, bit_shift, bit_type, Val(sort_order) + ) + break_flag && break + end + + else + + write_xzs = @view xzs[:, j] + if multiplication_order == multiplication_order_left + left = @view xzs[:, shared_parameters[0x2]] + right = write_xzs + elseif multiplication_order == multiplication_order_right + left = write_xzs + right = @view xzs[:, shared_parameters[0x2]] + end + + # Equivalent to mul!. + if flags & ( + Integer(bit_field_flag_scan) | Integer(bit_field_flag_multiply) + ) == Integer(bit_field_flag_multiply) + + for (i, _) in zip(begin_i : stride_i : end_i, one(batch_size) : batch_size) + x_left = left[i] + z_left = left[i + end_i] + x_right = right[i] + z_right = right[i + end_i] + + x_new = xor(x_left, x_right) + z_new = xor(z_left, z_right) + if phases + xl_zr = x_left & z_right + merged = xor(xl_zr, z_left & x_right) + high = xor(high, xor(low, x_new, z_new, xl_zr) & merged) + low = xor(low, merged) + end + + write_xzs[i] = x_new + write_xzs[i + end_i] = z_new + end + + # Equivalent to joint mul! and a modified bit_scan. + else + + for (i, _) in zip(begin_i : stride_i : end_i, one(batch_size) : batch_size) + x_left = left[i] + z_left = left[i + end_i] + x_right = right[i] + z_right = right[i + end_i] + + x_new = xor(x_left, x_right) + z_new = xor(z_left, z_right) + if phases + xl_zr = x_left & z_right + merged = xor(xl_zr, z_left & x_right) + high = xor(high, xor(low, x_new, z_new, xl_zr) & merged) + low = xor(low, merged) + end + + write_xzs[i] = x_new + write_xzs[i + end_i] = z_new + + if !isnothing(bit_masks) + mask = bit_masks[i] + x_new &= mask + z_new &= mask + end + + index, bit_shift, bit_type = scan_step( + x_new, z_new, i, index, bit_shift, bit_type, Val(sort_order) + ) + + end + + # Marks the end for flags == Integer(bit_field_flag_multiply)/else + end + + # Marks the end for flags == Integer(bit_field_flag_scan)/else + end + + # Marks the end for begin_i <= new_begin_i <= end_i + end + + # Multiplication phase reduction and update. + if phases + if flags & Integer(bit_field_flag_multiply) != zero(DeviceUnsigned) + phase_buffer[local_index] = + ((count_ones(high) << 0x1) + count_ones(low)) & 0x3 + shared_memory_reduce!( + reduce_sum!, local_index, Val(block_size), phase_buffer + ) + + if local_index == one(local_index) + temp_ph = phase_buffer[local_index] + if flags & Integer(bit_field_flag_leader) != + zero(DeviceUnsigned) + temp_ph += ph[shared_parameters[0x2]] + end + # CAUTION: This is sufficient since only atomicity is required. + @atomic :monotonic ph[j] += temp_ph & 0x3 + # CAUTION: Avoid nullifying the multiplication flag bit! + @atomic :monotonic ph[j] &= 0x3 | top_bits(0x1, P) + end + end + end + + # Pauli bit scan reduction and update. + if flags & Integer(bit_field_flag_scan) != zero(DeviceUnsigned) + bit_shift_buffer[local_index] = bit_shift + index_buffer[local_index] = index + if sort_order in ( + sort_order_pauli_bit_prefer_x, sort_order_qubit_number_prefer_x + ) + + bit_type_buffer[local_index] = Integer(bit_type) + + elseif sort_order in ( + sort_order_pauli_bit_prefer_z, sort_order_qubit_number_prefer_z + ) + + # This reverses the ordering of the Pauli bits. + bit_type_buffer[local_index] = xor(Integer(bit_type), 0x1) + + end + + if sort_order in ( + sort_order_pauli_bit_prefer_x, sort_order_pauli_bit_prefer_z + ) + + shared_memory_reduce!( + reduce_lexicographic_min!, local_index, Val(block_size), + bit_type_buffer, index_buffer, bit_shift_buffer + ) + + elseif sort_order in ( + sort_order_qubit_number_prefer_x, sort_order_qubit_number_prefer_z + ) + + shared_memory_reduce!( + reduce_lexicographic_min!, local_index, Val(block_size), + index_buffer, bit_shift_buffer, bit_type_buffer + ) + + end + + if local_index == one(local_index) + + if sort_order in ( + sort_order_pauli_bit_prefer_x, sort_order_qubit_number_prefer_x + ) + + enumless_bit_type = bit_type_buffer[local_index] + + elseif sort_order in ( + sort_order_pauli_bit_prefer_z, sort_order_qubit_number_prefer_z + ) + + # This reverses the ordering of the Pauli bits. + enumless_bit_type = xor(bit_type_buffer[local_index], 0x1) + + end + + # Avoid mutex contention if there is no valid contribution. + if enumless_bit_type < Integer(pauli_bit_invalid) + index = index_buffer[local_index] + bit_shift = bit_shift_buffer[local_index] + next = ifelse(toggle, 0x0, tracker_element_count) + + lock_mutex!( + mutex, + (@view tracker[next + Integer(tracker_content_index)]), + (@view tracker[next + Integer(tracker_content_bit_shift)]), + (@view tracker[next + Integer(tracker_content_bit_type)]), + (@view tracker[next + Integer(tracker_content_swap_from)]) + ) + + temp_index = tracker[next + Integer(tracker_content_index)] + temp_bit_shift = + tracker[next + Integer(tracker_content_bit_shift)] + temp_bit_type = + tracker[next + Integer(tracker_content_bit_type)] + temp_row = tracker[next + Integer(tracker_content_swap_from)] + + if sort_order == sort_order_pauli_bit_prefer_x + + lower_than_current = isless( + (enumless_bit_type, index, bit_shift, j), + (temp_bit_type, temp_index, temp_bit_shift, temp_row) + ) + + elseif sort_order == sort_order_pauli_bit_prefer_z + + # This reverses the ordering of the Pauli bits. + temp_bit_type = xor(temp_bit_type, 0x1) + lower_than_current = isless( + (xor(enumless_bit_type, 0x1), index, bit_shift, j), + (temp_bit_type, temp_index, temp_bit_shift, temp_row) + ) + + elseif sort_order == sort_order_qubit_number_prefer_x + + lower_than_current = isless( + (index, bit_shift, enumless_bit_type, j), + (temp_index, temp_bit_shift, temp_bit_type, temp_row) + ) + + elseif sort_order == sort_order_qubit_number_prefer_z + + # This reverses the ordering of the Pauli bits. + temp_bit_type = xor(temp_bit_type, 0x1) + lower_than_current = isless( + (index, bit_shift, xor(enumless_bit_type, 0x1), j), + (temp_index, temp_bit_shift, temp_bit_type, temp_row) + ) + + end + + if lower_than_current + tracker[next + Integer(tracker_content_index)] = index + tracker[next + Integer(tracker_content_bit_shift)] = + bit_shift + tracker[next + Integer(tracker_content_bit_type)] = + enumless_bit_type + tracker[next + Integer(tracker_content_swap_from)] = j + end + + + unlock_mutex!( + mutex, + (@view tracker[next + Integer(tracker_content_index)]), + (@view tracker[next + Integer(tracker_content_bit_shift)]), + (@view tracker[next + Integer(tracker_content_bit_type)]), + (@view tracker[next + Integer(tracker_content_swap_from)]) + ) + end + + end + end + + # Marks the end for flags != zero(DeviceUnsigned) + end + +end +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/definitions.jl b/ext/QuantumCliffordKAExt/definitions.jl index d146432cd..7a152ce91 100644 --- a/ext/QuantumCliffordKAExt/definitions.jl +++ b/ext/QuantumCliffordKAExt/definitions.jl @@ -1,37 +1,9 @@ #=============================================================================# -# Reasonable size that is generally ideal for most vendors and usecases. -const default_block_size = 256 -# Ameliorate overhead and enhance performance by doing more work per thread. -const default_batch_size = 32 - -# Most ideal type for shared reductions due to avoiding bank conflicts. -const DeviceUnsigned = UInt32 - -# Keeps function definitions succinct. -const DevicePauliOperator = PauliOperator{T_P, T_XZ} where { - T_P <: AbstractGPUArray, T_XZ <: AbstractGPUArray - } -const DeviceTableau = Tableau{T_P, T_XZ} where { - T_P <: AbstractGPUArray, T_XZ <: AbstractGPUArray - } -# This is a bit redundant but it keeps the REPL expansions more readable. -const DeviceUnionStabilizer = Union{ - Stabilizer{T}, MixedStabilizer{T} - } where { - T_P <: AbstractGPUArray, T_XZ <: AbstractGPUArray, - T <: Tableau{T_P, T_XZ} - } -const DeviceUnionDestabilizer = Union{ - Destabilizer{T}, MixedDestabilizer{T} - } where { - T_P <: AbstractGPUArray, T_XZ <: AbstractGPUArray, - T <: Tableau{T_P, T_XZ} - } -const DeviceAbstractStabilizer = Union{ - Stabilizer{T}, MixedStabilizer{T}, Destabilizer{T}, MixedDestabilizer{T} - } where { - T_P <: AbstractGPUArray, T_XZ <: AbstractGPUArray, - T <: Tableau{T_P, T_XZ} - } +include("definitions/word_size_integers.jl") +include("definitions/enumerations.jl") +include("definitions/fixed_sizes.jl") +include("definitions/default_parameters.jl") +include("definitions/type_shorthands.jl") +include("definitions/mutex_configuration.jl") #=============================================================================# diff --git a/ext/QuantumCliffordKAExt/definitions/default_parameters.jl b/ext/QuantumCliffordKAExt/definitions/default_parameters.jl new file mode 100644 index 000000000..02a022154 --- /dev/null +++ b/ext/QuantumCliffordKAExt/definitions/default_parameters.jl @@ -0,0 +1,15 @@ + +#=============================================================================# +# Maintains compatibility with the main package unless explicitly specified. +const default_multiplication_order = multiplication_order_left +# Strict correctness unless there is an explicit opt-out. +const default_phases = true +# Potentially boosts cache hits and reduces atomic contention. +const default_primary_axis = primary_axis_rows +# Reasonable size that is generally ideal for most vendors and use cases. +const default_block_size = 256 +# Ameliorate overhead and enhance performance by doing more work per thread. +const default_batch_size = 32 +# TODO: Eliminate this in favour of complete asynchronicity. +const default_scheduling_limit = 64 +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/definitions/enumerations.jl b/ext/QuantumCliffordKAExt/definitions/enumerations.jl new file mode 100644 index 000000000..a8f55420b --- /dev/null +++ b/ext/QuantumCliffordKAExt/definitions/enumerations.jl @@ -0,0 +1,58 @@ + +#=============================================================================# +# There is vanishing overhead for supporting both of these. +@enum MultiplicationOrder::UInt8 begin + multiplication_order_left + multiplication_order_right +end + +# Determines whether the first grid dimension matches the rows or the qubits. +@enum PrimaryAxis::UInt8 begin + primary_axis_rows + primary_axis_qubits +end + +#============================================================================== +UTILISED INTERNALLY +==============================================================================# + +# Pauli bit: Highest X(Z) < Lowest Z(X). Qubit Number: 1X(Z) < 1Z(X) < 2X(Z)... +@enum SortOrder::UInt8 begin + sort_order_pauli_bit_prefer_x + sort_order_pauli_bit_prefer_z + sort_order_qubit_number_prefer_x + sort_order_qubit_number_prefer_z +end + +# Determines whether contraction proceeds from high to low or low to high. +@enum ScanSide::UInt8 begin + scan_side_lesser + scan_side_greater +end + +# Provides enhanced clarity over plain numerical values. +# CAUTION: The values are NOT arbitrary but utilised for proper indexing. +@enum TrackerContent::UInt8 begin + tracker_content_index = 0x1 + tracker_content_bit_shift = 0x2 + tracker_content_bit_type = 0x3 + tracker_content_swap_from = 0x4 + tracker_content_swap_to = 0x5 +end + +# Provides enhanced clarity over plain numerical values. +# CAUTION: The values are NOT arbitrary but utilised for proper ordering. +@enum PauliBit::UInt8 begin + pauli_bit_x = 0x0 + pauli_bit_z = 0x1 + pauli_bit_invalid = 0x2 +end + +# Provides enhanced clarity over plain numerical values. +# CAUTION: The values are NOT arbitrary but utilised for proper masking. +@enum BitFieldFlags::DeviceUnsigned begin + bit_field_flag_leader = 0x1 + bit_field_flag_scan = 0x2 + bit_field_flag_multiply = 0x4 +end +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/definitions/fixed_sizes.jl b/ext/QuantumCliffordKAExt/definitions/fixed_sizes.jl new file mode 100644 index 000000000..f6de0e5f2 --- /dev/null +++ b/ext/QuantumCliffordKAExt/definitions/fixed_sizes.jl @@ -0,0 +1,5 @@ + +#=============================================================================# +# TODO: Figure out a more elegant solution that Julia approves of. +const tracker_element_count = 0x5 +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/definitions/mutex_configuration.jl b/ext/QuantumCliffordKAExt/definitions/mutex_configuration.jl new file mode 100644 index 000000000..f9917975d --- /dev/null +++ b/ext/QuantumCliffordKAExt/definitions/mutex_configuration.jl @@ -0,0 +1,11 @@ + +#=============================================================================# +# TODO: Eliminate these once KernelAbstractions becomes more feature complete. +AbstractMutex = AbstractArray{DeviceUnsigned, 0} +# These should be confined to an enum but atomics require type conversion. +const mutex_state_locked::DeviceUnsigned = 0x0 +const mutex_state_unlocked::DeviceUnsigned = 0x1 +# Defines the mappings for the compare-and-swap(CAS)/@atomicreplace call. +const mutex_exchange_lock = mutex_state_unlocked => mutex_state_locked +const mutex_exchange_unlock = mutex_state_locked => mutex_state_unlocked +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/definitions/type_shorthands.jl b/ext/QuantumCliffordKAExt/definitions/type_shorthands.jl new file mode 100644 index 000000000..553052c3f --- /dev/null +++ b/ext/QuantumCliffordKAExt/definitions/type_shorthands.jl @@ -0,0 +1,29 @@ + +#=============================================================================# +# Keeps the function definitions succinct. +const DevicePauliOperator = PauliOperator{T_P, T_XZ} where { + T_P <: AbstractGPUArray, T_XZ <: AbstractGPUArray + } +const DeviceTableau = Tableau{T_P, T_XZ} where { + T_P <: AbstractGPUArray, T_XZ <: AbstractGPUArray + } +# This is a bit redundant but it keeps the REPL expansions more readable. +const DeviceUnionStabilizer = Union{ + Stabilizer{T}, MixedStabilizer{T} + } where { + T_P <: AbstractGPUArray, T_XZ <: AbstractGPUArray, + T <: Tableau{T_P, T_XZ} + } +const DeviceUnionDestabilizer = Union{ + Destabilizer{T}, MixedDestabilizer{T} + } where { + T_P <: AbstractGPUArray, T_XZ <: AbstractGPUArray, + T <: Tableau{T_P, T_XZ} + } +const DeviceAbstractStabilizer = Union{ + Stabilizer{T}, MixedStabilizer{T}, Destabilizer{T}, MixedDestabilizer{T} + } where { + T_P <: AbstractGPUArray, T_XZ <: AbstractGPUArray, + T <: Tableau{T_P, T_XZ} + } +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/definitions/word_size_integers.jl b/ext/QuantumCliffordKAExt/definitions/word_size_integers.jl new file mode 100644 index 000000000..96f1fd1c7 --- /dev/null +++ b/ext/QuantumCliffordKAExt/definitions/word_size_integers.jl @@ -0,0 +1,7 @@ + +#=============================================================================# +# Most ideal type for shared reductions due to avoiding bank conflicts. +# TODO: Should these be modified to Cint/Cuint? +const DeviceSigned = Int32 +const DeviceUnsigned = UInt32 +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/imports.jl b/ext/QuantumCliffordKAExt/imports.jl index 2f95adc5e..30bdf449c 100644 --- a/ext/QuantumCliffordKAExt/imports.jl +++ b/ext/QuantumCliffordKAExt/imports.jl @@ -2,7 +2,7 @@ #=============================================================================# import KernelAbstractions as KA -using Atomix: @atomic +using Atomix: @atomic, @atomicreplace using GPUArraysCore: AbstractGPUArray # Resolves issue due to KA comparing against the literal Symbol("@Const"). using KernelAbstractions: @Const diff --git a/ext/QuantumCliffordKAExt/mul_leftright.jl b/ext/QuantumCliffordKAExt/mul_leftright.jl index b3e9d3d6a..bde47ce2b 100644 --- a/ext/QuantumCliffordKAExt/mul_leftright.jl +++ b/ext/QuantumCliffordKAExt/mul_leftright.jl @@ -1,461 +1,5 @@ #=============================================================================# -# TODO: include the unsafe functions once the main package establishes them. -import QuantumClifford: mul_left!, mul_right! - -# CAUTION: Keep in mind that mutable = order_right_left ? right : left -# TODO: Make the parameters keyword arguments once support becomes available. -KA.@kernel inbounds = true unsafe_indices = true function kernel_mul!( - mutable_phases, mutable_xzs, @Const(const_xzs), - ::Val{order_right_left}, ::Val{phases}, - ::Val{block_size}, ::Val{batch_size} - ) where {order_right_left, phases, block_size, batch_size} - - # unsafe_indices is required for shared_memory_reduce, do this manually. - begin_i, j_mutable = global_index( - KA.@index(Group, NTuple), KA.@groupsize(), KA.@index(Local, NTuple) - ) - stride_i = KA.@uniform (KA.@ndrange()[1]) - j_const = ifelse(size(const_xzs, 2) > 1, j_mutable, 1) - end_i = KA.@uniform (size(mutable_xzs, 1) >> 1) - if phases - low = KA.@uniform (zero(eltype(mutable_xzs))) - high = KA.@uniform (zero(eltype(mutable_xzs))) - end - - for (i, _) in zip(begin_i : stride_i : end_i, one(batch_size) : batch_size) - if order_right_left - x_left = const_xzs[i, j_const] - z_left = const_xzs[i + end_i, j_const] - x_right = mutable_xzs[i, j_mutable] - z_right = mutable_xzs[i + end_i, j_mutable] - else - x_left = mutable_xzs[i, j_mutable] - z_left = mutable_xzs[i + end_i, j_mutable] - x_right = const_xzs[i, j_const] - z_right = const_xzs[i + end_i, j_const] - end - - x_new = xor(x_left, x_right) - z_new = xor(z_left, z_right) - if phases - xl_zr = x_left & z_right - merged = xor(xl_zr, z_left & x_right) - high = - xor(high, xor(low, x_new, z_new, xl_zr) & merged) - low = xor(low, merged) - end - - mutable_xzs[i, j_mutable] = x_new - mutable_xzs[i + end_i, j_mutable] = z_new - end - - if phases - value::DeviceUnsigned = (count_ones(high) << 1) + count_ones(low) - buffer = KA.@localmem DeviceUnsigned (block_size,) - index = KA.@index(Local, Linear) - value = shared_memory_reduce!(+, buffer, value, index, Val(block_size)) - if index == one(index) - @atomic mutable_phases[j_mutable] += value - end - end - -end - -# CAUTION: Requires either rows(const) == 1 or rows(const) == rows(mutable) -function device_mul!( - mutable_phases::AbstractGPUArray{<: Unsigned}, - mutable_xzs::AbstractGPUArray{T}, - const_phases::AbstractGPUArray{<: Unsigned}, - const_xzs::AbstractGPUArray{T}; - order_right_left::Val{right_left}, phases::Val{phase_B} = Val(true), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - )::Nothing where {T <: Unsigned, right_left, phase_B, block_SZ, batch_SZ} - - backend = KA.get_backend(mutable_xzs) - dim_x = cld(size(mutable_xzs, 1) >> 1, batch_SZ) - dim_y = size(mutable_xzs, 2) - if phase_B - transform! = kernel_transform!(backend) - transform!( - mod_4_sum!, mutable_phases, const_phases; ndrange = dim_y - ) - end - mul! = kernel_mul!(backend, (block_SZ, 1)) - mul!( - mutable_phases, mutable_xzs, const_xzs, - order_right_left, phases, block_size, batch_size; - workgroupsize = (block_SZ, 1), - ndrange = tessellate((dim_x, dim_y), (block_SZ, 1)) - ) - if phase_B - transform!( - mod_4_identity!, mutable_phases, nothing; ndrange = dim_y - ) - end - return nothing - -end - -# CAUTION: Meta-programming is utilised to in order to avoid repetition. -for (safe_f_sym, unsafe_f_sym, right_left) in ( - (:mul_left!, :do_mul_left!, true), (:mul_right!, :do_mul_right!, false) - ) - -#============================================================================== -RETURNS PAULI OPERATOR -==============================================================================# - -# PauliOperator - PauliOperator -@eval @inline function $safe_f_sym( - u::DevicePauliOperator, v::DevicePauliOperator; - phases::Val{phase_B} = Val(true), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, block_SZ, batch_SZ} - - u.nqubits == v.nqubits || throw(DimensionMismatch(THROW_NQUBITS)) - return $unsafe_f_sym( - u, v; - phases = phases, block_size = block_size, batch_size = batch_size - ) - -end - -@eval @inline function $unsafe_f_sym( - u::DevicePauliOperator, v::DevicePauliOperator; - phases::Val{phase_B} = Val(true), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, block_SZ, batch_SZ} - - device_mul!( - u.phase, u.xz, v.phase, v.xz; - order_right_left = Val($right_left), phases = phases, - block_size = block_size, batch_size = batch_size - ) - return u - -end - -for (T_v_sym, v_tab_sym) in ( - (:DeviceTableau, :v), (:DeviceAbstractStabilizer, :(v.tab)) - ) - -# PauliOperator - Tableau/AbstractStabilizer[i] -@eval @inline function $safe_f_sym( - u::DevicePauliOperator, v::$T_v_sym, i::Integer; - phases::Val{phase_B} = Val(true), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, block_SZ, batch_SZ} - - 1 <= i <= length($v_tab_sym.phases) || throw(BoundsError(THROW_BOUNDS)) - u.nqubits == $v_tab_sym.nqubits || throw(DimensionMismatch(THROW_NQUBITS)) - return $unsafe_f_sym( - u, v, i; - phases = phases, block_size = block_size, batch_size = batch_size - ) - -end - -@eval @inline function $unsafe_f_sym( - u::DevicePauliOperator, v::$T_v_sym, i::Integer; - phases::Val{phase_B} = Val(true), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, block_SZ, batch_SZ} - - @inbounds device_mul!( - u.phase, u.xz, - (@view $v_tab_sym.phases[i]), (@view $v_tab_sym.xzs[:, i]); - order_right_left = Val($right_left), phases = phases, - block_size = block_size, batch_size = batch_size - ) - return u - -end - -# Marks the end for (T_v_sym, v_tab_sym) -end - -#============================================================================== -RETURNS TABLEAU / ABSTRACT STABILIZER -==============================================================================# - -for (T_u_sym, u_tab_sym) in ( - (:DeviceTableau, :u), (:DeviceAbstractStabilizer, :(u.tab)) - ) - -# Tableau/AbstractStabilizer - PauliOperator -@eval @inline function $safe_f_sym( - u::$T_u_sym, v::DevicePauliOperator; - phases::Val{phase_B} = Val(true), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, block_SZ, batch_SZ} - - $u_tab_sym.nqubits == v.nqubits || throw(DimensionMismatch(THROW_NQUBITS)) - return $unsafe_f_sym( - u, v; - phases = phases, block_size = block_size, batch_size = batch_size - ) - -end - -@eval @inline function $unsafe_f_sym( - u::$T_u_sym, v::DevicePauliOperator; - phases::Val{phase_B} = Val(true), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, block_SZ, batch_SZ} - - device_mul!( - $u_tab_sym.phases, $u_tab_sym.xzs, v.phase, v.xz; - order_right_left = Val($right_left), phases = phases, - block_size = block_size, batch_size = batch_size - ) - return u - -end - -# Tableau/AbstractStabilizer[i] - PauliOperator -@eval @inline function $safe_f_sym( - u::$T_u_sym, i::Integer, v::DevicePauliOperator; - phases::Val{phase_B} = Val(true), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, block_SZ, batch_SZ} - - 1 <= i <= length($u_tab_sym.phases) || throw(BoundsError(THROW_BOUNDS)) - $u_tab_sym.nqubits == v.nqubits || throw(DimensionMismatch(THROW_NQUBITS)) - return $unsafe_f_sym( - u, i, v; - phases = phases, block_size = block_size, batch_size = batch_size - ) - -end - -@eval @inline function $unsafe_f_sym( - u::$T_u_sym, i::Integer, v::DevicePauliOperator; - phases::Val{phase_B} = Val(true), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, block_SZ, batch_SZ} - - @inbounds device_mul!( - (@view $u_tab_sym.phases[i]), (@view $u_tab_sym.xzs[:, i]), - v.phase, v.xz; - order_right_left = Val($right_left), phases = phases, - block_size = block_size, batch_size = batch_size - ) - return u - -end - -# CAUTION: (Mixed)Destabilizer is handled separately. -# Tableau/AbstractStabilizer[i] - Self[j] -@eval @inline function $safe_f_sym( - u::$T_u_sym, i::Integer, j::Integer; - phases::Val{phase_B} = Val(true), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, block_SZ, batch_SZ} - - len = length($u_tab_sym.phases) - 1 <= i <= len && 1 <= j <= len || throw(BoundsError(THROW_BOUNDS)) - return $unsafe_f_sym( - u, i, j; - phases = phases, block_size = block_size, batch_size = batch_size - ) - -end - -@eval @inline function $unsafe_f_sym( - u::$T_u_sym, i::Integer, j::Integer; - phases::Val{phase_B} = Val(true), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, block_SZ, batch_SZ} - - @inbounds device_mul!( - (@view $u_tab_sym.phases[i]), (@view $u_tab_sym.xzs[:, i]), - (@view $u_tab_sym.phases[j]), (@view $u_tab_sym.xzs[:, j]); - order_right_left = Val($right_left), phases = phases, - block_size = block_size, batch_size = batch_size - ) - return u - -end - -for (T_v_sym, v_tab_sym) in ( - (:DeviceTableau, :v), (:DeviceAbstractStabilizer, :(v.tab)) - ) - -# Tableau/AbstractStabilizer - Tableau/AbstractStabilizer -@eval @inline function $safe_f_sym( - u::$T_u_sym, v::$T_v_sym; - phases::Val{phase_B} = Val(true), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, block_SZ, batch_SZ} - - length($u_tab_sym.phases) == length($v_tab_sym.phases) || - throw(DimensionMismatch(THROW_SIZE)) - $u_tab_sym.nqubits == $v_tab_sym.nqubits || - throw(DimensionMismatch(THROW_NQUBITS)) - return $unsafe_f_sym( - u, v; - phases = phases, block_size = block_size, batch_size = batch_size - ) - -end - -@eval @inline function $unsafe_f_sym( - u::$T_u_sym, v::$T_v_sym; - phases::Val{phase_B} = Val(true), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, block_SZ, batch_SZ} - - device_mul!( - $u_tab_sym.phases, $u_tab_sym.xzs, $v_tab_sym.phases, $v_tab_sym.xzs; - order_right_left = Val($right_left), phases = phases, - block_size = block_size, batch_size = batch_size - ) - return u - -end - -# Tableau/AbstractStabilizer - Tableau/AbstractStabilizer[i] -@eval @inline function $safe_f_sym( - u::$T_u_sym, v::$T_v_sym, i::Integer; - phases::Val{phase_B} = Val(true), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, block_SZ, batch_SZ} - - 1 <= i <= length($v_tab_sym.phases) || throw(BoundsError(THROW_BOUNDS)) - $u_tab_sym.nqubits == $v_tab_sym.nqubits || - throw(DimensionMismatch(THROW_NQUBITS)) - return $unsafe_f_sym( - u, v, i; - phases = phases, block_size = block_size, batch_size = batch_size - ) - -end - -@eval @inline function $unsafe_f_sym( - u::$T_u_sym, v::$T_v_sym, i::Integer; - phases::Val{phase_B} = Val(true), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, block_SZ, batch_SZ} - - @inbounds device_mul!( - $u_tab_sym.phases, $u_tab_sym.xzs, - (@view $v_tab_sym.phases[i]), (@view $v_tab_sym.xzs[:, i]); - order_right_left = Val($right_left), phases = phases, - block_size = block_size, batch_size = batch_size - ) - return u - -end - -# Tableau/AbstractStabilizer[i] - Tableau/AbstractStabilizer[j] -@eval @inline function $safe_f_sym( - u::$T_u_sym, i::Integer, v::$T_v_sym, j::Integer; - phases::Val{phase_B} = Val(true), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, block_SZ, batch_SZ} - - 1 <= i <= length($u_tab_sym.phases) && - 1 <= j <= length($v_tab_sym.phases) || - throw(BoundsError(THROW_BOUNDS)) - $u_tab_sym.nqubits == $v_tab_sym.nqubits || - throw(DimensionMismatch(THROW_NQUBITS)) - return $unsafe_f_sym( - u, i, v, j; - phases = phases, block_size = block_size, batch_size = batch_size - ) - -end - -@eval @inline function $unsafe_f_sym( - u::$T_u_sym, i::Integer, v::$T_v_sym, j::Integer; - phases::Val{phase_B} = Val(true), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, block_SZ, batch_SZ} - - @inbounds device_mul!( - (@view $u_tab_sym.phases[i]), (@view $u_tab_sym.xzs[:, i]), - (@view $v_tab_sym.phases[j]), (@view $v_tab_sym.xzs[:, j]); - order_right_left = Val($right_left), phases = phases, - block_size = block_size, batch_size = batch_size - ) - return u - -end - -# Marks the end for (T_v_sym, v_tab_sym) -end - -# Marks the end for (T_u_sym, u_tab_sym) -end - -#============================================================================== -RETURNS (MIXED) DESTABILIZER -==============================================================================# - -# CAUTION: Requires special handling. -# (Mixed)Destabilizer[i] - Self[j] -@eval @inline function $safe_f_sym( - u::DeviceUnionDestabilizer, i::Integer, j::Integer; - phases::Val{phase_B} = Val(true), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, block_SZ, batch_SZ} - - len = length(u.tab.phases) - n = len >> 1 - all(x -> 1 <= x <= len, (i, j, i + n, j + n)) || - throw(BoundsError(THROW_BOUNDS)) - return $unsafe_f_sym( - u, i, j; - phases = phases, block_size = block_size, batch_size = batch_size - ) - -end - -@eval @inline function $unsafe_f_sym( - u::DeviceUnionDestabilizer, i::Integer, j::Integer; - phases::Val{phase_B} = Val(true), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, block_SZ, batch_SZ} - - p, xzs = u.tab.phases, u.tab.xzs - n = length(p) >> 1 - # Swapping the order of the indices is intentional. - @inbounds device_mul!( - (@view p[j]), (@view xzs[:, j]), - (@view p[i]), (@view xzs[:, i]); - order_right_left = Val($right_left), phases = Val(false), - block_size = block_size, batch_size = batch_size - ) - @inbounds device_mul!( - (@view p[i + n]), (@view xzs[:, i + n]), - (@view p[j + n]), (@view xzs[:, j + n]); - order_right_left = Val($right_left), phases = phases, - block_size = block_size, batch_size = batch_size - ) - return u - -end - -# Marks the end for (safe_f_sym, unsafe_f_sym, right_left) -end +include("mul_leftright/device_mul.jl") +include("mul_leftright/host_interface.jl") #=============================================================================# diff --git a/ext/QuantumCliffordKAExt/mul_leftright/device_mul.jl b/ext/QuantumCliffordKAExt/mul_leftright/device_mul.jl new file mode 100644 index 000000000..451ef0e29 --- /dev/null +++ b/ext/QuantumCliffordKAExt/mul_leftright/device_mul.jl @@ -0,0 +1,139 @@ + +#=============================================================================# +# CAUTION: Keep in mind that the constants match the direction of the order. +# TODO: Make the parameters keyword arguments once support becomes available. +KA.@kernel inbounds = true unsafe_indices = true function kernel_mul!( + mutable_phases::AbstractArray{<: Unsigned}, mutable_xzs::AbstractArray{T}, + @Const(const_xzs::AbstractArray{T}), + multiplication_order::MultiplicationOrder, + ::Val{phases}, ::Val{primary_axis}, ::Val{block_size}, ::Val{batch_size} + ) where { + T <: Unsigned, phases, primary_axis, block_size, batch_size + } + + if primary_axis == primary_axis_rows + j_mutable, begin_i = global_index( + KA.@index(Group, NTuple), KA.@groupsize(), KA.@index(Local, NTuple) + ) + stride_i = KA.@ndrange()[0x2] + elseif primary_axis == primary_axis_qubits + begin_i, j_mutable = global_index( + KA.@index(Group, NTuple), KA.@groupsize(), KA.@index(Local, NTuple) + ) + stride_i = KA.@ndrange()[0x1] + end + end_i = KA.@uniform (size(mutable_xzs, 0x1) >> 0x1) + flag = KA.@uniform (size(const_xzs, 0x2) > 0x1) + j_const = ifelse(flag, j_mutable, one(j_mutable)) + + if phases + low = zero(T) + high = zero(T) + phase_buffer = KA.@localmem DeviceUnsigned block_size + end + + mutable_xzs = @view mutable_xzs[:, j_mutable] + if multiplication_order == multiplication_order_left + left = @view const_xzs[:, j_const] + right = mutable_xzs + elseif multiplication_order == multiplication_order_right + left = mutable_xzs + right = @view const_xzs[:, j_const] + end + + for (i, _) in zip(begin_i : stride_i : end_i, one(batch_size) : batch_size) + x_left = left[i] + z_left = left[i + end_i] + x_right = right[i] + z_right = right[i + end_i] + + x_new = xor(x_left, x_right) + z_new = xor(z_left, z_right) + if phases + xl_zr = x_left & z_right + merged = xor(xl_zr, z_left & x_right) + high = xor(high, xor(low, x_new, z_new, xl_zr) & merged) + low = xor(low, merged) + end + + mutable_xzs[i] = x_new + mutable_xzs[i + end_i] = z_new + end + + if phases + local_index = KA.@index(Local, Linear) + phase_buffer[local_index] = + ((count_ones(high) << 0x1) + count_ones(low)) & 0x3 + shared_memory_reduce!( + reduce_sum!, local_index, Val(block_size), phase_buffer + ) + + if local_index == one(local_index) + # CAUTION: This is sufficient since only atomicity is required. + @atomic :monotonic mutable_phases[j_mutable] += + phase_buffer[local_index] & 0x3 + @atomic :monotonic mutable_phases[j_mutable] &= 0x3 + end + end + +end + +# CAUTION: Requires either rows(const) == 1 or rows(const) == rows(mutable) +function device_mul!( + mutable_phases::AbstractArray{<: Unsigned}, mutable_xzs::AbstractArray{T}, + const_phases::AbstractArray{<: Unsigned}, const_xzs::AbstractArray{T}, + multiplication_order::MultiplicationOrder; + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + )::Nothing where { + T <: Unsigned, phase_B, primary_axis_E, block_SZ, batch_SZ + } + + phase_B isa Bool && primary_axis_E isa PrimaryAxis && + block_SZ isa Integer && block_SZ > zero(block_SZ) && + batch_SZ isa Integer && batch_SZ > zero(batch_SZ) || + throw(ArgumentError(THROW_VALS)) + + backend = KA.get_backend(mutable_xzs) + + if primary_axis_E == primary_axis_rows + tile = (one(block_SZ), block_SZ) + space = tessellate( + ( + size(mutable_xzs, 0x2), + cld(size(mutable_xzs, 0x1) >> 0x1, batch_SZ) + ), + tile + ) + elseif primary_axis_E == primary_axis_qubits + tile = (block_SZ, one(block_SZ)) + space = tessellate( + ( + cld(size(mutable_xzs, 0x1) >> 0x1, batch_SZ), + size(mutable_xzs, 0x2) + ), + tile + ) + end + + if phase_B + snippet! = kernel_snippet!(backend) + @inbounds snippet!( + snippet_mod_4_sum_phase!, + mutable_phases, const_phases; + ndrange = length(mutable_phases) + ) + end + mul! = kernel_mul!(backend) + mul!( + mutable_phases, mutable_xzs, const_xzs, multiplication_order, + phases, primary_axis, block_size, batch_size; + workgroupsize = tile, ndrange = space + ) + + return nothing + +end +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/mul_leftright/host_interface.jl b/ext/QuantumCliffordKAExt/mul_leftright/host_interface.jl new file mode 100644 index 000000000..68466f638 --- /dev/null +++ b/ext/QuantumCliffordKAExt/mul_leftright/host_interface.jl @@ -0,0 +1,430 @@ + +#=============================================================================# +# TODO: include the unsafe functions once the main package establishes them. +import QuantumClifford: mul_left!, mul_right! + +# CAUTION: Meta-programming is utilised to in order to avoid repetition. +for (safe_f_sym, unsafe_f_sym, multiplication_order) in ( + (:mul_left!, :do_mul_left!, multiplication_order_left), + (:mul_right!, :do_mul_right!, multiplication_order_right) + ) + +#============================================================================== +RETURNS PAULI OPERATOR +==============================================================================# + +# PauliOperator - PauliOperator +@eval @inline function $safe_f_sym( + u::DevicePauliOperator, v::DevicePauliOperator; + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + @boundscheck begin + u.nqubits == v.nqubits || + throw(DimensionMismatch(THROW_NQUBITS)) + end + return $unsafe_f_sym( + u, v; + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + +end + +@eval @inline function $unsafe_f_sym( + u::DevicePauliOperator, v::DevicePauliOperator; + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + device_mul!( + u.phase, u.xz, v.phase, v.xz, $multiplication_order; + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + return u + +end + +for (T_v_sym, v_tab_sym) in ( + (:DeviceTableau, :v), (:DeviceAbstractStabilizer, :(v.tab)) + ) + +# PauliOperator - Tableau/AbstractStabilizer[i] +@eval @inline function $safe_f_sym( + u::DevicePauliOperator, v::$T_v_sym, i::Integer; + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + @boundscheck begin + one(i) <= i <= length($v_tab_sym.phases) || + throw(BoundsError(THROW_BOUNDS)) + u.nqubits == $v_tab_sym.nqubits || + throw(DimensionMismatch(THROW_NQUBITS)) + end + return $unsafe_f_sym( + u, v, i; + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + +end + +@eval @inline function $unsafe_f_sym( + u::DevicePauliOperator, v::$T_v_sym, i::Integer; + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + @inbounds device_mul!( + u.phase, u.xz, + (@view $v_tab_sym.phases[i]), (@view $v_tab_sym.xzs[:, i]), + $multiplication_order; + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + return u + +end + +# Marks the end for (T_v_sym, v_tab_sym) +end + +#============================================================================== +RETURNS TABLEAU / ABSTRACT STABILIZER +==============================================================================# + +for (T_u_sym, u_tab_sym) in ( + (:DeviceTableau, :u), (:DeviceAbstractStabilizer, :(u.tab)) + ) + +# Tableau/AbstractStabilizer - PauliOperator +@eval @inline function $safe_f_sym( + u::$T_u_sym, v::DevicePauliOperator; + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + @boundscheck begin + $u_tab_sym.nqubits == v.nqubits || + throw(DimensionMismatch(THROW_NQUBITS)) + end + return $unsafe_f_sym( + u, v; + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + +end + +@eval @inline function $unsafe_f_sym( + u::$T_u_sym, v::DevicePauliOperator; + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + device_mul!( + $u_tab_sym.phases, $u_tab_sym.xzs, v.phase, v.xz, + $multiplication_order; + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + return u + +end + +# Tableau/AbstractStabilizer[i] - PauliOperator +@eval @inline function $safe_f_sym( + u::$T_u_sym, i::Integer, v::DevicePauliOperator; + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + @boundscheck begin + one(i) <= i <= length($u_tab_sym.phases) || + throw(BoundsError(THROW_BOUNDS)) + $u_tab_sym.nqubits == v.nqubits || + throw(DimensionMismatch(THROW_NQUBITS)) + end + return $unsafe_f_sym( + u, i, v; + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + +end + +@eval @inline function $unsafe_f_sym( + u::$T_u_sym, i::Integer, v::DevicePauliOperator; + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + @inbounds device_mul!( + (@view $u_tab_sym.phases[i]), (@view $u_tab_sym.xzs[:, i]), + v.phase, v.xz, + $multiplication_order; + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + return u + +end + +# CAUTION: (Mixed)Destabilizer is handled separately. +# Tableau/AbstractStabilizer[i] - Self[j] +@eval @inline function $safe_f_sym( + u::$T_u_sym, i::Integer, j::Integer; + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + @boundscheck begin + len = length($u_tab_sym.phases) + one(i) <= i <= len && one(j) <= j <= len || + throw(BoundsError(THROW_BOUNDS)) + end + return $unsafe_f_sym( + u, i, j; + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + +end + +@eval @inline function $unsafe_f_sym( + u::$T_u_sym, i::Integer, j::Integer; + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + @inbounds device_mul!( + (@view $u_tab_sym.phases[i]), (@view $u_tab_sym.xzs[:, i]), + (@view $u_tab_sym.phases[j]), (@view $u_tab_sym.xzs[:, j]), + $multiplication_order; + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + return u + +end + +for (T_v_sym, v_tab_sym) in ( + (:DeviceTableau, :v), (:DeviceAbstractStabilizer, :(v.tab)) + ) + +# Tableau/AbstractStabilizer - Tableau/AbstractStabilizer +@eval @inline function $safe_f_sym( + u::$T_u_sym, v::$T_v_sym; + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + @boundscheck begin + length($u_tab_sym.phases) == length($v_tab_sym.phases) || + throw(DimensionMismatch(THROW_SIZE)) + $u_tab_sym.nqubits == $v_tab_sym.nqubits || + throw(DimensionMismatch(THROW_NQUBITS)) + end + return $unsafe_f_sym( + u, v; + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + +end + +@eval @inline function $unsafe_f_sym( + u::$T_u_sym, v::$T_v_sym; + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + device_mul!( + $u_tab_sym.phases, $u_tab_sym.xzs, $v_tab_sym.phases, $v_tab_sym.xzs, + $multiplication_order; + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + return u + +end + +# Tableau/AbstractStabilizer - Tableau/AbstractStabilizer[i] +@eval @inline function $safe_f_sym( + u::$T_u_sym, v::$T_v_sym, i::Integer; + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + @boundscheck begin + one(i) <= i <= length($v_tab_sym.phases) || + throw(BoundsError(THROW_BOUNDS)) + $u_tab_sym.nqubits == $v_tab_sym.nqubits || + throw(DimensionMismatch(THROW_NQUBITS)) + end + return $unsafe_f_sym( + u, v, i; + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + +end + +@eval @inline function $unsafe_f_sym( + u::$T_u_sym, v::$T_v_sym, i::Integer; + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + @inbounds device_mul!( + $u_tab_sym.phases, $u_tab_sym.xzs, + (@view $v_tab_sym.phases[i]), (@view $v_tab_sym.xzs[:, i]), + $multiplication_order; + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + return u + +end + +# Tableau/AbstractStabilizer[i] - Tableau/AbstractStabilizer[j] +@eval @inline function $safe_f_sym( + u::$T_u_sym, i::Integer, v::$T_v_sym, j::Integer; + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + @boundscheck begin + one(i) <= i <= length($u_tab_sym.phases) && + one(j) <= j <= length($v_tab_sym.phases) || + throw(BoundsError(THROW_BOUNDS)) + $u_tab_sym.nqubits == $v_tab_sym.nqubits || + throw(DimensionMismatch(THROW_NQUBITS)) + end + return $unsafe_f_sym( + u, i, v, j; + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + +end + +@eval @inline function $unsafe_f_sym( + u::$T_u_sym, i::Integer, v::$T_v_sym, j::Integer; + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + @inbounds device_mul!( + (@view $u_tab_sym.phases[i]), (@view $u_tab_sym.xzs[:, i]), + (@view $v_tab_sym.phases[j]), (@view $v_tab_sym.xzs[:, j]), + $multiplication_order; + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + return u + +end + +# Marks the end for (T_v_sym, v_tab_sym) +end + +# Marks the end for (T_u_sym, u_tab_sym) +end + +#============================================================================== +RETURNS (MIXED) DESTABILIZER +==============================================================================# + +# CAUTION: Requires special handling. +# (Mixed)Destabilizer[i] - Self[j] +@eval @inline function $safe_f_sym( + u::DeviceUnionDestabilizer, i::Integer, j::Integer; + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + @boundscheck begin + len = length(u.tab.phases) + n = len >> one(len) + all(x -> one(x) <= x <= len, (i, j, i + n, j + n)) || + throw(BoundsError(THROW_BOUNDS)) + end + return $unsafe_f_sym( + u, i, j; + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + +end + +@eval @inline function $unsafe_f_sym( + u::DeviceUnionDestabilizer, i::Integer, j::Integer; + phases::Val{phase_B} = Val(default_phases), + primary_axis::Val{primary_axis_E} = Val(default_primary_axis), + block_size::Val{block_SZ} = Val(default_block_size), + batch_size::Val{batch_SZ} = Val(default_batch_size) + ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + + p, xzs = u.tab.phases, u.tab.xzs + n = length(p) + n >>= one(n) + # Swapping the order of the indices is intentional. + @inbounds device_mul!( + (@view p[j]), (@view xzs[:, j]), + (@view p[i]), (@view xzs[:, i]), + $multiplication_order; + phases = Val(false), primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + @inbounds device_mul!( + (@view p[i + n]), (@view xzs[:, i + n]), + (@view p[j + n]), (@view xzs[:, j + n]), + $multiplication_order; + phases = phases, primary_axis = primary_axis, + block_size = block_size, batch_size = batch_size + ) + return u + +end + +# Marks the end for (safe_f_sym, unsafe_f_sym, multiplication_order) +end +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/utilities.jl b/ext/QuantumCliffordKAExt/utilities.jl index 703c4cad7..652b5efd3 100644 --- a/ext/QuantumCliffordKAExt/utilities.jl +++ b/ext/QuantumCliffordKAExt/utilities.jl @@ -1,95 +1,9 @@ #=============================================================================# -# Assists in cleanly translating into the grid-block model. -@inline tessellate(space, tile) = tile .* cld.(space, tile) - -# For use whenever KA.@index(Global, NTuple) is unavailable. -# At the moment, JET always complains unless unsafe_indices = true is set. -@inline global_index(block_index, block_dim, thread_index) = - (block_index .- one(eltype(block_index))) .* block_dim .+ thread_index - -# This would have been called "map" but that name is already reserved. -KA.@kernel inbounds = true unsafe_indices = true function kernel_transform!( - f!, target, @Const(auxiliary) - ) - - global_position = global_index( - KA.@index(Group, NTuple), KA.@groupsize(), KA.@index(Local, NTuple) - ) - f!(target, auxiliary, global_position) - -end - -# CAUTION: Requires block_size == length(buffer) == prod(KA.@groupsize()) -# CAUTION: Requires unsafe_indices = true if num_active_threads < block_size -# TODO: Overhaul once __ctx__ is no longer necessary for runtime queries. -# TODO: Revisit once warp level primitives are supported. -@inline function shared_memory_reduce!( - f, buffer::AbstractArray{T}, value::T, index::Integer, ::Val{block_size} - ) where {T, block_size} - - @inbounds buffer[index] = value - - # This branch is a power of 2, take the quick route. - if count_ones(block_size) == one(block_size) - - # This is messy but only for-loop unrolling is supported. - KA.Extras.@unroll for bit = one(block_size) : trailing_zeros(block_size) - stride = KA.@uniform ((2)^(trailing_zeros(block_size) - bit)) - # The call to KA.@synchronize is ALL or NOTHING. - KA.@synchronize() - if index <= stride - @inbounds value = f(value, buffer[index + stride]) - @inbounds buffer[index] = value - end - end - - else - - current = KA.@uniform (block_size) - # TODO: Unroll this branch should it become possible. - while current > one(current) - # The call to KA.@synchronize is ALL or NOTHING. - KA.@synchronize() - # This splits into even/odd steps. Should target SALU. - if iseven(current) - current = KA.@uniform (current >> one(current)) - if index <= current - @inbounds value = f(value, buffer[index + current]) - @inbounds buffer[index] = value - end - else - current = KA.@uniform ((current >> one(current)) + one(current)) - # The strict inequality is intentional. - if index < current - @inbounds value = f(value, buffer[index + current]) - @inbounds buffer[index] = value - end - end - end - - end - return value - -end - -#============================================================================== -COMMON TRANSFORMATIONS -==============================================================================# -# Anonymous functions trigger recompilation. Hence, Separate them out. - -@inline function mod_4_sum!(target, auxiliary, global_position) - @inbounds i = global_position[1] - if i <= length(target) - j = ifelse(length(auxiliary) > 1, i, 1) - @inbounds target[i] = (target[i] + auxiliary[j]) & 0x3 - end -end - -@inline function mod_4_identity!(target, auxiliary, global_position) - @inbounds i = global_position[1] - if i <= length(target) - @inbounds target[i] &= 0x3 - end -end +include("utilities/kernel_configuration.jl") +include("utilities/bit_manipulation.jl") +include("utilities/scan_step.jl") +include("utilities/mutex_management.jl") +include("utilities/reductions.jl") +include("utilities/snippets.jl") #=============================================================================# diff --git a/ext/QuantumCliffordKAExt/utilities/bit_manipulation.jl b/ext/QuantumCliffordKAExt/utilities/bit_manipulation.jl new file mode 100644 index 000000000..447fd0c66 --- /dev/null +++ b/ext/QuantumCliffordKAExt/utilities/bit_manipulation.jl @@ -0,0 +1,23 @@ + +#=============================================================================# +# By definition, the size of (unsigned) char is set to unity. +@inline function bit_count(::Type{T}) where {T} + return sizeof(T) * count_zeros(zero(Cuchar)) +end + +# CAUTION: Zero indexed shift, valid values are less than bit_count(T). +@inline function lowest_set_bit(bit_field::T) where {T <: Unsigned} + return T(trailing_zeros(bit_field)) +end + +# CAUTION: Unsigned typing is intentional for branchless code generation. +@inline function bit_mask(bit_shift::Unsigned, ::Type{T}) where {T <: Unsigned} + return one(T) << bit_shift +end + +# CAUTION: Unsigned typing is intentional for branchless code generation. +# CAUTION: Requires that count be lower than bit_count(T) for validity. +@inline function top_bits(count::Unsigned, ::Type{T}) where {T <: Unsigned} + return ~(~zero(T) >>> count) +end +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/utilities/kernel_configuration.jl b/ext/QuantumCliffordKAExt/utilities/kernel_configuration.jl new file mode 100644 index 000000000..65ee7e3ae --- /dev/null +++ b/ext/QuantumCliffordKAExt/utilities/kernel_configuration.jl @@ -0,0 +1,22 @@ + +#=============================================================================# +# Translates index set dimensions into the grid-block model. +@inline function tessellate( + space::NTuple{N, <: Integer}, tile::NTuple{N, <: Integer} + ) where {N} + + return tile .* cld.(space, tile) + +end + +# Commonly set unsafe_indices = true, hence replaces KA.@index(Global, NTuple). +@inline function global_index( + block_index::NTuple{N, T}, + block_dim::NTuple{N, <: Integer}, + thread_index::NTuple{N, <: Integer} + ) where {N, T <: Integer} + + return (block_index .- one(T)) .* block_dim .+ thread_index + +end +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/utilities/mutex_management.jl b/ext/QuantumCliffordKAExt/utilities/mutex_management.jl new file mode 100644 index 000000000..e294cacea --- /dev/null +++ b/ext/QuantumCliffordKAExt/utilities/mutex_management.jl @@ -0,0 +1,87 @@ + +#=============================================================================# +@inline function create_mutex(backend::KA.Backend) + output = KA.allocate(backend, DeviceUnsigned) + fill!(output, mutex_state_unlocked) + return output +end + +@inline function reset_mutex!(mutex::AbstractMutex) + fill!(mutex, mutex_state_unlocked) + return mutex +end + +# WARNING: Atomix has an API for atomic load/store, yet support is missing. +# WARNING: KernelAbstractions does not provide any sort of memory fence. + +#============================================================================== +SINFUL IMPLEMENTATION WROUGHT ABOUT BY THE FOLLY OF MANKIND +==============================================================================# + +# TODO: Overhaul this entirely once support becomes available. +@inline @generated function lock_mutex!( + mutex::M, data::AbstractArray... + )::Nothing where {M <: AbstractMutex} + + if length(data) > 0x0 + clause = :( + @atomicreplace :release :acquire mutex[0x1] mutex_exchange_lock + ) + # CAUTION: Atomics forces the necessary memory synchronisation barrier. + @inbounds nil = zero(eltype(data[0x1])) + fence = :( + # This is just a fancy atomic NOOP. + @atomicreplace :release :acquire data[0x1][0x1] $nil => $nil; + ) + end + for n in 0x2 : length(data) + @inbounds nil = zero(eltype(data[n])) + fence = :( + $fence; + @atomicreplace :release :acquire data[$n][0x1] $nil => $nil; + ) + end + return :( + @inbounds begin; + while true; + ($clause).success && break; + end; + $fence; + end; + return nothing; + ) + +end + +# TODO: Overhaul this entirely once support becomes available. +@inline @generated function unlock_mutex!( + mutex::M, data::AbstractArray... + )::Nothing where {M <: AbstractMutex} + + if length(data) > 0x0 + # CAUTION: Atomics forces the necessary memory synchronisation barrier. + fence = :( + temp_1 = data[0x1][0x1]; + # This is just a fancy atomic NOOP. Always succeeds and releases. + @atomicreplace :release :acquire data[0x1][0x1] temp_1 => temp_1; + ) + end + for n in 0x2 : length(data) + sym = Symbol(:temp_, n) + fence = :( + $fence; + $sym = data[$n][0x1]; + @atomicreplace :release :acquire data[$n][0x1] $sym => $sym; + ) + end + return :( + @inbounds begin; + $fence; + # This will always succeed. + @atomicreplace :release :acquire mutex[0x1] mutex_exchange_unlock; + end; + return nothing; + ) + +end +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/utilities/reductions.jl b/ext/QuantumCliffordKAExt/utilities/reductions.jl new file mode 100644 index 000000000..4aa0a52a7 --- /dev/null +++ b/ext/QuantumCliffordKAExt/utilities/reductions.jl @@ -0,0 +1,111 @@ + +#=============================================================================# +# CAUTION: Requires unsafe_indices = true if active_threads_count < block_size +# TODO: Overhaul once __ctx__ is no longer necessary for runtime queries. +# TODO: Revisit once warp level primitives are supported. +@inline @generated function shared_memory_reduce!( + f!::Function, index::Integer, ::Val{block_size}, arguments... + )::Nothing where {block_size} + + current = block_size + if current > zero(current) + stages = :(;;) + end + while current > one(current) + stages = :( + $stages; + # The call to KA.@synchronize is ALL or NOTHING. + KA.@synchronize(); + ) + if iseven(current) + current >>= one(current) + stages = :( + $stages; + if index <= $current; + f!(index, $current, arguments...); + end; + ) + else + current = (current >> one(current)) + one(current) + stages = :( + $stages; + # The strict inequality is intentional. + if index < $current; + f!(index, $current, arguments...); + end; + ) + end + end + return :( + $stages; + return nothing; + ) + +end + +#============================================================================== +REDUCTION CATALOGUE +==============================================================================# + +@inline @generated function reduce_sum!( + index::Integer, stride::Integer, arguments::AbstractArray... + )::Nothing + + if length(arguments) > 0x0 + reduction = :( + arguments[0x1][index] += arguments[0x1][index + stride]; + ) + end + for n in 0x2 : length(arguments) + reduction = :( + $reduction; + arguments[$n][index] += arguments[$n][index + stride]; + ) + end + return :( + @inbounds begin; + $reduction; + end; + return nothing; + ) + +end + +@inline @generated function reduce_lexicographic_min!( + index::Integer, stride::Integer, arguments::AbstractArray... + )::Nothing + + if length(arguments) > 0x0 + clause = :(arguments[0x1][index + stride] < arguments[0x1][index]) + body = :( + arguments[0x1][index] = arguments[0x1][index + stride]; + ) + end + for n in 0x2 : length(arguments) + subclause = :(arguments[0x1][index + stride] == arguments[0x1][index]) + for m in 0x2 : (n - 0x1) + subclause = :( + $subclause && + arguments[$m][index + stride] == arguments[$m][index] + ) + end + subclause = :( + $subclause && arguments[$n][index + stride] < arguments[$n][index] + ) + clause = :($clause || $subclause) + body = :( + $body; + arguments[$n][index] = arguments[$n][index + stride]; + ) + end + return :( + @inbounds begin; + if $clause; + $body; + end; + end; + return nothing; + ) + +end +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/utilities/scan_step.jl b/ext/QuantumCliffordKAExt/utilities/scan_step.jl new file mode 100644 index 000000000..afcb367d9 --- /dev/null +++ b/ext/QuantumCliffordKAExt/utilities/scan_step.jl @@ -0,0 +1,103 @@ + +#=============================================================================# +# CAUTION: Output layout is (index, bit_shift, bit_type, break_flag). +@inline function scan_step( + x_bits::T, z_bits::T, current_index::Integer, + index::Integer, bit_shift::Integer, bit_type::PauliBit, + ::Val{sort_order} + ) where {T <: Unsigned, sort_order} + + current_shift_x = lowest_set_bit(x_bits) + current_shift_z = lowest_set_bit(z_bits) + + if sort_order == sort_order_pauli_bit_prefer_x + + if current_shift_x < bit_count(T) && isless( + (pauli_bit_x, current_index, current_shift_x), + (bit_type, index, bit_shift) + ) + + output = (current_index, current_shift_x, pauli_bit_x, true) + + elseif current_shift_z < bit_count(T) && isless( + (pauli_bit_z, current_index, current_shift_z), + (bit_type, index, bit_shift) + ) + + output = (current_index, current_shift_z, pauli_bit_z, false) + + else + + output = (index, bit_shift, bit_type, false) + + end + + elseif sort_order == sort_order_pauli_bit_prefer_z + + # This reverses the ordering of the Pauli bits. + if current_shift_z < bit_count(T) && isless( + (xor(Integer(pauli_bit_z), 0x1), current_index, current_shift_z), + (xor(Integer(bit_type), 0x1), index, bit_shift) + ) + + output = (current_index, current_shift_z, pauli_bit_z, true) + + elseif current_shift_x < bit_count(T) && isless( + (xor(Integer(pauli_bit_x), 0x1), current_index, current_shift_x), + (xor(Integer(bit_type), 0x1), index, bit_shift) + ) + + output = (current_index, current_shift_x, pauli_bit_x, false) + + else + + output = (index, bit_shift, bit_type, false) + + end + + elseif sort_order == sort_order_qubit_number_prefer_x + + candidate = min( + (current_shift_x, pauli_bit_x), + (current_shift_z, pauli_bit_z) + ) + + if @inbounds candidate[0x1] < bit_count(T) && isless( + (current_index, candidate...), + (index, bit_shift, bit_type) + ) + + output = (current_index, candidate..., true) + + else + + output = (index, bit_shift, bit_type, false) + + end + + elseif sort_order == sort_order_qubit_number_prefer_z + + # This reverses the ordering of the Pauli bits. + candidate = min( + (current_shift_x, xor(Integer(pauli_bit_x), 0x1)), + (current_shift_z, xor(Integer(pauli_bit_z), 0x1)) + ) + + @inbounds if candidate[0x1] < bit_count(T) && isless( + (current_index, candidate...), + (index, bit_shift, xor(Integer(bit_type), 0x1)) + ) + + output = (current_index, candidate..., true) + + else + + output = (index, bit_shift, bit_type, false) + + end + + end + + return output +end +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/utilities/snippets.jl b/ext/QuantumCliffordKAExt/utilities/snippets.jl new file mode 100644 index 000000000..9a5b5ac9c --- /dev/null +++ b/ext/QuantumCliffordKAExt/utilities/snippets.jl @@ -0,0 +1,220 @@ + +#=============================================================================# +# CAUTION: Utilises unsafe_indices = true, hence demanding boundary validation. +KA.@kernel inbounds = true unsafe_indices = true function kernel_snippet!( + f!::Function, arguments... + ) + + global_position = global_index( + KA.@index(Group, NTuple), KA.@groupsize(), KA.@index(Local, NTuple) + ) + f!(global_position, arguments...) + +end + +#============================================================================== +SNIPPET CATALOGUE +==============================================================================# + +@inline function snippet_mod_4_sum_phase!( + global_position::NTuple{N, <: Integer}, phases::AbstractArray{<: Unsigned}, + partner::Union{Unsigned, AbstractArray{<: Unsigned}} + )::Nothing where {N} + + @inbounds begin + i = global_position[0x1] + if i <= length(phases) + if partner isa Integer + phases[i] = (phases[i] + partner) & 0x3 + elseif partner isa AbstractArray + j = ifelse(length(partner) > 0x1, i, one(i)) + phases[i] = (phases[i] + partner[j]) & 0x3 + end + end + end + return nothing + +end + +@inline function snippet_mod_4_phase!( + global_position::NTuple{N, <: Integer}, phases::AbstractArray{<: Unsigned} + )::Nothing where {N} + + @inbounds begin + i = global_position[0x1] + if i <= length(phases) + phases[i] &= 0x3 + end + end + return nothing + +end + +@inline function snippet_track_pivot_canonicalize!( + global_position::NTuple{N, <: Integer}, + output_buffer::Union{Nothing, AbstractArray{<: Integer}}, + tracker::AbstractArray{<: Unsigned}, toggle::Bool, sort_order::SortOrder + )::Nothing where {N} + + @inbounds begin + current = KA.@uniform ( + ifelse(toggle, tracker_element_count, 0x0) + ) + previous = KA.@uniform ( + ifelse(toggle, 0x0, tracker_element_count) + ) + if global_position[0x1] == 0x1 + bit_type = tracker[current + Integer(tracker_content_bit_type)] + row = tracker[previous + Integer(tracker_content_swap_to)] + invalid = Integer(pauli_bit_invalid) + + if !isnothing(output_buffer) + + if sort_order == sort_order_pauli_bit_prefer_x + primary = Integer(pauli_bit_x) + secondary = Integer(pauli_bit_z) + elseif sort_order == sort_order_pauli_bit_prefer_z + primary = Integer(pauli_bit_z) + secondary = Integer(pauli_bit_x) + end + + previous_bit_type = + tracker[previous + Integer(tracker_content_bit_type)] + # Primary => Invalid + if bit_type >= invalid && previous_bit_type == primary + output_buffer[0x1] = row + output_buffer[0x2] = row + # Invalid/Primary => Secondary + elseif bit_type == secondary && previous_bit_type != secondary + output_buffer[0x1] = row + # Secondary => Invalid + elseif bit_type >= invalid && previous_bit_type == secondary + output_buffer[0x2] = row + end + + end + + row = ifelse(bit_type < invalid, row + one(row), row) + tracker[current + Integer(tracker_content_swap_to)] = row + end + end + return nothing + +end + +@inline function snippet_track_pivot_canonicalize_rref!( + global_position::NTuple{N, <: Integer}, + output_buffer::Union{Nothing, AbstractArray{<: Integer}}, + tracker::AbstractArray{<: Unsigned}, toggle::Bool + )::Nothing where {N} + + @inbounds begin + current = KA.@uniform ( + ifelse(toggle, tracker_element_count, 0x0) + ) + previous = KA.@uniform ( + ifelse(toggle, 0x0, tracker_element_count) + ) + if global_position[0x1] == 0x1 + bit_type = tracker[current + Integer(tracker_content_bit_type)] + row = tracker[previous + Integer(tracker_content_swap_to)] + invalid = Integer(pauli_bit_invalid) + + if !isnothing(output_buffer) + previous_bit_type = + tracker[previous + Integer(tracker_content_bit_type)] + # Valid => Invalid + if bit_type >= invalid && previous_bit_type < invalid + output_buffer[0x1] = row - one(row) + end + end + + row = ifelse(bit_type < invalid, row - one(row), row) + tracker[current + Integer(tracker_content_swap_to)] = row + end + end + return nothing + +end + +@inline function snippet_swap_rows_prepare_tracker!( + global_position::NTuple{N, <: Integer}, + phases::AbstractArray{<: Unsigned}, xzs::AbstractArray{<: Unsigned}, + tracker::AbstractArray{S}, toggle::Bool + )::Nothing where {N, S <: Unsigned} + + @inbounds begin + i = global_position[0x1] + end_i = KA.@uniform (size(xzs, 0x1) >> 0x1) + current = KA.@uniform ( + ifelse(toggle, tracker_element_count, 0x0) + ) + next = KA.@uniform ( + ifelse(toggle, 0x0, tracker_element_count) + ) + valid = + tracker[current + Integer(tracker_content_bit_type)] < + Integer(pauli_bit_invalid) + row_A = tracker[current + Integer(tracker_content_swap_from)] + row_B = tracker[current + Integer(tracker_content_swap_to)] + + if valid && row_A != row_B + if i <= end_i + temp_x = xzs[i, row_A] + temp_z = xzs[i + end_i, row_A] + xzs[i, row_A] = xzs[i, row_B] + xzs[i + end_i, row_A] = xzs[i + end_i, row_B] + xzs[i, row_B] = temp_x + xzs[i + end_i, row_B] = temp_z + end + if i == one(i) + temp_phase = phases[row_A] + phases[row_A] = phases[row_B] + phases[row_B] = temp_phase + end + end + if i == one(i) + tracker[next + Integer(tracker_content_index)] = typemax(S) + tracker[next + Integer(tracker_content_bit_shift)] = typemax(S) + tracker[next + Integer(tracker_content_bit_type)] = typemax(S) + tracker[next + Integer(tracker_content_swap_from)] = typemax(S) + end + end + return nothing + +end + +@inline function snippet_set_row_phase_flag!( + global_position::NTuple{N, <: Integer}, + phases::AbstractArray{P}, xzs::AbstractArray{T}, + tracker::AbstractArray{<: Unsigned}, toggle::Bool + )::Nothing where {N, P <: Unsigned, T <: Unsigned} + + @inbounds begin + z_offset = KA.@uniform (size(xzs, 0x1) >> 0x1) + end_rows = KA.@uniform (size(xzs, 0x2)) + current = KA.@uniform ( + ifelse(toggle, tracker_element_count, 0x0) + ) + row = global_position[0x1] + index = tracker[current + Integer(tracker_content_index)] + bit_shift = tracker[current + Integer(tracker_content_bit_shift)] + bit_type = tracker[current + Integer(tracker_content_bit_type)] + + if bit_type < Integer(pauli_bit_invalid) && row <= end_rows + if bit_type == Integer(pauli_bit_x) + status = xzs[index, row] & bit_mask(bit_shift, T) + elseif bit_type == Integer(pauli_bit_z) + status = xzs[index + z_offset, row] & bit_mask(bit_shift, T) + end + phases[row] &= 0x3 + if status != zero(T) + # Top bit is utilised as auxiliary storage. + phases[row] |= top_bits(0x1, P) + end + end + end + return nothing + +end +#=============================================================================# diff --git a/src/throws.jl b/src/throws.jl index 1adc43b29..765b6f88f 100644 --- a/src/throws.jl +++ b/src/throws.jl @@ -10,3 +10,7 @@ between the pertinent size(s) of the provided arguments." const THROW_NQUBITS = "Unable to perform the requested operation due to encountering a mismatch \ between the number of qubits in the provided arguments." + +const THROW_VALS = +"Unable to perform the requested operation due to encountering a mismatch \ +between the provided `::Val` parameter(s) and the range of supported value(s)." diff --git a/test/KernelAbstractions/implementation/definitions.jl b/test/KernelAbstractions/implementation/definitions.jl index 27e9e93b3..bdf4ec4a4 100644 --- a/test/KernelAbstractions/implementation/definitions.jl +++ b/test/KernelAbstractions/implementation/definitions.jl @@ -9,5 +9,5 @@ const max_rows = 1024 const round_count = 16 # Correctness should be independent of parameter values. # The omission of the const specifier is intentional, overridden in OpenCL. -block_sizes = rand(1:256, round_count) -const batch_sizes = rand(1:256, round_count) +block_sizes = rand(1 : 256, round_count) +const batch_sizes = rand(1 : 256, round_count) diff --git a/test/KernelAbstractions/implementation/utilities.jl b/test/KernelAbstractions/implementation/utilities.jl index 6447b6a82..c56f2a2c5 100644 --- a/test/KernelAbstractions/implementation/utilities.jl +++ b/test/KernelAbstractions/implementation/utilities.jl @@ -1,21 +1,46 @@ # Works even when broadcasting on zero-dimensional arrays. -@inline u32(v) = map(x -> UInt32(x), v) +@inline function u32(v) + return map(x -> UInt32(x), v) +end # Surprisingly, these do not already exist. -@inline get_pauli(t::Tableau, i::Integer) = - PauliOperator((@view t.phases[i]), t.nqubits, (@view t.xzs[:, i])) -@inline get_pauli(s::AbstractStabilizer, i::Integer) = - PauliOperator( +@inline function get_pauli(t::Tableau, i::Integer) + return PauliOperator( + (@view t.phases[i]), + t.nqubits, + (@view t.xzs[:, i]) + ) +end +@inline function get_pauli(s::AbstractStabilizer, i::Integer) + return PauliOperator( (@view s.tab.phases[i]), s.tab.nqubits, (@view s.tab.xzs[:, i]) ) +end + +@inline function phases(t::Tableau) + return t.phases +end +@inline function phases(t::Tableau, i::Integer) + return (@view t.phases[i]) +end +@inline function phases(s::AbstractStabilizer) + return s.tab.phases +end +@inline function phases(s::AbstractStabilizer, i::Integer) + return (@view s.tab.phases[i]) +end -@inline phases(t::Tableau) = t.phases -@inline phases(t::Tableau, i::Integer) = (@view t.phases[i]) -@inline phases(s::AbstractStabilizer) = s.tab.phases -@inline phases(s::AbstractStabilizer, i::Integer) = (@view s.tab.phases[i]) -@inline xzs(t::Tableau) = t.xzs -@inline xzs(t::Tableau, i::Integer) = (@view t.xzs[:, i]) -@inline xzs(s::AbstractStabilizer) = s.tab.xzs -@inline xzs(s::AbstractStabilizer, i::Integer) = (@view s.tab.xzs[:, i]) +@inline function xzs(t::Tableau) + return t.xzs +end +@inline function xzs(t::Tableau, i::Integer) + return (@view t.xzs[:, i]) +end +@inline function xzs(s::AbstractStabilizer) + return s.tab.xzs +end +@inline function xzs(s::AbstractStabilizer, i::Integer) + return (@view s.tab.xzs[:, i]) +end From e19e58a38c3dbc34284865bec2b6b6caac0faf5c Mon Sep 17 00:00:00 2001 From: "ha.git" Date: Mon, 18 Aug 2025 17:13:13 +0200 Subject: [PATCH 2/3] Introduce consistent interface. Overhaul test and benchmark suites. Introduce benchmark and test scripts for canonicalization routines. --- benchmark/KernelAbstractions/README.md | 2 +- .../benchmark_platform_CUDA.jl | 7 +- .../benchmark_platform_OpenCL.jl | 7 +- .../benchmark_platform_ROCm.jl | 7 +- .../benchmark_KA_mul_leftright.jl | 195 -------- .../implementation/benchmark_platform.jl | 18 +- .../implementation/definitions.jl | 28 +- .../definitions/benchmark_configuration.jl | 54 +++ .../definitions/plot_configuration.jl | 13 + .../definitions/tuning_parameters.jl | 29 ++ .../implementation/imports.jl | 17 +- .../suites/benchmark_KA_canonicalization.jl | 172 +++++++ .../implementation/suites/benchmark_KA_mul.jl | 292 +++++++++++ .../implementation/utilities.jl | 13 +- .../utilities/benchmark_management.jl | 65 +++ .../utilities/bit_manipulation.jl | 23 + .../utilities/memory_management.jl | 34 ++ .../QuantumCliffordAdaptExt.jl | 2 +- ext/QuantumCliffordAdaptExt/adapters.jl | 16 +- ext/QuantumCliffordAdaptExt/imports.jl | 1 - ext/QuantumCliffordAdaptExt/utilities.jl | 4 +- ext/QuantumCliffordGPUExt/canonicalization.jl | 4 +- .../QuantumCliffordKAExt.jl | 3 +- ext/QuantumCliffordKAExt/README.md | 2 +- .../canonicalization/canonicalize.jl | 167 +++---- .../canonicalization/canonicalize_gott.jl | 1 + .../canonicalization/canonicalize_rref.jl | 168 +++---- .../canonicalization/common.jl | 252 ++++------ .../definitions/default_parameters.jl | 16 +- .../definitions/enumerations.jl | 46 +- .../definitions/fixed_sizes.jl | 2 +- .../definitions/mutex_configuration.jl | 6 +- .../definitions/type_shorthands.jl | 16 +- .../definitions/word_size_integers.jl | 4 +- ext/QuantumCliffordKAExt/imports.jl | 10 +- .../{mul_leftright.jl => mul.jl} | 5 +- .../{mul_leftright => mul}/device_mul.jl | 90 ++-- ext/QuantumCliffordKAExt/mul/new_interface.jl | 457 ++++++++++++++++++ ext/QuantumCliffordKAExt/mul/old_interface.jl | 343 +++++++++++++ .../mul_leftright/host_interface.jl | 430 ---------------- .../utilities/bit_manipulation.jl | 19 +- .../utilities/kernel_configuration.jl | 6 +- .../utilities/mutex_management.jl | 46 +- .../utilities/reductions.jl | 22 +- .../utilities/scan_step.jl | 87 ++-- .../utilities/snippets.jl | 95 ++-- src/throws.jl | 4 +- .../implementation/definitions.jl | 18 +- .../definitions/test_configuration.jl | 15 + .../definitions/tuning_parameters.jl | 12 + .../implementation/imports.jl | 9 +- .../suites/test_KA_canonicalization.jl | 88 ++++ .../implementation/suites/test_KA_mul.jl | 223 +++++++++ .../implementation/test_KA_mul_leftright.jl | 238 --------- .../implementation/test_platform.jl | 20 +- .../implementation/utilities.jl | 51 +- .../utilities/bit_manipulation.jl | 13 + .../implementation/utilities/equalities.jl | 44 ++ .../utilities/memory_management.jl | 34 ++ .../implementation/utilities/views.jl | 26 + test/KernelAbstractions/test_platform_CUDA.jl | 5 +- .../test_platform_OpenCL.jl | 7 +- test/KernelAbstractions/test_platform_ROCm.jl | 5 +- test/test_gpu_canonicalization.jl | 18 +- 64 files changed, 2506 insertions(+), 1620 deletions(-) delete mode 100644 benchmark/KernelAbstractions/implementation/benchmark_KA_mul_leftright.jl create mode 100644 benchmark/KernelAbstractions/implementation/definitions/benchmark_configuration.jl create mode 100644 benchmark/KernelAbstractions/implementation/definitions/plot_configuration.jl create mode 100644 benchmark/KernelAbstractions/implementation/definitions/tuning_parameters.jl create mode 100644 benchmark/KernelAbstractions/implementation/suites/benchmark_KA_canonicalization.jl create mode 100644 benchmark/KernelAbstractions/implementation/suites/benchmark_KA_mul.jl create mode 100644 benchmark/KernelAbstractions/implementation/utilities/benchmark_management.jl create mode 100644 benchmark/KernelAbstractions/implementation/utilities/bit_manipulation.jl create mode 100644 benchmark/KernelAbstractions/implementation/utilities/memory_management.jl rename ext/QuantumCliffordKAExt/{mul_leftright.jl => mul.jl} (63%) rename ext/QuantumCliffordKAExt/{mul_leftright => mul}/device_mul.jl (54%) create mode 100644 ext/QuantumCliffordKAExt/mul/new_interface.jl create mode 100644 ext/QuantumCliffordKAExt/mul/old_interface.jl delete mode 100644 ext/QuantumCliffordKAExt/mul_leftright/host_interface.jl create mode 100644 test/KernelAbstractions/implementation/definitions/test_configuration.jl create mode 100644 test/KernelAbstractions/implementation/definitions/tuning_parameters.jl create mode 100644 test/KernelAbstractions/implementation/suites/test_KA_canonicalization.jl create mode 100644 test/KernelAbstractions/implementation/suites/test_KA_mul.jl delete mode 100644 test/KernelAbstractions/implementation/test_KA_mul_leftright.jl create mode 100644 test/KernelAbstractions/implementation/utilities/bit_manipulation.jl create mode 100644 test/KernelAbstractions/implementation/utilities/equalities.jl create mode 100644 test/KernelAbstractions/implementation/utilities/memory_management.jl create mode 100644 test/KernelAbstractions/implementation/utilities/views.jl diff --git a/benchmark/KernelAbstractions/README.md b/benchmark/KernelAbstractions/README.md index 517ae6833..51364a3f4 100644 --- a/benchmark/KernelAbstractions/README.md +++ b/benchmark/KernelAbstractions/README.md @@ -1,6 +1,6 @@ # Usage Instructions -1. Modify the parameters listed in `implementation/definitions.jl` as desired, keeping in mind the limitations of the available device memory and the extent of the benchmark runtime. +1. Modify the values listed in `implementation/definitions/[benchmark_configuration, tuning_parameters].jl` as desired, keeping in mind the limitations of the available device memory and the extent of the benchmark runtime. 2. Ensure that all the packages listed in `implementation/imports.jl` are installed as this is not handled automatically. 3. Ensure that the backend package(s) listed in the pertinent `benchmark_platform_*.jl` script are properly setup and configured. 4. Pass said script as an argument to julia (optionally, also set the number of executing host threads) and await for the benchmark to conclude. diff --git a/benchmark/KernelAbstractions/benchmark_platform_CUDA.jl b/benchmark/KernelAbstractions/benchmark_platform_CUDA.jl index 702a50cac..ccf64d4eb 100644 --- a/benchmark/KernelAbstractions/benchmark_platform_CUDA.jl +++ b/benchmark/KernelAbstractions/benchmark_platform_CUDA.jl @@ -1,13 +1,16 @@ + +#=============================================================================# include("implementation/benchmark_platform.jl") using CUDA: CuArray, devices, synchronize const AT = CuArray -const path = "QuantumClifford_benchmarks/CUDA" +const path = "benchmarks/QuantumCliffordKAExt/CUDA" const can_run = length(devices()) > 0 if can_run - benchmark_platform(AT, synchronize, path) + benchmark_platform(synchronize, AT, path) else @info "Unable to run CUDA benchmark. No suitable device was found." end +#=============================================================================# diff --git a/benchmark/KernelAbstractions/benchmark_platform_OpenCL.jl b/benchmark/KernelAbstractions/benchmark_platform_OpenCL.jl index 302c94421..67d3ee813 100644 --- a/benchmark/KernelAbstractions/benchmark_platform_OpenCL.jl +++ b/benchmark/KernelAbstractions/benchmark_platform_OpenCL.jl @@ -1,15 +1,18 @@ + +#=============================================================================# include("implementation/benchmark_platform.jl") import pocl_jll using OpenCL: CLArray, cl.devices, cl.platforms, cl.finish, cl.queue const AT = CLArray -const path = "QuantumClifford_benchmarks/OpenCL" +const path = "benchmarks/QuantumCliffordKAExt/OpenCL" const can_run = any(length(devices(platform)) > 0 for platform in platforms()) if can_run synchronize() = finish(queue()) - benchmark_platform(AT, synchronize, path) + benchmark_platform(synchronize, AT, path) else @info "Unable to run OpenCL benchmark. No suitable device was found." end +#=============================================================================# diff --git a/benchmark/KernelAbstractions/benchmark_platform_ROCm.jl b/benchmark/KernelAbstractions/benchmark_platform_ROCm.jl index 50bc6868b..dce2a26e6 100644 --- a/benchmark/KernelAbstractions/benchmark_platform_ROCm.jl +++ b/benchmark/KernelAbstractions/benchmark_platform_ROCm.jl @@ -1,13 +1,16 @@ + +#=============================================================================# include("implementation/benchmark_platform.jl") using AMDGPU: ROCArray, devices, synchronize const AT = ROCArray -const path = "QuantumClifford_benchmarks/ROCm" +const path = "benchmarks/QuantumCliffordKAExt/ROCm" const can_run = length(devices()) > 0 if can_run - benchmark_platform(AT, synchronize, path) + benchmark_platform(synchronize, AT, path) else @info "Unable to run ROCm benchmark. No suitable device was found." end +#=============================================================================# diff --git a/benchmark/KernelAbstractions/implementation/benchmark_KA_mul_leftright.jl b/benchmark/KernelAbstractions/implementation/benchmark_KA_mul_leftright.jl deleted file mode 100644 index 16fefba65..000000000 --- a/benchmark/KernelAbstractions/implementation/benchmark_KA_mul_leftright.jl +++ /dev/null @@ -1,195 +0,0 @@ -# This must be done explicitly as they are not exported. -using QuantumClifford: mul_left!, mul_right!, Tableau - -@inline function host_f!( - x, y; - phases::Val{phase_B} = Val(default_phases) - ) where {phase_B} - - mul_left!(x, y; phases = phases) - -end - -@inline function device_f!( - x, y, synchronize; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - mul_left!( - x, y; - phases = phases, primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - synchronize() - -end - -@inline function benchmark_KA_mul_pauli_pauli( - AT, synchronize, path; - phases::Val{phase_B} = Val(default_phases) - ) where {phase_B} - - host_time = zeros(Float64, length(n_MiB)) - device_time = zeros(Float64, length(batch_sizes), length(n_MiB)) - - # Keep the memory usage sane. - cache = AllocCache() - for (i, n) in enumerate(n_MiB) - @cached cache begin - # Each qubit requires 2 bits. - h_p1 = PauliOperator( - zeros(Cuchar), - n * MiB >> 1, - zeros(UInt, cld(n * MiB >> 1, bit_count(UInt)) << 1) - ) - h_p2 = copy(h_p1) - d_p1 = PauliOperator( - AT(u32(h_p1.phase)), - h_p1.nqubits, - AT(reinterpret(UInt32, h_p1.xz)) - ) - d_p2 = copy(d_p1) - synchronize() - # Trigger compilation before benchmarking. - host_f!(h_p1, h_p2; phases = phases) - host_time[i] = @belapsed host_f!( - $h_p1, $h_p2; phases = $phases - ) evals = evals samples = samples seconds = seconds - for (j, size) in enumerate(batch_sizes) - device_f!( - d_p1, d_p2, synchronize; - phases = phases, batch_size = Val(size) - ) - device_time[j, i] = - @belapsed device_f!( - $d_p1, $d_p2, $synchronize; - phases = $phases, batch_size = Val($size) - ) evals = evals samples = samples seconds = seconds - end - end - end - unsafe_free!(cache) - - device_cat = [device_time[i, :] for i = 1 : length(batch_sizes)] - title = - "Performance uplift - multiplication (phases = $phase_B)\n\ - Host threads = " * string(Threads.nthreads()) * " / " * - string(Sys.CPU_THREADS) * ", Device block size = $default_block_size" - xlabel = "Pauli operator size (MiB)" - label = hcat(("Device - batch size = " .* string.(batch_sizes))..., "Host") - path *= "/pauli_pauli" - mkpath(path) - - plot( - n_MiB, 10^3 .* hcat(device_cat..., host_time); - shape = :circle, xticks = n_MiB, xscale = :log2, yscale = :log10, - title = title, label = label, xlabel = xlabel, ylabel = "Runtime (ms)", - background_color = :transparent - ) - savefig("$path/runtime.$format") - - plot( - n_MiB, map(x -> host_time ./ x, device_cat); - shape = :circle, xticks = n_MiB, xscale = :log2, title = title, - label = hcat(label[1 : end - 1]...), xlabel = xlabel, - ylabel = "Ratio (host/device)", background_color = :transparent - ) - savefig("$path/ratio.$format") - -end - -@inline function benchmark_KA_mul_tableau_pauli( - AT, synchronize, path; - phases::Val{phase_B} = Val(default_phases) - ) where {phase_B} - - host_time = zeros(Float64, length(n_MiB)) - device_time = zeros(Float64, length(batch_sizes), length(n_MiB)) - # Keep the memory usage sane. - cache = AllocCache() - for (i, n) in enumerate(n_MiB) - # Each qubit requires 2 bits. - nqubits = round(Int, sqrt(n * MiB / 2), RoundUp) - @cached cache begin - h_p = PauliOperator( - zeros(Cuchar), - nqubits, - zeros(UInt, cld(nqubits, bit_count(UInt)) << 1) - ) - h_t = Tableau( - zeros(Cuchar, nqubits), - nqubits, - zeros(UInt, cld(nqubits, bit_count(UInt)) << 1, nqubits) - ) - d_p = PauliOperator( - AT(u32(h_p.phase)), - h_p.nqubits, - AT(reinterpret(UInt32, h_p.xz)) - ) - d_t = Tableau( - AT(u32(h_t.phases)), - h_t.nqubits, - AT(reinterpret(UInt32, h_t.xzs)) - ) - synchronize() - # Trigger compilation before benchmarking. - host_f!(h_t, h_p; phases = phases) - host_time[i] = @belapsed host_f!( - $h_t, $h_p; phases = $phases - ) evals = evals samples = samples seconds = seconds - for (j, size) in enumerate(batch_sizes) - device_f!( - d_t, d_p, synchronize; - phases = phases, batch_size = Val(size) - ) - device_time[j, i] = - @belapsed device_f!( - $d_t, $d_p, $synchronize; - phases = $phases, batch_size = Val($size) - ) evals = evals samples = samples seconds = seconds - end - end - end - unsafe_free!(cache) - - device_cat = [device_time[i, :] for i = 1 : length(batch_sizes)] - title = - "Performance uplift - multiplication (phases = $phase_B)\n\ - Host threads = " * string(Threads.nthreads()) * " / " * - string(Sys.CPU_THREADS) * ", Device block size = $default_block_size" - xlabel = "Tableau size (MiB)" - label = hcat(("Device - batch size = " .* string.(batch_sizes))..., "Host") - path *= "/tableau_pauli" - mkpath(path) - - plot( - n_MiB, 10^3 .* hcat(device_cat..., host_time); - shape = :circle, xticks = n_MiB, xscale = :log2, yscale = :log10, - title = title, label = label, xlabel = xlabel, ylabel = "Runtime (ms)", - background_color = :transparent - ) - savefig("$path/runtime.$format") - - plot( - n_MiB, map(x -> host_time ./ x, device_cat); - shape = :circle, xticks = n_MiB, xscale = :log2, title = title, - label = hcat(label[1 : end - 1]...), xlabel = xlabel, - ylabel = "Ratio (host/device)", background_color = :transparent - ) - savefig("$path/ratio.$format") - -end - -@inline function benchmark_KA_mul_leftright( - AT, synchronize, path; - phases::Val{phase_B} = Val(default_phases) - ) where {phase_B} - - path *= "/mul_leftright/phase_$phase_B" - benchmark_KA_mul_pauli_pauli(AT, synchronize, path; phases = phases) - benchmark_KA_mul_tableau_pauli(AT, synchronize, path; phases = phases) - -end diff --git a/benchmark/KernelAbstractions/implementation/benchmark_platform.jl b/benchmark/KernelAbstractions/implementation/benchmark_platform.jl index c5fa19723..d9708e2e5 100644 --- a/benchmark/KernelAbstractions/implementation/benchmark_platform.jl +++ b/benchmark/KernelAbstractions/implementation/benchmark_platform.jl @@ -1,10 +1,20 @@ + +#=============================================================================# include("imports.jl") include("definitions.jl") include("utilities.jl") -include("benchmark_KA_mul_leftright.jl") -@inline function benchmark_platform(AT, synchronize, path) +include("suites/benchmark_KA_mul.jl") +include("suites/benchmark_KA_canonicalization.jl") + +@inline function benchmark_platform(synchronize, AT, path)::Nothing + # BenchmarkTools evaluates the setup block at the global scope. + global cache = AllocCache() path *= "/" * string(value(now()) - UNIXEPOCH) - benchmark_KA_mul_leftright(AT, synchronize, path; phases = Val(true)) - benchmark_KA_mul_leftright(AT, synchronize, path; phases = Val(false)) + + benchmark_KA_mul(synchronize, AT, path) + benchmark_KA_canonicalization(synchronize, AT, path) + + return nothing end +#=============================================================================# diff --git a/benchmark/KernelAbstractions/implementation/definitions.jl b/benchmark/KernelAbstractions/implementation/definitions.jl index bf9bd0597..b37686718 100644 --- a/benchmark/KernelAbstractions/implementation/definitions.jl +++ b/benchmark/KernelAbstractions/implementation/definitions.jl @@ -1,24 +1,6 @@ -# (La)TeX hates SVG but the Plots package has issues with transparent PDFs. -const format = "svg" -# BenchmarkTools parameters. -# Evaluations per sample point. -const evals = 16 -# Maximum number of samples. -const samples = 2^10 -# Maximum runtime for each sample group. -const seconds = 60 - -# By definition, (unsigned) char is the smallest addressable unit of memory. -const MiB = 1024 * 1024 * count_zeros(zero(Cuchar)) -# Avoid consuming too many resources, 1 GiB is plenty. -const n_MiB = [2^i for i = 1:10] -# TODO: Keep these or remove them now that a good default has been set? -const batch_sizes = [1, 4, 8, 16, 32, 64] - -# These values originate from a package extension, hence the query. -const KAExt = Base.get_extension(QuantumClifford, :QuantumCliffordKAExt) -const default_phases = KAExt.default_phases -const default_primary_axis = KAExt.default_primary_axis -const default_block_size = KAExt.default_block_size -const default_batch_size = KAExt.default_batch_size +#=============================================================================# +include("definitions/benchmark_configuration.jl") +include("definitions/plot_configuration.jl") +include("definitions/tuning_parameters.jl") +#=============================================================================# diff --git a/benchmark/KernelAbstractions/implementation/definitions/benchmark_configuration.jl b/benchmark/KernelAbstractions/implementation/definitions/benchmark_configuration.jl new file mode 100644 index 000000000..9d57ef430 --- /dev/null +++ b/benchmark/KernelAbstractions/implementation/definitions/benchmark_configuration.jl @@ -0,0 +1,54 @@ + +#=============================================================================# + +#============================================================================== +BENCHMARK TOOLS +==============================================================================# + +# CAUTION: Functions mutate their arguments, induces disparity between runs. +# Evaluations per sample point. +const evals = 1 + +# Maximum number of sample points. +const samples = 2^14 + +# Maximum runtime for each trial. +const seconds = 60 + +#============================================================================== +SAMPLE EXTRAPOLATION +==============================================================================# + +# Maximum sampling period before before being aborted and extrapolated instead. +const extrapolation_threshold = seconds << 1 + +# Whether to include the aborted run in the data set used for extrapolation. +const include_threshold_point = false + +# In the absence of sufficient data points for a fit, perform O(n^k) scaling. +const host_permit_simple_scaling = true +const device_permit_simple_scaling = true + +#============================================================================== +PROBLEM SIZE +==============================================================================# + +# By definition, (unsigned) char is the smallest addressable unit of memory. +const MiB = 1024 * 1024 * count_zeros(zero(Cuchar)) + +# Avoid consuming too many resources, 1 GiB is plenty. +const sizes_MiB = [2^i for i in 1 : 10] + +#============================================================================== +TUNING PARAMETERS +==============================================================================# + +const benchmark_primary_axis = true + +const benchmark_phases = true + +# TODO: Enable this by default once the POCL code generation bugs are fixed. +const benchmark_block_size = false + +const benchmark_batch_size = true +#=============================================================================# diff --git a/benchmark/KernelAbstractions/implementation/definitions/plot_configuration.jl b/benchmark/KernelAbstractions/implementation/definitions/plot_configuration.jl new file mode 100644 index 000000000..cd5a3358f --- /dev/null +++ b/benchmark/KernelAbstractions/implementation/definitions/plot_configuration.jl @@ -0,0 +1,13 @@ + +#=============================================================================# +# The output should be both stylish and informative. +const plot_style = Dict( + :xticks => sizes_MiB, + :xscale => :log2, + :shape => :circle, + :background_color => :transparent + ) + +# (La)TeX hates SVG but the Plots package has issues with transparent PDFs. +const file_format = "svg" +#=============================================================================# diff --git a/benchmark/KernelAbstractions/implementation/definitions/tuning_parameters.jl b/benchmark/KernelAbstractions/implementation/definitions/tuning_parameters.jl new file mode 100644 index 000000000..a5149d6e2 --- /dev/null +++ b/benchmark/KernelAbstractions/implementation/definitions/tuning_parameters.jl @@ -0,0 +1,29 @@ + +#=============================================================================# +# These values originate from a package extension, hence requiring this query. +const KAExt = Base.get_extension(QuantumClifford, :QuantumCliffordKAExt) + +if benchmark_primary_axis + const primary_axes = [axis for axis in instances(KAExt.PrimaryAxis)] +else + const primary_axes = [KAExt.default_primary_axis] +end + +if benchmark_phases + const phases = [true, false] +else + const phases = [KAExt.default_phases] +end + +if benchmark_block_size + const block_sizes = [32, 64, 128, 256, 512] +else + const block_sizes = [KAExt.default_block_size] +end + +if benchmark_batch_size + const batch_sizes = [1, 4, 8, 16, 32, 64] +else + const batch_sizes = [KAExt.default_batch_size] +end +#=============================================================================# diff --git a/benchmark/KernelAbstractions/implementation/imports.jl b/benchmark/KernelAbstractions/implementation/imports.jl index c5d27490c..e7339688f 100644 --- a/benchmark/KernelAbstractions/implementation/imports.jl +++ b/benchmark/KernelAbstractions/implementation/imports.jl @@ -1,9 +1,24 @@ + +#=============================================================================# # Required for QuantumCliffordKAExt. import Atomix, GPUArraysCore, KernelAbstractions +# Required for QuantumCliffordAdaptExt. +using Adapt: adapt + using BenchmarkTools: @belapsed -using Dates: value, now, UNIXEPOCH + +using Dates: value, now, UNIXEPOCH, Second + # Assists in reducing resource demands. using GPUArrays: AllocCache, @cached, unsafe_free! + +# Utilised for extrapolating the runtime of exceedingly long benchmarks. +using LsqFit: curve_fit + using Plots: plot, savefig + using QuantumClifford +# This must be done explicitly as they are not exported. +using QuantumClifford: Tableau, AbstractStabilizer, random_tableau +#=============================================================================# diff --git a/benchmark/KernelAbstractions/implementation/suites/benchmark_KA_canonicalization.jl b/benchmark/KernelAbstractions/implementation/suites/benchmark_KA_canonicalization.jl new file mode 100644 index 000000000..809ac8560 --- /dev/null +++ b/benchmark/KernelAbstractions/implementation/suites/benchmark_KA_canonicalization.jl @@ -0,0 +1,172 @@ + +#=============================================================================# +@inline function host_rref!(args...; kwargs...)::Nothing + canonicalize_rref!(args...; kwargs...) + return nothing +end + +@inline function device_rref!(synchronize, args...; kwargs...)::Nothing + canonicalize_rref!(args...; kwargs...) + synchronize() + return nothing +end + +function benchmark_KA_canonicalize_rref!( + synchronize, AT, + index, parameters, + host_time, extrapolate_host, host_fit, + device_time, extrapolate_device, device_fit; + enable_masks = false, + phases = KAExt.default_phases, + primary_axis = KAExt.default_primary_axis + ) + + error_flag = false + + # These definitions ought to be kept local to this scope. + initial_fit = [0.0, 0.0, 0.0, 1.0] + @inline function model(n, p) + return @. p[1] + p[2] * n + p[3] * n^2 + p[4] * n^3 + end + + nqubits = nqubits_tableau(sizes_MiB[index]) + point_count = ifelse( + include_threshold_point, + index, + index - one(index)) + fit_nqubits = nqubits_tableau.(sizes_MiB[Base.OneTo(point_count)]) + + @cached cache begin + # BenchmarkTools evaluates the setup block at the global scope. + global host_stabilizer = Stabilizer(random_tableau(nqubits, nqubits)) + global device_stabilizer = adapt(AT, host_stabilizer) + global host_temp = copy(host_stabilizer) + global device_temp = copy(device_stabilizer) + + if enable_masks + colindices = one(nqubits) : (one(nqubits) << 1) : nqubits + xzs = host_stabilizer.tab.xzs + bit_masks = AT(zeros(eltype(xzs), size(xzs, 1) >> 1)) + fill!(bit_masks, alternating_bit_mask(eltype(xzs))) + else + colindices = Base.OneTo(nqubits) + bit_masks = nothing + end + + if extrapolate_host[1] + host_time[index] = model(nqubits, host_fit[1]) + time_span = Second(0) + else + # Trigger compilation before benchmarking. + host_rref!(host_temp, colindices; phases = phases) + start = now() + host_time[index] = @belapsed host_rref!( + $host_temp, $colindices; phases = $phases + ) evals = evals samples = samples seconds = seconds setup = ( + copy_to!(host_temp, host_stabilizer); + ) + time_span = now() - start + end + + if time_span >= Second(extrapolation_threshold) + if point_count >= length(initial_fit) + fit_time = host_time[Base.OneTo(point_count)] + fit = curve_fit(model, fit_nqubits, fit_time, initial_fit) + + extrapolate_host[1] = true + host_fit[1] = fit.param + + if !include_threshold_point + host_time[index] = model(nqubits, host_fit[1]) + end + elseif host_permit_simple_scaling + fit = copy(initial_fit) + fit[end] = host_time[index] / model(nqubits, fit) + extrapolate_host[1] = true + host_fit[1] = fit + else + return true + end + end + + for (parameters_index, (block, batch)) in pairs(parameters) + if extrapolate_device[parameters_index] + device_time[index, parameters_index] = + model(nqubits, device_fit[parameters_index]) + time_span = Second(0) + else + device_rref!( + synchronize, device_temp, nothing, bit_masks; + phases = phases, primary_axis = primary_axis, + block_size = block, batch_size = batch + ) + start = now() + device_time[index, parameters_index] = @belapsed device_rref!( + $synchronize, $device_temp, nothing, $bit_masks; + phases = $phases, primary_axis = $primary_axis, + block_size = $block, batch_size = $batch + ) evals = evals samples = samples seconds = seconds setup = ( + @cached cache copy_to!(device_temp, device_stabilizer); + ) + time_span = now() - start + end + + if time_span >= Second(extrapolation_threshold) + if point_count >= length(initial_fit) + fit_time = + device_time[Base.OneTo(point_count), parameters_index] + fit = curve_fit(model, fit_nqubits, fit_time, initial_fit) + + extrapolate_device[parameters_index] = true + device_fit[parameters_index] = fit.param + + if !include_threshold_point + device_time[index, parameters_index] = + model(nqubits, device_fit[parameters_index]) + end + elseif device_permit_simple_scaling + fit = copy(initial_fit) + fit[end] = device_time[index, parameters_index] / model(nqubits, fit) + extrapolate_device[parameters_index] = true + device_fit[parameters_index] = fit + else + return true + end + end + end + end + + host_stabilizer = nothing + device_stabilizer = nothing + host_temp = nothing + device_temp = nothing + unsafe_free!(cache) + GC.gc(true) + + return error_flag + +end + +@inline function benchmark_KA_canonicalization(synchronize, AT, path)::Nothing + for (phase, axis) in Iterators.product(phases, primary_axes) + temp = path * "/canonicalization/phases_$phase" * "_$axis" + run_benchmark( + benchmark_KA_canonicalize_rref!, + synchronize, AT, + temp * "/canonicalize_rref", + "Tableau canonicalization\n(phases_$phase, $axis)", + "Tableau size (MiB)"; + enable_masks = false, phases = phase, primary_axis = axis + ) + run_benchmark( + benchmark_KA_canonicalize_rref!, + synchronize, AT, + temp * "/canonicalize_rref_masked", + "Tableau canonicalization (masked)\n(phases_$phase, $axis)", + "Tableau size (MiB)"; + enable_masks = true, phases = phase, primary_axis = axis + ) + end + return nothing +end +#=============================================================================# diff --git a/benchmark/KernelAbstractions/implementation/suites/benchmark_KA_mul.jl b/benchmark/KernelAbstractions/implementation/suites/benchmark_KA_mul.jl new file mode 100644 index 000000000..ed7d98c3f --- /dev/null +++ b/benchmark/KernelAbstractions/implementation/suites/benchmark_KA_mul.jl @@ -0,0 +1,292 @@ + +#=============================================================================# +# This must be done explicitly as they are not exported. +using QuantumClifford: mul_left!, mul_right! + +@inline function host_mul!(args...; kwargs...)::Nothing + mul_left!(args...; kwargs...) + return nothing +end + +@inline function device_mul!(synchronize, args...; kwargs...)::Nothing + mul_left!(args...; kwargs...) + synchronize() + return nothing +end + +function benchmark_KA_mul_pauli_pauli!( + synchronize, AT, + index, parameters, + host_time, extrapolate_host, host_fit, + device_time, extrapolate_device, device_fit; + phases = KAExt.default_phases, + primary_axis = KAExt.default_primary_axis + ) + + # These definitions ought to be kept local to this scope. + initial_fit = [0.0, 1.0] + @inline function model(n, p) + return @. p[1] + p[2] * n + end + + # Each qubit requires 2 bits. + nqubits = nqubits_pauli(sizes_MiB[index]) + point_count = ifelse( + include_threshold_point, + index, + index - one(index)) + fit_nqubits = nqubits_pauli.(sizes_MiB[Base.OneTo(point_count)]) + + @cached cache begin + # BenchmarkTools evaluates the setup block at the global scope. + global host_pauli = random_pauli(nqubits) + global device_pauli = adapt(AT, host_pauli) + global host_temp = copy(host_pauli) + global device_temp = copy(device_pauli) + host_pauli_mul = random_pauli(nqubits) + device_pauli_mul = adapt(AT, host_pauli_mul) + + if extrapolate_host[1] + host_time[index] = model(nqubits, host_fit[1]) + time_span = Second(0) + else + # Trigger compilation before benchmarking. + host_mul!(host_temp, host_pauli_mul; phases = Val(phases)) + start = now() + host_time[index] = @belapsed host_mul!( + $host_temp, $host_pauli_mul; phases = Val($phases) + ) evals = evals samples = samples seconds = seconds setup = ( + copy_to!(host_temp, host_pauli); + ) + time_span = now() - start + end + + if time_span >= Second(extrapolation_threshold) + if point_count >= length(initial_fit) + fit_time = host_time[Base.OneTo(point_count)] + fit = curve_fit(model, fit_nqubits, fit_time, initial_fit) + + extrapolate_host[1] = true + host_fit[1] = fit.param + + if !include_threshold_point + host_time[index] = model(nqubits, host_fit[1]) + end + elseif host_permit_simple_scaling + fit = copy(initial_fit) + fit[end] = host_time[index] / model(nqubits, fit) + extrapolate_host[1] = true + host_fit[1] = fit + else + return true + end + end + + for (parameters_index, (block, batch)) in pairs(parameters) + if extrapolate_device[parameters_index] + device_time[index, parameters_index] = + model(nqubits, device_fit[parameters_index]) + time_span = Second(0) + else + device_mul!( + synchronize, device_temp, device_pauli_mul; + phases = Val(phases), primary_axis = primary_axis, + block_size = block, batch_size = batch + ) + start = now() + device_time[index, parameters_index] = @belapsed device_mul!( + $synchronize, $device_temp, $device_pauli_mul; + phases = Val($phases), primary_axis = $primary_axis, + block_size = $block, batch_size = $batch + ) evals = evals samples = samples seconds = seconds setup = ( + @cached cache copy_to!(device_temp, device_pauli); + ) + time_span = now() - start + end + + if time_span >= Second(extrapolation_threshold) + if point_count >= length(initial_fit) + fit_time = + device_time[Base.OneTo(point_count), parameters_index] + fit = curve_fit(model, fit_nqubits, fit_time, initial_fit) + + extrapolate_device[parameters_index] = true + device_fit[parameters_index] = fit.param + + if !include_threshold_point + device_time[index, parameters_index] = + model(nqubits, device_fit[parameters_index]) + end + elseif device_permit_simple_scaling + fit = copy(initial_fit) + fit[end] = device_time[index, parameters_index] / model(nqubits, fit) + extrapolate_device[parameters_index] = true + device_fit[parameters_index] = fit + else + return true + end + end + end + end + + host_pauli = nothing + device_pauli = nothing + host_temp = nothing + device_temp = nothing + unsafe_free!(cache) + GC.gc(true) + + return false + +end + +function benchmark_KA_mul_tableau_pauli!( + synchronize, AT, + index, parameters, + host_time, extrapolate_host, host_fit, + device_time, extrapolate_device, device_fit; + phases = KAExt.default_phases, + primary_axis = KAExt.default_primary_axis + ) + + error_flag = false + + # These definitions ought to be kept local to this scope. + initial_fit = [0.0, 0.0, 1.0] + @inline function model(n, p) + return @. p[1] + p[2] * n + p[3] * n^2 + end + + nqubits = nqubits_tableau(sizes_MiB[index]) + point_count = ifelse( + include_threshold_point, + index, + index - one(index)) + fit_nqubits = nqubits_tableau.(sizes_MiB[Base.OneTo(point_count)]) + + @cached cache begin + # BenchmarkTools evaluates the setup block at the global scope. + global host_tableau = random_tableau(nqubits, nqubits) + global device_tableau = adapt(AT, host_tableau) + global host_temp = copy(host_tableau) + global device_temp = copy(device_tableau) + host_pauli = random_pauli(nqubits) + device_pauli = adapt(AT, host_pauli) + + if extrapolate_host[1] + host_time[index] = model(nqubits, host_fit[1]) + time_span = Second(0) + else + # Trigger compilation before benchmarking. + host_mul!(host_temp, host_pauli; phases = Val(phases)) + start = now() + host_time[index] = @belapsed host_mul!( + $host_temp, $host_pauli; phases = Val($phases) + ) evals = evals samples = samples seconds = seconds setup = ( + copy_to!(host_temp, host_tableau); + ) + time_span = now() - start + end + + if time_span >= Second(extrapolation_threshold) + if point_count >= length(initial_fit) + fit_time = host_time[Base.OneTo(point_count)] + fit = curve_fit(model, fit_nqubits, fit_time, initial_fit) + + extrapolate_host[1] = true + host_fit[1] = fit.param + + if !include_threshold_point + host_time[index] = model(nqubits, host_fit[1]) + end + elseif host_permit_simple_scaling + fit = copy(initial_fit) + fit[end] = host_time[index] / model(nqubits, fit) + extrapolate_host[1] = true + host_fit[1] = fit + else + return true + end + end + + for (parameters_index, (block, batch)) in pairs(parameters) + if extrapolate_device[parameters_index] + device_time[index, parameters_index] = + model(nqubits, device_fit[parameters_index]) + time_span = Second(0) + else + device_mul!( + synchronize, device_temp, device_pauli; + phases = Val(phases), primary_axis = primary_axis, + block_size = block, batch_size = batch + ) + start = now() + device_time[index, parameters_index] = @belapsed device_mul!( + $synchronize, $device_temp, $device_pauli; + phases = Val($phases), primary_axis = $primary_axis, + block_size = $block, batch_size = $batch + ) evals = evals samples = samples seconds = seconds setup = ( + @cached cache copy_to!(device_temp, device_tableau); + ) + time_span = now() - start + end + + if time_span >= Second(extrapolation_threshold) + if point_count >= length(initial_fit) + fit_time = + device_time[Base.OneTo(point_count), parameters_index] + fit = curve_fit(model, fit_nqubits, fit_time, initial_fit) + + extrapolate_device[parameters_index] = true + device_fit[parameters_index] = fit.param + + if !include_threshold_point + device_time[index, parameters_index] = + model(nqubits, device_fit[parameters_index]) + end + elseif device_permit_simple_scaling + fit = copy(initial_fit) + fit[end] = device_time[index, parameters_index] / model(nqubits, fit) + extrapolate_device[parameters_index] = true + device_fit[parameters_index] = fit + else + return true + end + end + end + end + + host_tableau = nothing + device_tableau = nothing + host_temp = nothing + device_temp = nothing + unsafe_free!(cache) + GC.gc(true) + + return error_flag + +end + +@inline function benchmark_KA_mul(synchronize, AT, path)::Nothing + for (phase, axis) in Iterators.product(phases, primary_axes) + temp = path * "/mul/phases_$phase" * "_$axis" + run_benchmark( + benchmark_KA_mul_pauli_pauli!, + synchronize, AT, + temp * "/pauli_pauli", + "Pauli-Pauli multiplication\n(phases_$phase, $axis)", + "Pauli operator size (MiB)"; + phases = phase, primary_axis = axis + ) + run_benchmark( + benchmark_KA_mul_tableau_pauli!, + synchronize, AT, + temp * "/tableau_pauli", + "Tableau-Pauli multiplication\n(phases_$phase, $axis)", + "Tableau size (MiB)"; + phases = phase, primary_axis = axis + ) + end + return nothing +end +#=============================================================================# diff --git a/benchmark/KernelAbstractions/implementation/utilities.jl b/benchmark/KernelAbstractions/implementation/utilities.jl index d55260606..953106dee 100644 --- a/benchmark/KernelAbstractions/implementation/utilities.jl +++ b/benchmark/KernelAbstractions/implementation/utilities.jl @@ -1,9 +1,6 @@ -# Works even when broadcasting on zero-dimensional arrays. -@inline function u32(v) - return map(x -> UInt32(x), v) -end -# By definition, the size of (unsigned) char is set to unity. -@inline function bit_count(::Type{T}) where {T} - return sizeof(T) * count_zeros(zero(Cuchar)) -end +#=============================================================================# +include("utilities/bit_manipulation.jl") +include("utilities/memory_management.jl") +include("utilities/benchmark_management.jl") +#=============================================================================# diff --git a/benchmark/KernelAbstractions/implementation/utilities/benchmark_management.jl b/benchmark/KernelAbstractions/implementation/utilities/benchmark_management.jl new file mode 100644 index 000000000..a863ed8c0 --- /dev/null +++ b/benchmark/KernelAbstractions/implementation/utilities/benchmark_management.jl @@ -0,0 +1,65 @@ + +#=============================================================================# +function run_benchmark( + benchmark_step!, synchronize, AT, path, title, xlabel; kwargs... + )::Nothing + + parameters = collect(Iterators.product(block_sizes, batch_sizes)) + host_time = zeros(Float64, length(sizes_MiB)) + extrapolate_host = [false] + host_fit = Vector{Vector}(undef, 1) + device_time = zeros(Float64, length(sizes_MiB), size(parameters)...) + extrapolate_device = similar(extrapolate_host, size(parameters)...) + fill!(extrapolate_device, false) + device_fit = similar(host_fit, size(parameters)...) + + error_flag = false + for index in Base.OneTo(length(sizes_MiB)) + # This is intentional for short-circuit evaluation. + error_flag = error_flag || benchmark_step!( + synchronize, AT, + index, parameters, + host_time, extrapolate_host, host_fit, + device_time, extrapolate_device, device_fit; + kwargs... + ) + end + + if error_flag + error_string = + "Unable to conclude benchmark ($title) within the current thresholds. \ + Consider relaxing the constraints of the provided configuration." + @error error_string + return nothing + end + + mkpath(path) + + device_cat = [device_time[:, i] for i in CartesianIndices(parameters)] + marked_labels = [x ? "Device*" : "Device" for x in extrapolate_device] + label = hcat( + [ + "$(marked_labels[i]) - (block, batch) = $x" + for (i, x) in enumerate(parameters) + ]..., + extrapolate_host[1] ? "Host*" : "Host" + ) + + plot( + sizes_MiB, 10^3 .* hcat(device_cat..., host_time); + plot_style..., yscale = :log10, title = title, label = label, + xlabel = xlabel, ylabel = "Runtime (ms)" + ) + savefig("$path/runtime.$file_format") + + plot( + sizes_MiB, hcat(map(x -> host_time ./ x, device_cat)...); + plot_style..., title = title, label = hcat(label[1 : (end - 1)]...), + xlabel = xlabel, ylabel = "Ratio (host/device)" + ) + savefig("$path/ratio.$file_format") + + return nothing + +end +#=============================================================================# diff --git a/benchmark/KernelAbstractions/implementation/utilities/bit_manipulation.jl b/benchmark/KernelAbstractions/implementation/utilities/bit_manipulation.jl new file mode 100644 index 000000000..2c21b8dec --- /dev/null +++ b/benchmark/KernelAbstractions/implementation/utilities/bit_manipulation.jl @@ -0,0 +1,23 @@ + +#=============================================================================# +# Repeats 0x55 to fill out all the bits in the given type. +@inline function alternating_bit_mask(::Type{T}) where {T <: Unsigned} + counter = count_zeros(zero(T)) >> one(T) + pattern = one(T) + while counter > one(counter) + pattern |= pattern << counter + counter >>= one(counter) + end + return pattern +end + +@inline function nqubits_pauli(size_MiB::Integer) + # Each qubit requires 2 bits. + return cld(size_MiB * MiB, 2) +end + +@inline function nqubits_tableau(size_MiB::Integer) + # Each qubit requires 2 bits. + return round(Int, sqrt(cld(size_MiB * MiB, 2)), RoundUp) +end +#=============================================================================# diff --git a/benchmark/KernelAbstractions/implementation/utilities/memory_management.jl b/benchmark/KernelAbstractions/implementation/utilities/memory_management.jl new file mode 100644 index 000000000..4f44fc974 --- /dev/null +++ b/benchmark/KernelAbstractions/implementation/utilities/memory_management.jl @@ -0,0 +1,34 @@ + +#=============================================================================# +# TODO: Remove these once the main package establishes them. +@inline function copy_to!( + target::PauliOperator, source::PauliOperator + ) + + target.nqubits == source.nqubits || throw(ArgumentError("BAD COPY_TO!")) + copyto!(target.phase, source.phase) + copyto!(target.xz, source.xz) + return target + +end + +@inline function copy_to!( + target::Tableau, source::Tableau + ) + + target.nqubits == source.nqubits || throw(ArgumentError("BAD COPY_TO!")) + copyto!(target.phases, source.phases) + copyto!(target.xzs, source.xzs) + return target + +end + +@inline function copy_to!( + target::AbstractStabilizer, source::AbstractStabilizer + ) + + copy_to!(tab(target), tab(source)) + return target + +end +#=============================================================================# diff --git a/ext/QuantumCliffordAdaptExt/QuantumCliffordAdaptExt.jl b/ext/QuantumCliffordAdaptExt/QuantumCliffordAdaptExt.jl index 26ab2c85e..2c8a5b5f9 100644 --- a/ext/QuantumCliffordAdaptExt/QuantumCliffordAdaptExt.jl +++ b/ext/QuantumCliffordAdaptExt/QuantumCliffordAdaptExt.jl @@ -3,7 +3,7 @@ module QuantumCliffordAdaptExt include("imports.jl") -include("../QuantumCliffordKAExt/definitions.jl") +include("../QuantumCliffordKAExt/definitions/word_size_integers.jl") include("utilities.jl") include("adapters.jl") diff --git a/ext/QuantumCliffordAdaptExt/adapters.jl b/ext/QuantumCliffordAdaptExt/adapters.jl index d364190a4..4742ffb63 100644 --- a/ext/QuantumCliffordAdaptExt/adapters.jl +++ b/ext/QuantumCliffordAdaptExt/adapters.jl @@ -1,7 +1,7 @@ #=============================================================================# # PauliOperator -function adapt_structure( +@inline function adapt_structure( AT::Type{T}, pauli::PauliOperator ) where {T <: AbstractArray} @@ -13,7 +13,7 @@ function adapt_structure( end -function adapt_structure( +@inline function adapt_structure( AT::Type{T}, pauli::PauliOperator ) where {T <: AbstractGPUArray} @@ -26,7 +26,7 @@ function adapt_structure( end # Tableau -function adapt_structure( +@inline function adapt_structure( AT::Type{T}, tab::Tableau ) where {T <: AbstractArray} @@ -38,7 +38,7 @@ function adapt_structure( end -function adapt_structure( +@inline function adapt_structure( AT::Type{T}, tab::Tableau ) where {T <: AbstractGPUArray} @@ -51,7 +51,7 @@ function adapt_structure( end # Stabilizer -function adapt_structure( +@inline function adapt_structure( AT::Type{T}, state::Stabilizer ) where {T <: AbstractArray} @@ -60,7 +60,7 @@ function adapt_structure( end # Destabilizer -function adapt_structure( +@inline function adapt_structure( AT::Type{T}, state::Destabilizer ) where {T <: AbstractArray} @@ -69,7 +69,7 @@ function adapt_structure( end # MixedStabilizer -function adapt_structure( +@inline function adapt_structure( AT::Type{T}, state::MixedStabilizer ) where {T <: AbstractArray} @@ -78,7 +78,7 @@ function adapt_structure( end # MixedDestabilizer -function adapt_structure( +@inline function adapt_structure( AT::Type{T}, state::MixedDestabilizer ) where {T <: AbstractArray} diff --git a/ext/QuantumCliffordAdaptExt/imports.jl b/ext/QuantumCliffordAdaptExt/imports.jl index ba75c0cc6..aa8c1e167 100644 --- a/ext/QuantumCliffordAdaptExt/imports.jl +++ b/ext/QuantumCliffordAdaptExt/imports.jl @@ -1,7 +1,6 @@ #=============================================================================# import Adapt: adapt_structure - using Adapt: adapt using GPUArraysCore: AbstractGPUArray diff --git a/ext/QuantumCliffordAdaptExt/utilities.jl b/ext/QuantumCliffordAdaptExt/utilities.jl index f4d19bb9c..d3685541b 100644 --- a/ext/QuantumCliffordAdaptExt/utilities.jl +++ b/ext/QuantumCliffordAdaptExt/utilities.jl @@ -20,9 +20,7 @@ zero(T) ) temp = reinterpret(S, output) - @inbounds ( - @view temp[Base.OneTo(len), Base.OneTo.(dims)...] - ) .= source + @inbounds (@view temp[Base.OneTo(len), Base.OneTo.(dims)...]) .= source end return output diff --git a/ext/QuantumCliffordGPUExt/canonicalization.jl b/ext/QuantumCliffordGPUExt/canonicalization.jl index 496fd9796..ea8b3177b 100644 --- a/ext/QuantumCliffordGPUExt/canonicalization.jl +++ b/ext/QuantumCliffordGPUExt/canonicalization.jl @@ -159,8 +159,8 @@ function gpu_canonicalize!(tableau::QuantumClifford.Tableau{<:CuArray{P}, <:CuAr return tableau end -function canonicalize!(stab::Stabilizer{<:QuantumClifford.Tableau{<:CuArray{P}, <:CuArray{T}}}; phases::Bool=true) where {T, P} +function cuda_canonicalize!(stab::Stabilizer{<:QuantumClifford.Tableau{<:CuArray{P}, <:CuArray{T}}}; phases::Bool=true) where {T, P} gpu_canonicalize!(tab(stab), phases) CUDA.synchronize() return stab -end \ No newline at end of file +end diff --git a/ext/QuantumCliffordKAExt/QuantumCliffordKAExt.jl b/ext/QuantumCliffordKAExt/QuantumCliffordKAExt.jl index ac2e0e649..ef6c2baf1 100644 --- a/ext/QuantumCliffordKAExt/QuantumCliffordKAExt.jl +++ b/ext/QuantumCliffordKAExt/QuantumCliffordKAExt.jl @@ -6,7 +6,8 @@ include("imports.jl") include("definitions.jl") include("../../src/throws.jl") include("utilities.jl") -include("mul_leftright.jl") + +include("mul.jl") include("canonicalization.jl") end diff --git a/ext/QuantumCliffordKAExt/README.md b/ext/QuantumCliffordKAExt/README.md index 63cbd0eb5..c1ed4fad2 100644 --- a/ext/QuantumCliffordKAExt/README.md +++ b/ext/QuantumCliffordKAExt/README.md @@ -16,7 +16,7 @@ In order to actually invoke its features, it is also pivotal to import the perti - It cannot be stressed enough that **ALL** the accelerated functionality is strictly *asynchronous* and that synchronisation barriers should be inserted as required. Please consult the relevant backend documentation for detailed instructions on this matter. - Certain function calls are presently *synchronous* due to external limitations imposed by the toolchain dependencies. They should still be treated as *asynchronous* from a concurrency perspective as this will hopefully be resolved in a future release. - Hardware limitations impose certain restrictions that are not present in the base package. Namely, the bitwidth of the phase variable(s) must be compatible with the usage of atomic intrinsics. The Adapt extension automatically handles this conversion but explicitly initialised objects must ensure their own compatibility. -- Wherever feasible, tuning parameters are exposed via `::Val` keyword arguments. Whilst the chosen defaults strive to be as performant as possible whilst maintaining generality, it may prove beneficial to further refine them to be more optimal for the underlying hardware. +- Tuning parameters are exposed via keyword arguments wherever feasible. Whilst the chosen defaults strive to be as performant as possible whilst maintaining generality, it may prove beneficial to further refine them to be more optimal for the particular characteristics of the underlying hardware. # Warnings diff --git a/ext/QuantumCliffordKAExt/canonicalization/canonicalize.jl b/ext/QuantumCliffordKAExt/canonicalization/canonicalize.jl index e98c304de..78099590e 100644 --- a/ext/QuantumCliffordKAExt/canonicalization/canonicalize.jl +++ b/ext/QuantumCliffordKAExt/canonicalization/canonicalize.jl @@ -1,51 +1,46 @@ #=============================================================================# +# TODO: Include the unsafe functions once the main package establishes them. import QuantumClifford: canonicalize! +# TODO: Implement support for (Mixed)Destabilizers. function device_canonicalize!( ph::AbstractArray{<: Unsigned}, xzs::AbstractArray{<: Unsigned}, output_buffer::Union{Nothing, AbstractArray{S}} = nothing; multiplication_order::MultiplicationOrder = default_multiplication_order, - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - )::Nothing where { - S <: Integer, phase_B, primary_axis_E, block_SZ, batch_SZ - } - - phase_B isa Bool && primary_axis_E isa PrimaryAxis && - block_SZ isa Integer && block_SZ > zero(block_SZ) && - batch_SZ isa Integer && batch_SZ > zero(batch_SZ) || - throw(ArgumentError(THROW_VALS)) + pauli_preferance::PauliPreferance = default_pauli_preferance, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + )::Nothing where {S <: Integer} backend = KA.get_backend(xzs) - if primary_axis_E == primary_axis_rows - tile = (one(block_SZ), block_SZ) + if primary_axis == primary_axis_rows + tile = (one(block_size), block_size) space = tessellate( - (size(xzs, 0x2), cld(size(xzs, 0x1) >> 0x1, batch_SZ)), + (size(xzs, 2), cld(size(xzs, 1) >> 1, batch_size)), tile ) - elseif primary_axis_E == primary_axis_qubits - tile = (block_SZ, one(block_SZ)) + elseif primary_axis == primary_axis_qubits + tile = (block_size, one(block_size)) space = tessellate( - (cld(size(xzs, 0x1) >> 0x1, batch_SZ), size(xzs, 0x2)), + (cld(size(xzs, 1) >> 1, batch_size), size(xzs, 2)), tile ) end # Utilised for loop management. - length_xzs = size(xzs, 0x1) >> 0x1 - row_count = size(xzs, 0x2) + length_xzs = size(xzs, 1) >> 1 + row_count = size(xzs, 2) toggle = false - cycles_until_sync = default_scheduling_limit # Required for safety whilst setting up for the proceeding iteration. mutex = create_mutex(backend) # Double buffered for present and preceeding/proceeding iteration. stride_fill = tracker_element_count - tracker = similar(xzs, Csize_t, stride_fill << 0x1) + tracker = similar(xzs, Csize_t, stride_fill << 1) fill!(tracker, typemax(Csize_t)) # The pivot row tracker is initialised differently. begin_fill = Integer(tracker_content_swap_to) @@ -58,22 +53,22 @@ function device_canonicalize!( fill!(output_buffer, zero(S)) end - bit_scan = kernel_bit_scan(backend) + bit_scan! = kernel_bit_scan!(backend) snippet! = kernel_snippet!(backend) mul_and_scan! = kernel_mul_and_scan!(backend) - bit_scan( - xzs, nothing, mutex, tracker, - Val(sort_order_pauli_bit_prefer_x), - primary_axis, block_size, batch_size; + bit_scan!( + tracker, mutex, xzs, nothing, primary_axis, + Val(pauli_preferance), Val(sort_order_pauli_bit), + Val(block_size), Val(batch_size); workgroupsize = tile, ndrange = space ) snippet!( snippet_track_pivot_canonicalize!, - output_buffer, tracker, toggle, sort_order_pauli_bit_prefer_x; - ndrange = 0x1 + output_buffer, tracker, toggle; + ndrange = 1 ) - for _ in one(row_count) : row_count + for _ in Base.OneTo(row_count) snippet!( snippet_swap_rows_prepare_tracker!, ph, xzs, tracker, toggle; @@ -81,38 +76,24 @@ function device_canonicalize!( ) snippet!( snippet_set_row_phase_flag!, - ph, xzs, tracker, toggle; + ph, xzs, tracker, toggle, pauli_preferance; ndrange = row_count ) mul_and_scan!( - ph, xzs, multiplication_order, scan_side_greater, - nothing, false, mutex, tracker, toggle, - phases, Val(sort_order_pauli_bit_prefer_x), - primary_axis, block_size, batch_size; + ph, xzs, tracker, toggle, mutex, nothing, false, + scan_side_greater, multiplication_order, primary_axis, + Val(pauli_preferance), Val(sort_order_pauli_bit), + Val(phases), Val(block_size), Val(batch_size); workgroupsize = tile, ndrange = space ) # Switching the toggle is intentional. snippet!( snippet_track_pivot_canonicalize!, - output_buffer, tracker, !toggle, sort_order_pauli_bit_prefer_x; - ndrange = 0x1 + output_buffer, tracker, !toggle; + ndrange = 1 ) toggle = !toggle - cycles_until_sync -= one(cycles_until_sync) - if cycles_until_sync == zero(cycles_until_sync) - KA.synchronize(backend) - host_tracker = Array(tracker) - offset = ifelse(toggle, tracker_element_count, 0x0) - @inbounds continue_flag = - host_tracker[offset + Integer(tracker_content_bit_type)] < - Integer(pauli_bit_invalid) - if continue_flag - cycles_until_sync = default_scheduling_limit - else - break - end - end end # Remove all extraneous bits left behind by previous iterations. @@ -122,78 +103,58 @@ function device_canonicalize!( end -# Tableau +# Tableau/AbstractStabilizer @inline function canonicalize!( - tab::DeviceTableau, + state::DeviceUnionTableau, output_buffer::Union{Nothing, AbstractGPUArray{<: Integer}} = nothing; multiplication_order::MultiplicationOrder = default_multiplication_order, - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + pauli_preferance::PauliPreferance = default_pauli_preferance, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) @boundscheck begin + block_size > zero(block_size) && batch_size > zero(batch_size) || + throw(DomainError(THROW_PARAMETERS)) if !isnothing(output_buffer) - length(output_buffer) == 0x2 || + length(output_buffer) == 2 || throw(ArgumentError(THROW_SIZE)) end end - - device_canonicalize!( - tab.phases, tab.xzs, output_buffer; + return do_canonicalize!( + state, output_buffer; multiplication_order = multiplication_order, - phases = phases, primary_axis = primary_axis, + pauli_preferance = pauli_preferance, + primary_axis = primary_axis, phases = phases, block_size = block_size, batch_size = batch_size ) - if isnothing(output_buffer) - return tab - else - return tab, output_buffer - end - end -# AbstractStabilizer -@inline function canonicalize!( - state::DeviceAbstractStabilizer, +@inline function do_canonicalize!( + state::DeviceUnionTableau, output_buffer::Union{Nothing, AbstractGPUArray{<: Integer}} = nothing; multiplication_order::MultiplicationOrder = default_multiplication_order, - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - @boundscheck begin - if !isnothing(output_buffer) - length(output_buffer) == 0x2 || - throw(ArgumentError(THROW_SIZE)) - end - end - - if state isa Stabilizer - upper = size(state.tab.xzs, 0x2) - lower = one(upper) - elseif state isa Destabilizer - upper = size(state.tab.xzs, 0x2) - lower = (upper >> one(upper)) + one(upper) - elseif state isa MixedStabilizer - upper = state.rank - lower = one(upper) - elseif state isa MixedDestabilizer - upper = size(state.tab.xzs, 0x2) - lower = (upper >> one(upper)) + one(upper) - upper = (upper >> one(upper)) + state.rank + pauli_preferance::PauliPreferance = default_pauli_preferance, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + if state isa AbstractStabilizer + state_tab = tab(stabilizerview(state)) + elseif state isa Tableau + state_tab = state end - @inbounds device_canonicalize!( - (@view state.tab.phases[lower : upper]), - (@view state.tab.xzs[:, lower : upper]), - output_buffer; + device_canonicalize!( + state_tab.phases, state_tab.xzs, output_buffer; multiplication_order = multiplication_order, - phases = phases, primary_axis = primary_axis, + pauli_preferance = pauli_preferance, + primary_axis = primary_axis, phases = phases, block_size = block_size, batch_size = batch_size ) diff --git a/ext/QuantumCliffordKAExt/canonicalization/canonicalize_gott.jl b/ext/QuantumCliffordKAExt/canonicalization/canonicalize_gott.jl index 85b64d785..38b41c416 100644 --- a/ext/QuantumCliffordKAExt/canonicalization/canonicalize_gott.jl +++ b/ext/QuantumCliffordKAExt/canonicalization/canonicalize_gott.jl @@ -1,5 +1,6 @@ #=============================================================================# +# TODO: Include the unsafe functions once the main package establishes them. import QuantumClifford: canonicalize_gott! # TODO: Implement. #=============================================================================# diff --git a/ext/QuantumCliffordKAExt/canonicalization/canonicalize_rref.jl b/ext/QuantumCliffordKAExt/canonicalization/canonicalize_rref.jl index f25087494..f508dd903 100644 --- a/ext/QuantumCliffordKAExt/canonicalization/canonicalize_rref.jl +++ b/ext/QuantumCliffordKAExt/canonicalization/canonicalize_rref.jl @@ -1,52 +1,48 @@ #=============================================================================# +# TODO: Include the unsafe functions once the main package establishes them. import QuantumClifford: canonicalize_rref! +# TODO: Implement support for (Mixed)Destabilizers. function device_canonicalize_rref!( ph::AbstractArray{<: Unsigned}, xzs::AbstractArray{T}, output_buffer::Union{Nothing, AbstractArray{<: Integer}} = nothing, bit_masks::Union{Nothing, AbstractArray{T}} = nothing; multiplication_order::MultiplicationOrder = default_multiplication_order, - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - )::Nothing where { - T <: Unsigned, phase_B, primary_axis_E, block_SZ, batch_SZ - } - - phase_B isa Bool && primary_axis_E isa PrimaryAxis && - block_SZ isa Integer && block_SZ > zero(block_SZ) && - batch_SZ isa Integer && batch_SZ > zero(batch_SZ) || - throw(ArgumentError(THROW_VALS)) + pauli_preferance::PauliPreferance = default_pauli_preferance, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + )::Nothing where {T <: Unsigned} backend = KA.get_backend(xzs) - if primary_axis_E == primary_axis_rows - tile = (one(block_SZ), block_SZ) + if primary_axis == primary_axis_rows + tile = (one(block_size), block_size) space = tessellate( - (size(xzs, 0x2), cld(size(xzs, 0x1) >> 0x1, batch_SZ)), + (size(xzs, 2), cld(size(xzs, 1) >> 1, batch_size)), tile ) - elseif primary_axis_E == primary_axis_qubits - tile = (block_SZ, one(block_SZ)) + elseif primary_axis == primary_axis_qubits + tile = (block_size, one(block_size)) space = tessellate( - (cld(size(xzs, 0x1) >> 0x1, batch_SZ), size(xzs, 0x2)), + (cld(size(xzs, 1) >> 1, batch_size), size(xzs, 2)), tile ) end # Utilised for loop management. - length_xzs = size(xzs, 0x1) >> 0x1 - row_count = size(xzs, 0x2) + length_xzs = size(xzs, 1) >> 1 + row_count = size(xzs, 2) toggle = false - cycles_until_sync = default_scheduling_limit + shrink_workspace = isnothing(bit_masks) # Required for safety whilst setting up for the proceeding iteration. mutex = create_mutex(backend) # Double buffered for present and preceeding/proceeding iteration. stride_fill = tracker_element_count - tracker = similar(xzs, Csize_t, stride_fill << 0x1) + tracker = similar(xzs, Csize_t, stride_fill << 1) fill!(tracker, typemax(Csize_t)) # The pivot row tracker is initialised differently. begin_fill = Integer(tracker_content_swap_to) @@ -59,17 +55,17 @@ function device_canonicalize_rref!( fill!(output_buffer, row_count) end - bit_scan = kernel_bit_scan(backend) + bit_scan! = kernel_bit_scan!(backend) snippet! = kernel_snippet!(backend) mul_and_scan! = kernel_mul_and_scan!(backend) - bit_scan( - xzs, bit_masks, mutex, tracker, - Val(sort_order_qubit_number_prefer_x), - primary_axis, block_size, batch_size; + bit_scan!( + tracker, mutex, xzs, bit_masks, primary_axis, + Val(pauli_preferance), Val(sort_order_qubit_number), + Val(block_size), Val(batch_size); workgroupsize = tile, ndrange = space ) - for _ in one(row_count) : row_count + for _ in Base.OneTo(row_count) snippet!( snippet_swap_rows_prepare_tracker!, ph, xzs, tracker, toggle; @@ -77,38 +73,24 @@ function device_canonicalize_rref!( ) snippet!( snippet_set_row_phase_flag!, - ph, xzs, tracker, toggle; + ph, xzs, tracker, toggle, pauli_preferance; ndrange = row_count ) mul_and_scan!( - ph, xzs, multiplication_order, scan_side_lesser, - bit_masks, isnothing(bit_masks), mutex, tracker, toggle, - phases, Val(sort_order_qubit_number_prefer_x), - primary_axis, block_size, batch_size; + ph, xzs, tracker, toggle, mutex, bit_masks, shrink_workspace, + scan_side_lesser, multiplication_order, primary_axis, + Val(pauli_preferance), Val(sort_order_qubit_number), + Val(phases), Val(block_size), Val(batch_size); workgroupsize = tile, ndrange = space ) # Switching the toggle is intentional. snippet!( snippet_track_pivot_canonicalize_rref!, output_buffer, tracker, !toggle; - ndrange = 0x1 + ndrange = 1 ) toggle = !toggle - cycles_until_sync -= one(cycles_until_sync) - if cycles_until_sync == zero(cycles_until_sync) - KA.synchronize(backend) - host_tracker = Array(tracker) - offset = ifelse(toggle, tracker_element_count, 0x0) - @inbounds continue_flag = - host_tracker[offset + Integer(tracker_content_bit_type)] < - Integer(pauli_bit_invalid) - if continue_flag - cycles_until_sync = default_scheduling_limit - else - break - end - end end # Remove all extraneous bits left behind by previous iterations. @@ -118,88 +100,64 @@ function device_canonicalize_rref!( end -# Tableau +# Tableau/AbstractStabilizer @inline function canonicalize_rref!( - tab::DeviceTableau, + state::DeviceUnionTableau, output_buffer::Union{Nothing, AbstractGPUArray{<: Integer}} = nothing, bit_masks::Union{Nothing, AbstractGPUArray{<: Unsigned}} = nothing; multiplication_order::MultiplicationOrder = default_multiplication_order, - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} + pauli_preferance::PauliPreferance = default_pauli_preferance, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) @boundscheck begin + block_size > zero(block_size) && batch_size > zero(batch_size) || + throw(DomainError(THROW_PARAMETERS)) if !isnothing(output_buffer) - length(output_buffer) == 0x1 || + length(output_buffer) == 1 || throw(ArgumentError(THROW_SIZE)) end if !isnothing(bit_masks) - length(bit_masks) == size(tab.xzs, 0x1) >> 0x1 || + length(bit_masks) == size(tab(state).xzs, 1) >> 1 || throw(DimensionMismatch(THROW_SIZE)) end end - - device_canonicalize!( - tab.phases, tab.xzs, output_buffer, bit_masks; + return do_canonicalize_rref!( + state, output_buffer, bit_masks; multiplication_order = multiplication_order, - phases = phases, primary_axis = primary_axis, + pauli_preferance = pauli_preferance, + primary_axis = primary_axis, phases = phases, block_size = block_size, batch_size = batch_size ) - if isnothing(output_buffer) - return tab - else - return tab, output_buffer - end - end -# AbstractStabilizer -@inline function canonicalize_rref!( - state::DeviceAbstractStabilizer, +@inline function do_canonicalize_rref!( + state::DeviceUnionTableau, output_buffer::Union{Nothing, AbstractGPUArray{<: Integer}} = nothing, bit_masks::Union{Nothing, AbstractGPUArray{<: Unsigned}} = nothing; multiplication_order::MultiplicationOrder = default_multiplication_order, - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - @boundscheck begin - if !isnothing(output_buffer) - length(output_buffer) == 0x1 || - throw(ArgumentError(THROW_SIZE)) - end - if !isnothing(bit_masks) - length(bit_masks) == size(state.tab.xzs, 0x1) >> 0x1 || - throw(DimensionMismatch(THROW_SIZE)) - end - end - - if state isa Stabilizer - upper = size(state.tab.xzs, 0x2) - lower = one(upper) - elseif state isa Destabilizer - upper = size(state.tab.xzs, 0x2) - lower = (upper >> one(upper)) + one(upper) - elseif state isa MixedStabilizer - upper = state.rank - lower = one(upper) - elseif state isa MixedDestabilizer - upper = size(state.tab.xzs, 0x2) - lower = (upper >> one(upper)) + one(upper) - upper = (upper >> one(upper)) + state.rank + pauli_preferance::PauliPreferance = default_pauli_preferance, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + if state isa AbstractStabilizer + state_tab = tab(stabilizerview(state)) + elseif state isa Tableau + state_tab = state end - @inbounds device_canonicalize_rref!( - (@view state.tab.phases[lower : upper]), - (@view state.tab.xzs[:, lower : upper]), - output_buffer, bit_masks; + device_canonicalize_rref!( + state_tab.phases, state_tab.xzs, output_buffer, bit_masks; multiplication_order = multiplication_order, - phases = phases, primary_axis = primary_axis, + pauli_preferance = pauli_preferance, + primary_axis = primary_axis, phases = phases, block_size = block_size, batch_size = batch_size ) diff --git a/ext/QuantumCliffordKAExt/canonicalization/common.jl b/ext/QuantumCliffordKAExt/canonicalization/common.jl index e27cdddec..fc15df92b 100644 --- a/ext/QuantumCliffordKAExt/canonicalization/common.jl +++ b/ext/QuantumCliffordKAExt/canonicalization/common.jl @@ -1,31 +1,33 @@ #=============================================================================# # TODO: Make the parameters keyword arguments once support becomes available. -KA.@kernel inbounds = true unsafe_indices = true function kernel_bit_scan( - xzs::AbstractArray{T}, bit_masks::Union{Nothing, AbstractArray{T}}, - mutex::AbstractMutex, tracker::AbstractArray{S}, - ::Val{sort_order}, - ::Val{primary_axis}, ::Val{block_size}, ::Val{batch_size} +KA.@kernel inbounds = true unsafe_indices = true function kernel_bit_scan!( + tracker::AbstractArray{S}, mutex::AbstractMutex, + @Const(xzs::AbstractArray{T}), + @Const(bit_masks::Union{Nothing, AbstractArray{T}}), + @Const(primary_axis::PrimaryAxis), + ::Val{pauli_preferance}, ::Val{sort_order}, + ::Val{block_size}, ::Val{batch_size} ) where { - T <: Unsigned, S <: Unsigned, - sort_order, primary_axis, block_size, batch_size + S <: Unsigned, T <: Unsigned, + pauli_preferance, sort_order, block_size, batch_size } if primary_axis == primary_axis_rows j, begin_i = global_index( KA.@index(Group, NTuple), KA.@groupsize(), KA.@index(Local, NTuple) ) - stride_i = KA.@ndrange()[0x2] + stride_i = KA.@ndrange()[2] elseif primary_axis == primary_axis_qubits begin_i, j = global_index( KA.@index(Group, NTuple), KA.@groupsize(), KA.@index(Local, NTuple) ) - stride_i = KA.@ndrange()[0x1] + stride_i = KA.@ndrange()[1] end - end_i = KA.@uniform (size(xzs, 0x1) >> 0x1) + end_i = KA.@uniform (size(xzs, 1) >> 1) - index = typemax(S) - bit_shift = typemax(S) + index::S = typemax(S) + bit_shift::S = typemax(S) bit_type = pauli_bit_invalid index_buffer = KA.@localmem S block_size bit_shift_buffer = KA.@localmem S block_size @@ -33,7 +35,7 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_bit_scan( scan_target = @view xzs[:, j] - for (i, _) in zip(begin_i : stride_i : end_i, one(batch_size) : batch_size) + for (i, _) in zip(begin_i : stride_i : end_i, Base.OneTo(batch_size)) x_bits = scan_target[i] z_bits = scan_target[i + end_i] if !isnothing(bit_masks) @@ -43,44 +45,30 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_bit_scan( end index, bit_shift, bit_type, break_flag = scan_step( - x_bits, z_bits, i, index, bit_shift, bit_type, Val(sort_order) + x_bits, z_bits, i, index, bit_shift, bit_type, + Val(pauli_preferance), Val(sort_order) ) break_flag && break end - local_index = KA.@index(Local, Linear) + local_index = DeviceUnsigned(KA.@index(Local, Linear)) bit_shift_buffer[local_index] = bit_shift index_buffer[local_index] = index - if sort_order in ( - sort_order_pauli_bit_prefer_x, sort_order_qubit_number_prefer_x - ) - - bit_type_buffer[local_index] = Integer(bit_type) - - elseif sort_order in ( - sort_order_pauli_bit_prefer_z, sort_order_qubit_number_prefer_z - ) - - # This reverses the ordering of the Pauli bits. - bit_type_buffer[local_index] = xor(Integer(bit_type), 0x1) + bit_type_buffer[local_index] = Integer(bit_type) - end - - if sort_order in ( - sort_order_pauli_bit_prefer_x, sort_order_pauli_bit_prefer_z - ) + if sort_order == sort_order_pauli_bit shared_memory_reduce!( - reduce_lexicographic_min!, local_index, Val(block_size), + reduce_lexicographic_min!, + local_index, Val(DeviceUnsigned(block_size)), bit_type_buffer, index_buffer, bit_shift_buffer ) - elseif sort_order in ( - sort_order_qubit_number_prefer_x, sort_order_qubit_number_prefer_z - ) + elseif sort_order == sort_order_qubit_number shared_memory_reduce!( - reduce_lexicographic_min!, local_index, Val(block_size), + reduce_lexicographic_min!, + local_index, Val(DeviceUnsigned(block_size)), index_buffer, bit_shift_buffer, bit_type_buffer ) @@ -88,21 +76,7 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_bit_scan( if local_index == one(local_index) - if sort_order in ( - sort_order_pauli_bit_prefer_x, sort_order_qubit_number_prefer_x - ) - - enumless_bit_type = bit_type_buffer[local_index] - - elseif sort_order in ( - sort_order_pauli_bit_prefer_z, sort_order_qubit_number_prefer_z - ) - - # This reverses the ordering of the Pauli bits. - enumless_bit_type = xor(bit_type_buffer[local_index], 0x1) - - end - + enumless_bit_type = bit_type_buffer[local_index] # Avoid mutex contention if there is no valid contribution. if enumless_bit_type < Integer(pauli_bit_invalid) index = index_buffer[local_index] @@ -121,38 +95,20 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_bit_scan( temp_bit_type = tracker[Integer(tracker_content_bit_type)] temp_row = tracker[Integer(tracker_content_swap_from)] - if sort_order == sort_order_pauli_bit_prefer_x + if sort_order == sort_order_pauli_bit lower_than_current = isless( (enumless_bit_type, index, bit_shift, j), (temp_bit_type, temp_index, temp_bit_shift, temp_row) ) - elseif sort_order == sort_order_pauli_bit_prefer_z - - # This reverses the ordering of the Pauli bits. - temp_bit_type = xor(temp_bit_type, 0x1) - lower_than_current = isless( - (xor(enumless_bit_type, 0x1), index, bit_shift, j), - (temp_bit_type, temp_index, temp_bit_shift, temp_row) - ) - - elseif sort_order == sort_order_qubit_number_prefer_x + elseif sort_order == sort_order_qubit_number lower_than_current = isless( (index, bit_shift, enumless_bit_type, j), (temp_index, temp_bit_shift, temp_bit_type, temp_row) ) - elseif sort_order == sort_order_qubit_number_prefer_z - - # This reverses the ordering of the Pauli bits. - temp_bit_type = xor(temp_bit_type, 0x1) - lower_than_current = isless( - (index, bit_shift, xor(enumless_bit_type, 0x1), j), - (temp_index, temp_bit_shift, temp_bit_type, temp_row) - ) - end if lower_than_current @@ -179,43 +135,45 @@ end # TODO: Make the parameters keyword arguments once support becomes available. KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( ph::AbstractArray{P}, xzs::AbstractArray{T}, - multiplication_order::MultiplicationOrder, scan_side::ScanSide, - bit_masks::Union{Nothing, AbstractArray{T}}, shrink_workspace::Bool, - mutex::AbstractMutex, tracker::AbstractArray{S}, toggle::Bool, - ::Val{phases}, ::Val{sort_order}, - ::Val{primary_axis}, ::Val{block_size}, ::Val{batch_size} + tracker::AbstractArray{S}, toggle::Bool, mutex::AbstractMutex, + @Const(bit_masks::Union{Nothing, AbstractArray{T}}), + @Const(shrink_workspace::Bool), @Const(scan_side::ScanSide), + @Const(multiplication_order::MultiplicationOrder), + @Const(primary_axis::PrimaryAxis), + ::Val{pauli_preferance}, ::Val{sort_order}, + ::Val{phases}, ::Val{block_size}, ::Val{batch_size} ) where { - P <: Unsigned, T <: Unsigned, S <: Unsigned, - phases, sort_order, primary_axis, block_size, batch_size + S <: Unsigned, P <: Unsigned, T <: Unsigned, + pauli_preferance, sort_order, phases, block_size, batch_size } if primary_axis == primary_axis_rows j, begin_i = global_index( KA.@index(Group, NTuple), KA.@groupsize(), KA.@index(Local, NTuple) ) - stride_i = KA.@ndrange()[0x2] + stride_i = KA.@ndrange()[2] elseif primary_axis == primary_axis_qubits begin_i, j = global_index( KA.@index(Group, NTuple), KA.@groupsize(), KA.@index(Local, NTuple) ) - stride_i = KA.@ndrange()[0x1] + stride_i = KA.@ndrange()[1] end - end_i = KA.@uniform (size(xzs, 0x1) >> 0x1) + end_i = KA.@uniform (size(xzs, 1) >> 1) - local_index = KA.@index(Local, Linear) + local_index = DeviceUnsigned(KA.@index(Local, Linear)) # Layout is [index, row]. shared_parameters = KA.@localmem S 2 # Dense storage as flags in a bit field. shared_flags = KA.@localmem DeviceUnsigned 1 if local_index == one(local_index) - shared_flags[0x1] = ifelse( + shared_flags[1] = ifelse( begin_i == one(begin_i), Integer(bit_field_flag_leader), zero(DeviceUnsigned) ) - current = ifelse(toggle, tracker_element_count, 0x0) + current = ifelse(toggle, tracker_element_count, zero(Csize_t)) temp_index = tracker[current + Integer(tracker_content_index)] temp_bit_type = tracker[current + Integer(tracker_content_bit_type)] temp_row = tracker[current + Integer(tracker_content_swap_to)] @@ -232,8 +190,8 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( temp_bit_type < Integer(pauli_bit_invalid) if continue_flag - shared_parameters[0x1] = temp_index - shared_parameters[0x2] = temp_row + shared_parameters[1] = temp_index + shared_parameters[2] = temp_row if scan_side == scan_side_lesser flags = ifelse( @@ -249,18 +207,19 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( ) end + # Top bit holds the multiplication flag. flags |= ifelse( ph[j] & top_bits(0x1, P) != zero(P), Integer(bit_field_flag_multiply), zero(DeviceUnsigned) ) - shared_flags[0x1] |= flags + shared_flags[1] |= flags end end KA.@synchronize() - flags = shared_flags[0x1] + flags = shared_flags[1] if flags & ( Integer(bit_field_flag_scan) | Integer(bit_field_flag_multiply) @@ -272,8 +231,8 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( phase_buffer = KA.@localmem DeviceUnsigned block_size end - index = typemax(S) - bit_shift = typemax(S) + index::S = typemax(S) + bit_shift::S = typemax(S) bit_type = pauli_bit_invalid index_buffer = KA.@localmem S block_size bit_shift_buffer = KA.@localmem S block_size @@ -281,7 +240,7 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( new_begin_i = ifelse( shrink_workspace, - begin_i + shared_parameters[0x1] - one(S), + begin_i + shared_parameters[1] - one(S), begin_i + zero(S) ) @@ -296,7 +255,7 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( scan_target = @view xzs[:, j] - for (i, _) in zip(begin_i : stride_i : end_i, one(batch_size) : batch_size) + for (i, _) in zip(begin_i : stride_i : end_i, Base.OneTo(batch_size)) x_bits = scan_target[i] z_bits = scan_target[i + end_i] if !isnothing(bit_masks) @@ -306,7 +265,8 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( end index, bit_shift, bit_type, break_flag = scan_step( - x_bits, z_bits, i, index, bit_shift, bit_type, Val(sort_order) + x_bits, z_bits, i, index, bit_shift, bit_type, + Val(pauli_preferance), Val(sort_order) ) break_flag && break end @@ -315,11 +275,11 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( write_xzs = @view xzs[:, j] if multiplication_order == multiplication_order_left - left = @view xzs[:, shared_parameters[0x2]] + left = @view xzs[:, shared_parameters[2]] right = write_xzs elseif multiplication_order == multiplication_order_right left = write_xzs - right = @view xzs[:, shared_parameters[0x2]] + right = @view xzs[:, shared_parameters[2]] end # Equivalent to mul!. @@ -327,7 +287,7 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( Integer(bit_field_flag_scan) | Integer(bit_field_flag_multiply) ) == Integer(bit_field_flag_multiply) - for (i, _) in zip(begin_i : stride_i : end_i, one(batch_size) : batch_size) + for (i, _) in zip(begin_i : stride_i : end_i, Base.OneTo(batch_size)) x_left = left[i] z_left = left[i + end_i] x_right = right[i] @@ -349,7 +309,7 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( # Equivalent to joint mul! and a modified bit_scan. else - for (i, _) in zip(begin_i : stride_i : end_i, one(batch_size) : batch_size) + for (i, _) in zip(begin_i : stride_i : end_i, Base.OneTo(batch_size)) x_left = left[i] z_left = left[i + end_i] x_right = right[i] @@ -374,7 +334,8 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( end index, bit_shift, bit_type = scan_step( - x_new, z_new, i, index, bit_shift, bit_type, Val(sort_order) + x_new, z_new, i, index, bit_shift, bit_type, + Val(pauli_preferance), Val(sort_order) ) end @@ -392,21 +353,29 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( if phases if flags & Integer(bit_field_flag_multiply) != zero(DeviceUnsigned) phase_buffer[local_index] = - ((count_ones(high) << 0x1) + count_ones(low)) & 0x3 + ((count_ones(high) << 1) + count_ones(low)) & 0x3 shared_memory_reduce!( - reduce_sum!, local_index, Val(block_size), phase_buffer + reduce_sum!, + local_index, Val(DeviceUnsigned(block_size)), + phase_buffer ) if local_index == one(local_index) temp_ph = phase_buffer[local_index] if flags & Integer(bit_field_flag_leader) != zero(DeviceUnsigned) - temp_ph += ph[shared_parameters[0x2]] + temp_ph += ph[shared_parameters[2]] + end + # Avoid expensive operations when they are not required. + if stride_i > block_size + # CAUTION: This memory order is sufficient. + @atomic :monotonic ph[j] += temp_ph & 0x3 + # CAUTION: Avoid nullifying the multiplication flag bit! + @atomic :monotonic ph[j] &= 0x3 | top_bits(0x1, P) + else + ph[j] += temp_ph & 0x3 + ph[j] &= 0x3 end - # CAUTION: This is sufficient since only atomicity is required. - @atomic :monotonic ph[j] += temp_ph & 0x3 - # CAUTION: Avoid nullifying the multiplication flag bit! - @atomic :monotonic ph[j] &= 0x3 | top_bits(0x1, P) end end end @@ -415,36 +384,21 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( if flags & Integer(bit_field_flag_scan) != zero(DeviceUnsigned) bit_shift_buffer[local_index] = bit_shift index_buffer[local_index] = index - if sort_order in ( - sort_order_pauli_bit_prefer_x, sort_order_qubit_number_prefer_x - ) - - bit_type_buffer[local_index] = Integer(bit_type) - - elseif sort_order in ( - sort_order_pauli_bit_prefer_z, sort_order_qubit_number_prefer_z - ) - - # This reverses the ordering of the Pauli bits. - bit_type_buffer[local_index] = xor(Integer(bit_type), 0x1) + bit_type_buffer[local_index] = Integer(bit_type) - end - - if sort_order in ( - sort_order_pauli_bit_prefer_x, sort_order_pauli_bit_prefer_z - ) + if sort_order == sort_order_pauli_bit shared_memory_reduce!( - reduce_lexicographic_min!, local_index, Val(block_size), + reduce_lexicographic_min!, + local_index, Val(DeviceUnsigned(block_size)), bit_type_buffer, index_buffer, bit_shift_buffer ) - elseif sort_order in ( - sort_order_qubit_number_prefer_x, sort_order_qubit_number_prefer_z - ) + elseif sort_order == sort_order_qubit_number shared_memory_reduce!( - reduce_lexicographic_min!, local_index, Val(block_size), + reduce_lexicographic_min!, + local_index, Val(DeviceUnsigned(block_size)), index_buffer, bit_shift_buffer, bit_type_buffer ) @@ -452,26 +406,12 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( if local_index == one(local_index) - if sort_order in ( - sort_order_pauli_bit_prefer_x, sort_order_qubit_number_prefer_x - ) - - enumless_bit_type = bit_type_buffer[local_index] - - elseif sort_order in ( - sort_order_pauli_bit_prefer_z, sort_order_qubit_number_prefer_z - ) - - # This reverses the ordering of the Pauli bits. - enumless_bit_type = xor(bit_type_buffer[local_index], 0x1) - - end - + enumless_bit_type = bit_type_buffer[local_index] # Avoid mutex contention if there is no valid contribution. if enumless_bit_type < Integer(pauli_bit_invalid) index = index_buffer[local_index] bit_shift = bit_shift_buffer[local_index] - next = ifelse(toggle, 0x0, tracker_element_count) + next = ifelse(toggle, zero(Csize_t), tracker_element_count) lock_mutex!( mutex, @@ -488,38 +428,20 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( tracker[next + Integer(tracker_content_bit_type)] temp_row = tracker[next + Integer(tracker_content_swap_from)] - if sort_order == sort_order_pauli_bit_prefer_x + if sort_order == sort_order_pauli_bit lower_than_current = isless( (enumless_bit_type, index, bit_shift, j), (temp_bit_type, temp_index, temp_bit_shift, temp_row) ) - elseif sort_order == sort_order_pauli_bit_prefer_z - - # This reverses the ordering of the Pauli bits. - temp_bit_type = xor(temp_bit_type, 0x1) - lower_than_current = isless( - (xor(enumless_bit_type, 0x1), index, bit_shift, j), - (temp_bit_type, temp_index, temp_bit_shift, temp_row) - ) - - elseif sort_order == sort_order_qubit_number_prefer_x + elseif sort_order == sort_order_qubit_number lower_than_current = isless( (index, bit_shift, enumless_bit_type, j), (temp_index, temp_bit_shift, temp_bit_type, temp_row) ) - elseif sort_order == sort_order_qubit_number_prefer_z - - # This reverses the ordering of the Pauli bits. - temp_bit_type = xor(temp_bit_type, 0x1) - lower_than_current = isless( - (index, bit_shift, xor(enumless_bit_type, 0x1), j), - (temp_index, temp_bit_shift, temp_bit_type, temp_row) - ) - end if lower_than_current @@ -544,7 +466,7 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( end end - # Marks the end for flags != zero(DeviceUnsigned) + # Marks the end for flags & (...) != zero(DeviceUnsigned) end end diff --git a/ext/QuantumCliffordKAExt/definitions/default_parameters.jl b/ext/QuantumCliffordKAExt/definitions/default_parameters.jl index 02a022154..9a7c45d42 100644 --- a/ext/QuantumCliffordKAExt/definitions/default_parameters.jl +++ b/ext/QuantumCliffordKAExt/definitions/default_parameters.jl @@ -2,14 +2,20 @@ #=============================================================================# # Maintains compatibility with the main package unless explicitly specified. const default_multiplication_order = multiplication_order_left -# Strict correctness unless there is an explicit opt-out. -const default_phases = true + +# Customary choice to position X pauli operators before their Z counterparts. +const default_pauli_preferance = pauli_preferance_x + # Potentially boosts cache hits and reduces atomic contention. const default_primary_axis = primary_axis_rows + +# Strict correctness unless there is an explicit opt-out. +const default_phases = true + # Reasonable size that is generally ideal for most vendors and use cases. +# TODO: Modify this to DeviceUnsigned once CUDA bugs are resolved. const default_block_size = 256 + # Ameliorate overhead and enhance performance by doing more work per thread. -const default_batch_size = 32 -# TODO: Eliminate this in favour of complete asynchronicity. -const default_scheduling_limit = 64 +const default_batch_size = UInt8(32) #=============================================================================# diff --git a/ext/QuantumCliffordKAExt/definitions/enumerations.jl b/ext/QuantumCliffordKAExt/definitions/enumerations.jl index a8f55420b..1f6af5fc4 100644 --- a/ext/QuantumCliffordKAExt/definitions/enumerations.jl +++ b/ext/QuantumCliffordKAExt/definitions/enumerations.jl @@ -1,11 +1,17 @@ #=============================================================================# -# There is vanishing overhead for supporting both of these. +# Dictates the direction of the multiplication operation(s). @enum MultiplicationOrder::UInt8 begin multiplication_order_left multiplication_order_right end +# Specifies whether ordering should prioritise X or Z pauli operators. +@enum PauliPreferance::UInt8 begin + pauli_preferance_x + pauli_preferance_z +end + # Determines whether the first grid dimension matches the rows or the qubits. @enum PrimaryAxis::UInt8 begin primary_axis_rows @@ -18,10 +24,8 @@ UTILISED INTERNALLY # Pauli bit: Highest X(Z) < Lowest Z(X). Qubit Number: 1X(Z) < 1Z(X) < 2X(Z)... @enum SortOrder::UInt8 begin - sort_order_pauli_bit_prefer_x - sort_order_pauli_bit_prefer_z - sort_order_qubit_number_prefer_x - sort_order_qubit_number_prefer_z + sort_order_pauli_bit + sort_order_qubit_number end # Determines whether contraction proceeds from high to low or low to high. @@ -30,29 +34,29 @@ end scan_side_greater end -# Provides enhanced clarity over plain numerical values. -# CAUTION: The values are NOT arbitrary but utilised for proper indexing. -@enum TrackerContent::UInt8 begin - tracker_content_index = 0x1 - tracker_content_bit_shift = 0x2 - tracker_content_bit_type = 0x3 - tracker_content_swap_from = 0x4 - tracker_content_swap_to = 0x5 -end - # Provides enhanced clarity over plain numerical values. # CAUTION: The values are NOT arbitrary but utilised for proper ordering. @enum PauliBit::UInt8 begin - pauli_bit_x = 0x0 - pauli_bit_z = 0x1 - pauli_bit_invalid = 0x2 + pauli_bit_primary = 0 + pauli_bit_secondary = 1 + pauli_bit_invalid = 2 end # Provides enhanced clarity over plain numerical values. # CAUTION: The values are NOT arbitrary but utilised for proper masking. @enum BitFieldFlags::DeviceUnsigned begin - bit_field_flag_leader = 0x1 - bit_field_flag_scan = 0x2 - bit_field_flag_multiply = 0x4 + bit_field_flag_leader = 1 + bit_field_flag_scan = 2 + bit_field_flag_multiply = 4 +end + +# Provides enhanced clarity over plain numerical values. +# CAUTION: The values are NOT arbitrary but utilised for proper indexing. +@enum TrackerContent::Csize_t begin + tracker_content_index = 1 + tracker_content_bit_shift = 2 + tracker_content_bit_type = 3 + tracker_content_swap_from = 4 + tracker_content_swap_to = 5 end #=============================================================================# diff --git a/ext/QuantumCliffordKAExt/definitions/fixed_sizes.jl b/ext/QuantumCliffordKAExt/definitions/fixed_sizes.jl index f6de0e5f2..e56f7a17d 100644 --- a/ext/QuantumCliffordKAExt/definitions/fixed_sizes.jl +++ b/ext/QuantumCliffordKAExt/definitions/fixed_sizes.jl @@ -1,5 +1,5 @@ #=============================================================================# # TODO: Figure out a more elegant solution that Julia approves of. -const tracker_element_count = 0x5 +const tracker_element_count = Csize_t(length(instances(TrackerContent))) #=============================================================================# diff --git a/ext/QuantumCliffordKAExt/definitions/mutex_configuration.jl b/ext/QuantumCliffordKAExt/definitions/mutex_configuration.jl index f9917975d..447242ccf 100644 --- a/ext/QuantumCliffordKAExt/definitions/mutex_configuration.jl +++ b/ext/QuantumCliffordKAExt/definitions/mutex_configuration.jl @@ -2,9 +2,11 @@ #=============================================================================# # TODO: Eliminate these once KernelAbstractions becomes more feature complete. AbstractMutex = AbstractArray{DeviceUnsigned, 0} + # These should be confined to an enum but atomics require type conversion. -const mutex_state_locked::DeviceUnsigned = 0x0 -const mutex_state_unlocked::DeviceUnsigned = 0x1 +const mutex_state_locked = zero(DeviceUnsigned) +const mutex_state_unlocked = one(DeviceUnsigned) + # Defines the mappings for the compare-and-swap(CAS)/@atomicreplace call. const mutex_exchange_lock = mutex_state_unlocked => mutex_state_locked const mutex_exchange_unlock = mutex_state_locked => mutex_state_unlocked diff --git a/ext/QuantumCliffordKAExt/definitions/type_shorthands.jl b/ext/QuantumCliffordKAExt/definitions/type_shorthands.jl index 553052c3f..6cb96aa1b 100644 --- a/ext/QuantumCliffordKAExt/definitions/type_shorthands.jl +++ b/ext/QuantumCliffordKAExt/definitions/type_shorthands.jl @@ -1,29 +1,25 @@ #=============================================================================# -# Keeps the function definitions succinct. const DevicePauliOperator = PauliOperator{T_P, T_XZ} where { T_P <: AbstractGPUArray, T_XZ <: AbstractGPUArray } + const DeviceTableau = Tableau{T_P, T_XZ} where { T_P <: AbstractGPUArray, T_XZ <: AbstractGPUArray } -# This is a bit redundant but it keeps the REPL expansions more readable. -const DeviceUnionStabilizer = Union{ - Stabilizer{T}, MixedStabilizer{T} + +const DeviceUnionTableau = Union{ + T, Stabilizer{T}, MixedStabilizer{T}, Destabilizer{T}, MixedDestabilizer{T} } where { T_P <: AbstractGPUArray, T_XZ <: AbstractGPUArray, T <: Tableau{T_P, T_XZ} } + +# Utilised to specialise dispatch. const DeviceUnionDestabilizer = Union{ Destabilizer{T}, MixedDestabilizer{T} } where { T_P <: AbstractGPUArray, T_XZ <: AbstractGPUArray, T <: Tableau{T_P, T_XZ} } -const DeviceAbstractStabilizer = Union{ - Stabilizer{T}, MixedStabilizer{T}, Destabilizer{T}, MixedDestabilizer{T} - } where { - T_P <: AbstractGPUArray, T_XZ <: AbstractGPUArray, - T <: Tableau{T_P, T_XZ} - } #=============================================================================# diff --git a/ext/QuantumCliffordKAExt/definitions/word_size_integers.jl b/ext/QuantumCliffordKAExt/definitions/word_size_integers.jl index 96f1fd1c7..f18420de4 100644 --- a/ext/QuantumCliffordKAExt/definitions/word_size_integers.jl +++ b/ext/QuantumCliffordKAExt/definitions/word_size_integers.jl @@ -1,7 +1,5 @@ #=============================================================================# -# Most ideal type for shared reductions due to avoiding bank conflicts. -# TODO: Should these be modified to Cint/Cuint? -const DeviceSigned = Int32 +# Most ideal type due to matching the native hardware capabilities. const DeviceUnsigned = UInt32 #=============================================================================# diff --git a/ext/QuantumCliffordKAExt/imports.jl b/ext/QuantumCliffordKAExt/imports.jl index 30bdf449c..e6e7e0146 100644 --- a/ext/QuantumCliffordKAExt/imports.jl +++ b/ext/QuantumCliffordKAExt/imports.jl @@ -1,12 +1,14 @@ #=============================================================================# import KernelAbstractions as KA +# Resolves issue due to KA comparing against the literal Symbol("@Const"). +using KernelAbstractions: @Const using Atomix: @atomic, @atomicreplace + using GPUArraysCore: AbstractGPUArray -# Resolves issue due to KA comparing against the literal Symbol("@Const"). -using KernelAbstractions: @Const + using QuantumClifford -# This must be done explicitly as it is not exported. -using QuantumClifford: Tableau +# This must be done explicitly as they are not exported. +using QuantumClifford: Tableau, AbstractStabilizer #=============================================================================# diff --git a/ext/QuantumCliffordKAExt/mul_leftright.jl b/ext/QuantumCliffordKAExt/mul.jl similarity index 63% rename from ext/QuantumCliffordKAExt/mul_leftright.jl rename to ext/QuantumCliffordKAExt/mul.jl index bde47ce2b..9286acff8 100644 --- a/ext/QuantumCliffordKAExt/mul_leftright.jl +++ b/ext/QuantumCliffordKAExt/mul.jl @@ -1,5 +1,6 @@ #=============================================================================# -include("mul_leftright/device_mul.jl") -include("mul_leftright/host_interface.jl") +include("mul/device_mul.jl") +include("mul/new_interface.jl") +include("mul/old_interface.jl") #=============================================================================# diff --git a/ext/QuantumCliffordKAExt/mul_leftright/device_mul.jl b/ext/QuantumCliffordKAExt/mul/device_mul.jl similarity index 54% rename from ext/QuantumCliffordKAExt/mul_leftright/device_mul.jl rename to ext/QuantumCliffordKAExt/mul/device_mul.jl index 451ef0e29..15f9d8a28 100644 --- a/ext/QuantumCliffordKAExt/mul_leftright/device_mul.jl +++ b/ext/QuantumCliffordKAExt/mul/device_mul.jl @@ -1,29 +1,27 @@ #=============================================================================# # CAUTION: Keep in mind that the constants match the direction of the order. -# TODO: Make the parameters keyword arguments once support becomes available. KA.@kernel inbounds = true unsafe_indices = true function kernel_mul!( mutable_phases::AbstractArray{<: Unsigned}, mutable_xzs::AbstractArray{T}, @Const(const_xzs::AbstractArray{T}), - multiplication_order::MultiplicationOrder, - ::Val{phases}, ::Val{primary_axis}, ::Val{block_size}, ::Val{batch_size} - ) where { - T <: Unsigned, phases, primary_axis, block_size, batch_size - } + @Const(multiplication_order::MultiplicationOrder), + @Const(primary_axis::PrimaryAxis), + ::Val{phases}, ::Val{block_size}, ::Val{batch_size} + ) where {T <: Unsigned, phases, block_size, batch_size} if primary_axis == primary_axis_rows j_mutable, begin_i = global_index( KA.@index(Group, NTuple), KA.@groupsize(), KA.@index(Local, NTuple) ) - stride_i = KA.@ndrange()[0x2] + stride_i = KA.@ndrange()[2] elseif primary_axis == primary_axis_qubits begin_i, j_mutable = global_index( KA.@index(Group, NTuple), KA.@groupsize(), KA.@index(Local, NTuple) ) - stride_i = KA.@ndrange()[0x1] + stride_i = KA.@ndrange()[1] end - end_i = KA.@uniform (size(mutable_xzs, 0x1) >> 0x1) - flag = KA.@uniform (size(const_xzs, 0x2) > 0x1) + end_i = KA.@uniform (size(mutable_xzs, 1) >> 1) + flag = KA.@uniform (size(const_xzs, 2) > 1) j_const = ifelse(flag, j_mutable, one(j_mutable)) if phases @@ -41,7 +39,7 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul!( right = @view const_xzs[:, j_const] end - for (i, _) in zip(begin_i : stride_i : end_i, one(batch_size) : batch_size) + for (i, _) in zip(begin_i : stride_i : end_i, Base.OneTo(batch_size)) x_left = left[i] z_left = left[i + end_i] x_right = right[i] @@ -61,18 +59,26 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul!( end if phases - local_index = KA.@index(Local, Linear) + local_index = DeviceUnsigned(KA.@index(Local, Linear)) phase_buffer[local_index] = - ((count_ones(high) << 0x1) + count_ones(low)) & 0x3 + ((count_ones(high) << 1) + count_ones(low)) & 0x3 shared_memory_reduce!( - reduce_sum!, local_index, Val(block_size), phase_buffer + reduce_sum!, + local_index, Val(DeviceUnsigned(block_size)), + phase_buffer ) if local_index == one(local_index) - # CAUTION: This is sufficient since only atomicity is required. - @atomic :monotonic mutable_phases[j_mutable] += - phase_buffer[local_index] & 0x3 - @atomic :monotonic mutable_phases[j_mutable] &= 0x3 + # Avoid expensive operations when they are not required. + if stride_i > block_size + # CAUTION: This memory order is sufficient. + @atomic :monotonic mutable_phases[j_mutable] += + phase_buffer[local_index] & 0x3 + @atomic :monotonic mutable_phases[j_mutable] &= 0x3 + else + mutable_phases[j_mutable] += phase_buffer[local_index] & 0x3 + mutable_phases[j_mutable] &= 0x3 + end end end @@ -81,46 +87,33 @@ end # CAUTION: Requires either rows(const) == 1 or rows(const) == rows(mutable) function device_mul!( mutable_phases::AbstractArray{<: Unsigned}, mutable_xzs::AbstractArray{T}, - const_phases::AbstractArray{<: Unsigned}, const_xzs::AbstractArray{T}, - multiplication_order::MultiplicationOrder; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - )::Nothing where { - T <: Unsigned, phase_B, primary_axis_E, block_SZ, batch_SZ - } - - phase_B isa Bool && primary_axis_E isa PrimaryAxis && - block_SZ isa Integer && block_SZ > zero(block_SZ) && - batch_SZ isa Integer && batch_SZ > zero(batch_SZ) || - throw(ArgumentError(THROW_VALS)) + const_phases::AbstractArray{<: Unsigned}, const_xzs::AbstractArray{T}; + multiplication_order::MultiplicationOrder = default_multiplication_order, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + )::Nothing where {T <: Unsigned} backend = KA.get_backend(mutable_xzs) - if primary_axis_E == primary_axis_rows - tile = (one(block_SZ), block_SZ) + if primary_axis == primary_axis_rows + tile = (one(block_size), block_size) space = tessellate( - ( - size(mutable_xzs, 0x2), - cld(size(mutable_xzs, 0x1) >> 0x1, batch_SZ) - ), + (size(mutable_xzs, 2), cld(size(mutable_xzs, 1) >> 1, batch_size)), tile ) - elseif primary_axis_E == primary_axis_qubits - tile = (block_SZ, one(block_SZ)) + elseif primary_axis == primary_axis_qubits + tile = (block_size, one(block_size)) space = tessellate( - ( - cld(size(mutable_xzs, 0x1) >> 0x1, batch_SZ), - size(mutable_xzs, 0x2) - ), + (cld(size(mutable_xzs, 1) >> 1, batch_size), size(mutable_xzs, 2)), tile ) end - if phase_B + if phases snippet! = kernel_snippet!(backend) - @inbounds snippet!( + snippet!( snippet_mod_4_sum_phase!, mutable_phases, const_phases; ndrange = length(mutable_phases) @@ -128,8 +121,9 @@ function device_mul!( end mul! = kernel_mul!(backend) mul!( - mutable_phases, mutable_xzs, const_xzs, multiplication_order, - phases, primary_axis, block_size, batch_size; + mutable_phases, mutable_xzs, const_xzs, + multiplication_order, primary_axis, + Val(phases), Val(block_size), Val(batch_size); workgroupsize = tile, ndrange = space ) diff --git a/ext/QuantumCliffordKAExt/mul/new_interface.jl b/ext/QuantumCliffordKAExt/mul/new_interface.jl new file mode 100644 index 000000000..24442f21e --- /dev/null +++ b/ext/QuantumCliffordKAExt/mul/new_interface.jl @@ -0,0 +1,457 @@ + +#=============================================================================# +# TODO: Import the functions once the main package establishes them. + +#============================================================================== +RETURNS PAULI OPERATOR +==============================================================================# + +# PauliOperator - PauliOperator +@inline function mul!( + u::DevicePauliOperator, v::DevicePauliOperator; + multiplication_order::MultiplicationOrder = default_multiplication_order, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + @boundscheck begin + block_size > zero(block_size) && batch_size > zero(batch_size) || + throw(DomainError(THROW_PARAMETERS)) + u.nqubits == v.nqubits || + throw(DimensionMismatch(THROW_NQUBITS)) + end + return do_mul!( + u, v; + multiplication_order = multiplication_order, + primary_axis = primary_axis, phases = phases, + block_size = block_size, batch_size = batch_size + ) + +end + +@inline function do_mul!( + u::DevicePauliOperator, v::DevicePauliOperator; + multiplication_order::MultiplicationOrder = default_multiplication_order, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + device_mul!( + u.phase, u.xz, + v.phase, v.xz; + multiplication_order = multiplication_order, + primary_axis = primary_axis, phases = phases, + block_size = block_size, batch_size = batch_size + ) + return u + +end + +# PauliOperator - Tableau/AbstractStabilizer[i] +@inline function mul!( + u::DevicePauliOperator, v::DeviceUnionTableau, i::Integer; + multiplication_order::MultiplicationOrder = default_multiplication_order, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + @boundscheck begin + block_size > zero(block_size) && batch_size > zero(batch_size) || + throw(DomainError(THROW_PARAMETERS)) + v_tab = tab(v) + one(i) <= i <= length(v_tab.phases) || + throw(BoundsError(THROW_BOUNDS)) + u.nqubits == v_tab.nqubits || + throw(DimensionMismatch(THROW_NQUBITS)) + end + return do_mul!( + u, v, i; + multiplication_order = multiplication_order, + primary_axis = primary_axis, phases = phases, + block_size = block_size, batch_size = batch_size + ) + +end + +@inline function do_mul!( + u::DevicePauliOperator, v::DeviceUnionTableau, i::Integer; + multiplication_order::MultiplicationOrder = default_multiplication_order, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + v_tab = tab(v) + @inbounds device_mul!( + u.phase, u.xz, + (@view v_tab.phases[i]), (@view v_tab.xzs[:, i]); + multiplication_order = multiplication_order, + primary_axis = primary_axis, phases = phases, + block_size = block_size, batch_size = batch_size + ) + return u + +end + +#============================================================================== +RETURNS TABLEAU / ABSTRACT STABILIZER +==============================================================================# + +# Tableau/AbstractStabilizer - PauliOperator +@inline function mul!( + u::DeviceUnionTableau, v::DevicePauliOperator; + multiplication_order::MultiplicationOrder = default_multiplication_order, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + @boundscheck begin + block_size > zero(block_size) && batch_size > zero(batch_size) || + throw(DomainError(THROW_PARAMETERS)) + u_tab = tab(u) + u_tab.nqubits == v.nqubits || + throw(DimensionMismatch(THROW_NQUBITS)) + end + return do_mul!( + u, v; + multiplication_order = multiplication_order, + primary_axis = primary_axis, phases = phases, + block_size = block_size, batch_size = batch_size + ) + +end + +@inline function do_mul!( + u::DeviceUnionTableau, v::DevicePauliOperator; + multiplication_order::MultiplicationOrder = default_multiplication_order, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + u_tab = tab(u) + device_mul!( + u_tab.phases, u_tab.xzs, + v.phase, v.xz; + multiplication_order = multiplication_order, + primary_axis = primary_axis, phases = phases, + block_size = block_size, batch_size = batch_size + ) + return u + +end + +# Tableau/AbstractStabilizer - Tableau/AbstractStabilizer[i] +@inline function mul!( + u::DeviceUnionTableau, v::DeviceUnionTableau, i::Integer; + multiplication_order::MultiplicationOrder = default_multiplication_order, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + @boundscheck begin + block_size > zero(block_size) && batch_size > zero(batch_size) || + throw(DomainError(THROW_PARAMETERS)) + u_tab, v_tab = tab(u), tab(v) + one(i) <= i <= length(v_tab.phases) || + throw(BoundsError(THROW_BOUNDS)) + u_tab.nqubits == v_tab.nqubits || + throw(DimensionMismatch(THROW_NQUBITS)) + end + return do_mul!( + u, v, i; + multiplication_order = multiplication_order, + primary_axis = primary_axis, phases = phases, + block_size = block_size, batch_size = batch_size + ) + +end + +@inline function do_mul!( + u::DeviceUnionTableau, v::DeviceUnionTableau, i::Integer; + multiplication_order::MultiplicationOrder = default_multiplication_order, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + u_tab, v_tab = tab(u), tab(v) + @inbounds device_mul!( + u_tab.phases, u_tab.xzs, + (@view v_tab.phases[i]), (@view v_tab.xzs[:, i]); + multiplication_order = multiplication_order, + primary_axis = primary_axis, phases = phases, + block_size = block_size, batch_size = batch_size + ) + return u + +end + +# Tableau - Tableau/AbstractStabilizer +@inline function mul!( + u::DeviceTableau, v::DeviceUnionTableau; + multiplication_order::MultiplicationOrder = default_multiplication_order, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + @boundscheck begin + block_size > zero(block_size) && batch_size > zero(batch_size) || + throw(DomainError(THROW_PARAMETERS)) + v_tab = tab(v) + length(u.phases) == length(v_tab.phases) || + throw(DimensionMismatch(THROW_SIZE)) + u.nqubits == v_tab.nqubits || + throw(DimensionMismatch(THROW_NQUBITS)) + end + return do_mul!( + u, v; + multiplication_order = multiplication_order, + primary_axis = primary_axis, phases = phases, + block_size = block_size, batch_size = batch_size + ) + +end + +@inline function do_mul!( + u::DeviceTableau, v::DeviceUnionTableau; + multiplication_order::MultiplicationOrder = default_multiplication_order, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + v_tab = tab(v) + device_mul!( + u.phases, u.xzs, + v_tab.phases, v_tab.xzs; + multiplication_order = multiplication_order, + primary_axis = primary_axis, phases = phases, + block_size = block_size, batch_size = batch_size + ) + return u + +end + +# Tableau[i] - PauliOperator +@inline function mul!( + u::DeviceTableau, i::Integer, v::DevicePauliOperator; + multiplication_order::MultiplicationOrder = default_multiplication_order, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + @boundscheck begin + block_size > zero(block_size) && batch_size > zero(batch_size) || + throw(DomainError(THROW_PARAMETERS)) + one(i) <= i <= length(u.phases) || + throw(BoundsError(THROW_BOUNDS)) + u.nqubits == v.nqubits || + throw(DimensionMismatch(THROW_NQUBITS)) + end + return do_mul!( + u, i, v; + multiplication_order = multiplication_order, + primary_axis = primary_axis, phases = phases, + block_size = block_size, batch_size = batch_size + ) + +end + +@inline function do_mul!( + u::DeviceTableau, i::Integer, v::DevicePauliOperator; + multiplication_order::MultiplicationOrder = default_multiplication_order, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + @inbounds device_mul!( + (@view u.phases[i]), (@view u.xzs[:, i]), + v.phase, v.xz; + multiplication_order = multiplication_order, + primary_axis = primary_axis, phases = phases, + block_size = block_size, batch_size = batch_size + ) + return u + +end + +# Tableau[i] - Tableau/AbstractStabilizer[j] +@inline function mul!( + u::DeviceTableau, i::Integer, v::DeviceUnionTableau, j::Integer; + multiplication_order::MultiplicationOrder = default_multiplication_order, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + @boundscheck begin + block_size > zero(block_size) && batch_size > zero(batch_size) || + throw(DomainError(THROW_PARAMETERS)) + v_tab = tab(v) + one(i) <= i <= length(u.phases) && + one(j) <= j <= length(v_tab.phases) || + throw(BoundsError(THROW_BOUNDS)) + u.nqubits == v_tab.nqubits || + throw(DimensionMismatch(THROW_NQUBITS)) + end + return do_mul!( + u, i, v, j; + multiplication_order = multiplication_order, + primary_axis = primary_axis, phases = phases, + block_size = block_size, batch_size = batch_size + ) + +end + +@inline function do_mul!( + u::DeviceTableau, i::Integer, v::DeviceUnionTableau, j::Integer; + multiplication_order::MultiplicationOrder = default_multiplication_order, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + v_tab = tab(v) + @inbounds device_mul!( + (@view u.phases[i]), (@view u.xzs[:, i]), + (@view v_tab.phases[j]), (@view v_tab.xzs[:, j]); + multiplication_order = multiplication_order, + primary_axis = primary_axis, phases = phases, + block_size = block_size, batch_size = batch_size + ) + return u + +end + +# CAUTION: (Mixed)Destabilizer is handled separately. +# Tableau/AbstractStabilizer[i] - Self[j] +@inline function mul!( + u::DeviceUnionTableau, i::Integer, j::Integer; + multiplication_order::MultiplicationOrder = default_multiplication_order, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + @boundscheck begin + block_size > zero(block_size) && batch_size > zero(batch_size) || + throw(DomainError(THROW_PARAMETERS)) + u_tab = tab(u) + len = length(u_tab.phases) + one(i) <= i <= len && one(j) <= j <= len || + throw(BoundsError(THROW_BOUNDS)) + end + return do_mul!( + u, i, j; + multiplication_order = multiplication_order, + primary_axis = primary_axis, phases = phases, + block_size = block_size, batch_size = batch_size + ) + +end + +@inline function do_mul!( + u::DeviceUnionTableau, i::Integer, j::Integer; + multiplication_order::MultiplicationOrder = default_multiplication_order, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + u_tab = tab(u) + @inbounds device_mul!( + (@view u_tab.phases[i]), (@view u_tab.xzs[:, i]), + (@view u_tab.phases[j]), (@view u_tab.xzs[:, j]); + multiplication_order = multiplication_order, + primary_axis = primary_axis, phases = phases, + block_size = block_size, batch_size = batch_size + ) + return u + +end + +#============================================================================== +RETURNS (MIXED) DESTABILIZER +==============================================================================# + +# CAUTION: Requires special handling. +# (Mixed)Destabilizer[i] - Self[j] +@inline function mul!( + u::DeviceUnionDestabilizer, i::Integer, j::Integer; + multiplication_order::MultiplicationOrder = default_multiplication_order, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + @boundscheck begin + block_size > zero(block_size) && batch_size > zero(batch_size) || + throw(DomainError(THROW_PARAMETERS)) + len = length(u.tab.phases) >> 1 + one(i) <= i <= len && one(j) <= j <= len || + throw(BoundsError(THROW_BOUNDS)) + end + return do_mul!( + u, i, j; + multiplication_order = multiplication_order, + primary_axis = primary_axis, phases = phases, + block_size = block_size, batch_size = batch_size + ) + +end + +@inline function do_mul!( + u::DeviceUnionDestabilizer, i::Integer, j::Integer; + multiplication_order::MultiplicationOrder = default_multiplication_order, + primary_axis::PrimaryAxis = default_primary_axis, + phases::Bool = default_phases, + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) + + len = length(u.tab.phases) >> 1 + # Swapping the order of the indices is intentional. + @inbounds device_mul!( + (@view u.tab.phases[j]), (@view u.tab.xzs[:, j]), + (@view u.tab.phases[i]), (@view u.tab.xzs[:, i]); + multiplication_order = multiplication_order, + primary_axis = primary_axis, phases = false, + block_size = block_size, batch_size = batch_size + ) + @inbounds device_mul!( + (@view u.tab.phases[i + len]), (@view u.tab.xzs[:, i + len]), + (@view u.tab.phases[j + len]), (@view u.tab.xzs[:, j + len]); + multiplication_order = multiplication_order, + primary_axis = primary_axis, phases = phases, + block_size = block_size, batch_size = batch_size + ) + return u + +end +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/mul/old_interface.jl b/ext/QuantumCliffordKAExt/mul/old_interface.jl new file mode 100644 index 000000000..4fba4c6f3 --- /dev/null +++ b/ext/QuantumCliffordKAExt/mul/old_interface.jl @@ -0,0 +1,343 @@ + +#=============================================================================# +# TODO: Include the unsafe functions once the main package establishes them. +import QuantumClifford: mul_left!, mul_right! + +# CAUTION: Meta-programming is utilised to in order to avoid repetition. +for (safe_f_sym, unsafe_f_sym, multiplication_order) in ( + (:mul_left!, :do_mul_left!, multiplication_order_left), + (:mul_right!, :do_mul_right!, multiplication_order_right) + ) + +#============================================================================== +RETURNS PAULI OPERATOR +==============================================================================# + +# PauliOperator - PauliOperator +@eval @inline function $safe_f_sym( + u::DevicePauliOperator, v::DevicePauliOperator; + primary_axis::PrimaryAxis = default_primary_axis, + phases::Val{phases_B} = Val(default_phases), + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) where {phases_B} + + return mul!( + u, v; + multiplication_order = $multiplication_order, + primary_axis = primary_axis, phases = phases_B, + block_size = block_size, batch_size = batch_size + ) + +end + +@eval @inline function $unsafe_f_sym( + u::DevicePauliOperator, v::DevicePauliOperator; + primary_axis::PrimaryAxis = default_primary_axis, + phases::Val{phases_B} = Val(default_phases), + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) where {phases_B} + + return do_mul!( + u, v; + multiplication_order = $multiplication_order, + primary_axis = primary_axis, phases = phases_B, + block_size = block_size, batch_size = batch_size + ) + +end + +# PauliOperator - Tableau/AbstractStabilizer[i] +@eval @inline function $safe_f_sym( + u::DevicePauliOperator, v::DeviceUnionTableau, i::Integer; + primary_axis::PrimaryAxis = default_primary_axis, + phases::Val{phases_B} = Val(default_phases), + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) where {phases_B} + + return mul!( + u, v, i; + multiplication_order = $multiplication_order, + primary_axis = primary_axis, phases = phases_B, + block_size = block_size, batch_size = batch_size + ) + +end + +@eval @inline function $unsafe_f_sym( + u::DevicePauliOperator, v::DeviceUnionTableau, i::Integer; + primary_axis::PrimaryAxis = default_primary_axis, + phases::Val{phases_B} = Val(default_phases), + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) where {phases_B} + + return do_mul!( + u, v, i; + multiplication_order = $multiplication_order, + primary_axis = primary_axis, phases = phases_B, + block_size = block_size, batch_size = batch_size + ) + +end + +#============================================================================== +RETURNS TABLEAU / ABSTRACT STABILIZER +==============================================================================# + +# Tableau/AbstractStabilizer - PauliOperator +@eval @inline function $safe_f_sym( + u::DeviceUnionTableau, v::DevicePauliOperator; + primary_axis::PrimaryAxis = default_primary_axis, + phases::Val{phases_B} = Val(default_phases), + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) where {phases_B} + + return mul!( + u, v; + multiplication_order = $multiplication_order, + primary_axis = primary_axis, phases = phases_B, + block_size = block_size, batch_size = batch_size + ) + +end + +@eval @inline function $unsafe_f_sym( + u::DeviceUnionTableau, v::DevicePauliOperator; + primary_axis::PrimaryAxis = default_primary_axis, + phases::Val{phases_B} = Val(default_phases), + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) where {phases_B} + + return do_mul!( + u, v; + multiplication_order = $multiplication_order, + primary_axis = primary_axis, phases = phases_B, + block_size = block_size, batch_size = batch_size + ) + +end + +# Tableau/AbstractStabilizer - Tableau/AbstractStabilizer[i] +@eval @inline function $safe_f_sym( + u::DeviceUnionTableau, v::DeviceUnionTableau, i::Integer; + primary_axis::PrimaryAxis = default_primary_axis, + phases::Val{phases_B} = Val(default_phases), + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) where {phases_B} + + return mul!( + u, v, i; + multiplication_order = $multiplication_order, + primary_axis = primary_axis, phases = phases_B, + block_size = block_size, batch_size = batch_size + ) + +end + +@eval @inline function $unsafe_f_sym( + u::DeviceUnionTableau, v::DeviceUnionTableau, i::Integer; + primary_axis::PrimaryAxis = default_primary_axis, + phases::Val{phases_B} = Val(default_phases), + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) where {phases_B} + + return do_mul!( + u, v, i; + multiplication_order = $multiplication_order, + primary_axis = primary_axis, phases = phases_B, + block_size = block_size, batch_size = batch_size + ) + +end + +# Tableau - Tableau/AbstractStabilizer +@eval @inline function $safe_f_sym( + u::DeviceTableau, v::DeviceUnionTableau; + primary_axis::PrimaryAxis = default_primary_axis, + phases::Val{phases_B} = Val(default_phases), + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) where {phases_B} + + return mul!( + u, v; + multiplication_order = $multiplication_order, + primary_axis = primary_axis, phases = phases_B, + block_size = block_size, batch_size = batch_size + ) + +end + +@eval @inline function $unsafe_f_sym( + u::DeviceTableau, v::DeviceUnionTableau; + primary_axis::PrimaryAxis = default_primary_axis, + phases::Val{phases_B} = Val(default_phases), + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) where {phases_B} + + return do_mul!( + u, v; + multiplication_order = $multiplication_order, + primary_axis = primary_axis, phases = phases_B, + block_size = block_size, batch_size = batch_size + ) + +end + +# Tableau[i] - PauliOperator +@eval @inline function $safe_f_sym( + u::DeviceTableau, i::Integer, v::DevicePauliOperator; + primary_axis::PrimaryAxis = default_primary_axis, + phases::Val{phases_B} = Val(default_phases), + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) where {phases_B} + + return mul!( + u, i, v; + multiplication_order = $multiplication_order, + primary_axis = primary_axis, phases = phases_B, + block_size = block_size, batch_size = batch_size + ) + +end + +@eval @inline function $unsafe_f_sym( + u::DeviceTableau, i::Integer, v::DevicePauliOperator; + primary_axis::PrimaryAxis = default_primary_axis, + phases::Val{phases_B} = Val(default_phases), + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) where {phases_B} + + return do_mul!( + u, i, v; + multiplication_order = $multiplication_order, + primary_axis = primary_axis, phases = phases_B, + block_size = block_size, batch_size = batch_size + ) + +end + +# Tableau[i] - Tableau/AbstractStabilizer[j] +@eval @inline function $safe_f_sym( + u::DeviceTableau, i::Integer, v::DeviceUnionTableau, j::Integer; + primary_axis::PrimaryAxis = default_primary_axis, + phases::Val{phases_B} = Val(default_phases), + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) where {phases_B} + + return mul!( + u, i, v, j; + multiplication_order = $multiplication_order, + primary_axis = primary_axis, phases = phases_B, + block_size = block_size, batch_size = batch_size + ) + +end + +@eval @inline function $unsafe_f_sym( + u::DeviceTableau, i::Integer, v::DeviceUnionTableau, j::Integer; + primary_axis::PrimaryAxis = default_primary_axis, + phases::Val{phases_B} = Val(default_phases), + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) where {phases_B} + + return do_mul!( + u, i, v, j; + multiplication_order = $multiplication_order, + primary_axis = primary_axis, phases = phases_B, + block_size = block_size, batch_size = batch_size + ) + +end + +# CAUTION: (Mixed)Destabilizer is handled separately. +# Tableau/AbstractStabilizer[i] - Self[j] +@eval @inline function $safe_f_sym( + u::DeviceUnionTableau, i::Integer, j::Integer; + primary_axis::PrimaryAxis = default_primary_axis, + phases::Val{phases_B} = Val(default_phases), + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) where {phases_B} + + return mul!( + u, i, j; + multiplication_order = $multiplication_order, + primary_axis = primary_axis, phases = phases_B, + block_size = block_size, batch_size = batch_size + ) + +end + +@eval @inline function $unsafe_f_sym( + u::DeviceUnionTableau, i::Integer, j::Integer; + primary_axis::PrimaryAxis = default_primary_axis, + phases::Val{phases_B} = Val(default_phases), + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) where {phases_B} + + return do_mul!( + u, i, j; + multiplication_order = $multiplication_order, + primary_axis = primary_axis, phases = phases_B, + block_size = block_size, batch_size = batch_size + ) + +end + +#============================================================================== +RETURNS (MIXED) DESTABILIZER +==============================================================================# + +# CAUTION: Requires special handling. +# (Mixed)Destabilizer[i] - Self[j] +@eval @inline function $safe_f_sym( + u::DeviceUnionDestabilizer, i::Integer, j::Integer; + primary_axis::PrimaryAxis = default_primary_axis, + phases::Val{phases_B} = Val(default_phases), + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) where {phases_B} + + return mul!( + u, i, j; + multiplication_order = $multiplication_order, + primary_axis = primary_axis, phases = phases_B, + block_size = block_size, batch_size = batch_size + ) + +end + +@eval @inline function $unsafe_f_sym( + u::DeviceUnionDestabilizer, i::Integer, j::Integer; + primary_axis::PrimaryAxis = default_primary_axis, + phases::Val{phases_B} = Val(default_phases), + block_size::Integer = default_block_size, + batch_size::Integer = default_batch_size + ) where {phases_B} + + return do_mul!( + u, i, j; + multiplication_order = $multiplication_order, + primary_axis = primary_axis, phases = phases_B, + block_size = block_size, batch_size = batch_size + ) + +end + +# Marks the end for (safe_f_sym, unsafe_f_sym, multiplication_order) +end +#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/mul_leftright/host_interface.jl b/ext/QuantumCliffordKAExt/mul_leftright/host_interface.jl deleted file mode 100644 index 68466f638..000000000 --- a/ext/QuantumCliffordKAExt/mul_leftright/host_interface.jl +++ /dev/null @@ -1,430 +0,0 @@ - -#=============================================================================# -# TODO: include the unsafe functions once the main package establishes them. -import QuantumClifford: mul_left!, mul_right! - -# CAUTION: Meta-programming is utilised to in order to avoid repetition. -for (safe_f_sym, unsafe_f_sym, multiplication_order) in ( - (:mul_left!, :do_mul_left!, multiplication_order_left), - (:mul_right!, :do_mul_right!, multiplication_order_right) - ) - -#============================================================================== -RETURNS PAULI OPERATOR -==============================================================================# - -# PauliOperator - PauliOperator -@eval @inline function $safe_f_sym( - u::DevicePauliOperator, v::DevicePauliOperator; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - @boundscheck begin - u.nqubits == v.nqubits || - throw(DimensionMismatch(THROW_NQUBITS)) - end - return $unsafe_f_sym( - u, v; - phases = phases, primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - -end - -@eval @inline function $unsafe_f_sym( - u::DevicePauliOperator, v::DevicePauliOperator; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - device_mul!( - u.phase, u.xz, v.phase, v.xz, $multiplication_order; - phases = phases, primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - return u - -end - -for (T_v_sym, v_tab_sym) in ( - (:DeviceTableau, :v), (:DeviceAbstractStabilizer, :(v.tab)) - ) - -# PauliOperator - Tableau/AbstractStabilizer[i] -@eval @inline function $safe_f_sym( - u::DevicePauliOperator, v::$T_v_sym, i::Integer; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - @boundscheck begin - one(i) <= i <= length($v_tab_sym.phases) || - throw(BoundsError(THROW_BOUNDS)) - u.nqubits == $v_tab_sym.nqubits || - throw(DimensionMismatch(THROW_NQUBITS)) - end - return $unsafe_f_sym( - u, v, i; - phases = phases, primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - -end - -@eval @inline function $unsafe_f_sym( - u::DevicePauliOperator, v::$T_v_sym, i::Integer; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - @inbounds device_mul!( - u.phase, u.xz, - (@view $v_tab_sym.phases[i]), (@view $v_tab_sym.xzs[:, i]), - $multiplication_order; - phases = phases, primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - return u - -end - -# Marks the end for (T_v_sym, v_tab_sym) -end - -#============================================================================== -RETURNS TABLEAU / ABSTRACT STABILIZER -==============================================================================# - -for (T_u_sym, u_tab_sym) in ( - (:DeviceTableau, :u), (:DeviceAbstractStabilizer, :(u.tab)) - ) - -# Tableau/AbstractStabilizer - PauliOperator -@eval @inline function $safe_f_sym( - u::$T_u_sym, v::DevicePauliOperator; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - @boundscheck begin - $u_tab_sym.nqubits == v.nqubits || - throw(DimensionMismatch(THROW_NQUBITS)) - end - return $unsafe_f_sym( - u, v; - phases = phases, primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - -end - -@eval @inline function $unsafe_f_sym( - u::$T_u_sym, v::DevicePauliOperator; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - device_mul!( - $u_tab_sym.phases, $u_tab_sym.xzs, v.phase, v.xz, - $multiplication_order; - phases = phases, primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - return u - -end - -# Tableau/AbstractStabilizer[i] - PauliOperator -@eval @inline function $safe_f_sym( - u::$T_u_sym, i::Integer, v::DevicePauliOperator; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - @boundscheck begin - one(i) <= i <= length($u_tab_sym.phases) || - throw(BoundsError(THROW_BOUNDS)) - $u_tab_sym.nqubits == v.nqubits || - throw(DimensionMismatch(THROW_NQUBITS)) - end - return $unsafe_f_sym( - u, i, v; - phases = phases, primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - -end - -@eval @inline function $unsafe_f_sym( - u::$T_u_sym, i::Integer, v::DevicePauliOperator; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - @inbounds device_mul!( - (@view $u_tab_sym.phases[i]), (@view $u_tab_sym.xzs[:, i]), - v.phase, v.xz, - $multiplication_order; - phases = phases, primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - return u - -end - -# CAUTION: (Mixed)Destabilizer is handled separately. -# Tableau/AbstractStabilizer[i] - Self[j] -@eval @inline function $safe_f_sym( - u::$T_u_sym, i::Integer, j::Integer; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - @boundscheck begin - len = length($u_tab_sym.phases) - one(i) <= i <= len && one(j) <= j <= len || - throw(BoundsError(THROW_BOUNDS)) - end - return $unsafe_f_sym( - u, i, j; - phases = phases, primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - -end - -@eval @inline function $unsafe_f_sym( - u::$T_u_sym, i::Integer, j::Integer; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - @inbounds device_mul!( - (@view $u_tab_sym.phases[i]), (@view $u_tab_sym.xzs[:, i]), - (@view $u_tab_sym.phases[j]), (@view $u_tab_sym.xzs[:, j]), - $multiplication_order; - phases = phases, primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - return u - -end - -for (T_v_sym, v_tab_sym) in ( - (:DeviceTableau, :v), (:DeviceAbstractStabilizer, :(v.tab)) - ) - -# Tableau/AbstractStabilizer - Tableau/AbstractStabilizer -@eval @inline function $safe_f_sym( - u::$T_u_sym, v::$T_v_sym; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - @boundscheck begin - length($u_tab_sym.phases) == length($v_tab_sym.phases) || - throw(DimensionMismatch(THROW_SIZE)) - $u_tab_sym.nqubits == $v_tab_sym.nqubits || - throw(DimensionMismatch(THROW_NQUBITS)) - end - return $unsafe_f_sym( - u, v; - phases = phases, primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - -end - -@eval @inline function $unsafe_f_sym( - u::$T_u_sym, v::$T_v_sym; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - device_mul!( - $u_tab_sym.phases, $u_tab_sym.xzs, $v_tab_sym.phases, $v_tab_sym.xzs, - $multiplication_order; - phases = phases, primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - return u - -end - -# Tableau/AbstractStabilizer - Tableau/AbstractStabilizer[i] -@eval @inline function $safe_f_sym( - u::$T_u_sym, v::$T_v_sym, i::Integer; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - @boundscheck begin - one(i) <= i <= length($v_tab_sym.phases) || - throw(BoundsError(THROW_BOUNDS)) - $u_tab_sym.nqubits == $v_tab_sym.nqubits || - throw(DimensionMismatch(THROW_NQUBITS)) - end - return $unsafe_f_sym( - u, v, i; - phases = phases, primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - -end - -@eval @inline function $unsafe_f_sym( - u::$T_u_sym, v::$T_v_sym, i::Integer; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - @inbounds device_mul!( - $u_tab_sym.phases, $u_tab_sym.xzs, - (@view $v_tab_sym.phases[i]), (@view $v_tab_sym.xzs[:, i]), - $multiplication_order; - phases = phases, primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - return u - -end - -# Tableau/AbstractStabilizer[i] - Tableau/AbstractStabilizer[j] -@eval @inline function $safe_f_sym( - u::$T_u_sym, i::Integer, v::$T_v_sym, j::Integer; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - @boundscheck begin - one(i) <= i <= length($u_tab_sym.phases) && - one(j) <= j <= length($v_tab_sym.phases) || - throw(BoundsError(THROW_BOUNDS)) - $u_tab_sym.nqubits == $v_tab_sym.nqubits || - throw(DimensionMismatch(THROW_NQUBITS)) - end - return $unsafe_f_sym( - u, i, v, j; - phases = phases, primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - -end - -@eval @inline function $unsafe_f_sym( - u::$T_u_sym, i::Integer, v::$T_v_sym, j::Integer; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - @inbounds device_mul!( - (@view $u_tab_sym.phases[i]), (@view $u_tab_sym.xzs[:, i]), - (@view $v_tab_sym.phases[j]), (@view $v_tab_sym.xzs[:, j]), - $multiplication_order; - phases = phases, primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - return u - -end - -# Marks the end for (T_v_sym, v_tab_sym) -end - -# Marks the end for (T_u_sym, u_tab_sym) -end - -#============================================================================== -RETURNS (MIXED) DESTABILIZER -==============================================================================# - -# CAUTION: Requires special handling. -# (Mixed)Destabilizer[i] - Self[j] -@eval @inline function $safe_f_sym( - u::DeviceUnionDestabilizer, i::Integer, j::Integer; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - @boundscheck begin - len = length(u.tab.phases) - n = len >> one(len) - all(x -> one(x) <= x <= len, (i, j, i + n, j + n)) || - throw(BoundsError(THROW_BOUNDS)) - end - return $unsafe_f_sym( - u, i, j; - phases = phases, primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - -end - -@eval @inline function $unsafe_f_sym( - u::DeviceUnionDestabilizer, i::Integer, j::Integer; - phases::Val{phase_B} = Val(default_phases), - primary_axis::Val{primary_axis_E} = Val(default_primary_axis), - block_size::Val{block_SZ} = Val(default_block_size), - batch_size::Val{batch_SZ} = Val(default_batch_size) - ) where {phase_B, primary_axis_E, block_SZ, batch_SZ} - - p, xzs = u.tab.phases, u.tab.xzs - n = length(p) - n >>= one(n) - # Swapping the order of the indices is intentional. - @inbounds device_mul!( - (@view p[j]), (@view xzs[:, j]), - (@view p[i]), (@view xzs[:, i]), - $multiplication_order; - phases = Val(false), primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - @inbounds device_mul!( - (@view p[i + n]), (@view xzs[:, i + n]), - (@view p[j + n]), (@view xzs[:, j + n]), - $multiplication_order; - phases = phases, primary_axis = primary_axis, - block_size = block_size, batch_size = batch_size - ) - return u - -end - -# Marks the end for (safe_f_sym, unsafe_f_sym, multiplication_order) -end -#=============================================================================# diff --git a/ext/QuantumCliffordKAExt/utilities/bit_manipulation.jl b/ext/QuantumCliffordKAExt/utilities/bit_manipulation.jl index 447fd0c66..95defbcb0 100644 --- a/ext/QuantumCliffordKAExt/utilities/bit_manipulation.jl +++ b/ext/QuantumCliffordKAExt/utilities/bit_manipulation.jl @@ -5,19 +5,28 @@ return sizeof(T) * count_zeros(zero(Cuchar)) end +# CAUTION: Zero indexed shift, valid values are less than bit_count(T). +@inline function highest_set_bit(bit_field::T) where {T <: Unsigned} + return T(leading_zeros(bit_field)) +end + # CAUTION: Zero indexed shift, valid values are less than bit_count(T). @inline function lowest_set_bit(bit_field::T) where {T <: Unsigned} return T(trailing_zeros(bit_field)) end # CAUTION: Unsigned typing is intentional for branchless code generation. -@inline function bit_mask(bit_shift::Unsigned, ::Type{T}) where {T <: Unsigned} - return one(T) << bit_shift +@inline function top_bits(count::Unsigned, ::Type{T}) where {T <: Unsigned} + return ~(~zero(T) >> count) end # CAUTION: Unsigned typing is intentional for branchless code generation. -# CAUTION: Requires that count be lower than bit_count(T) for validity. -@inline function top_bits(count::Unsigned, ::Type{T}) where {T <: Unsigned} - return ~(~zero(T) >>> count) +@inline function bottom_bits(count::Unsigned, ::Type{T}) where {T <: Unsigned} + return ~(~zero(T) << count) +end + +# CAUTION: Unsigned typing is intentional for branchless code generation. +@inline function bit_mask(bit_shift::Unsigned, ::Type{T}) where {T <: Unsigned} + return one(T) << bit_shift end #=============================================================================# diff --git a/ext/QuantumCliffordKAExt/utilities/kernel_configuration.jl b/ext/QuantumCliffordKAExt/utilities/kernel_configuration.jl index 65ee7e3ae..999b9e03b 100644 --- a/ext/QuantumCliffordKAExt/utilities/kernel_configuration.jl +++ b/ext/QuantumCliffordKAExt/utilities/kernel_configuration.jl @@ -2,7 +2,7 @@ #=============================================================================# # Translates index set dimensions into the grid-block model. @inline function tessellate( - space::NTuple{N, <: Integer}, tile::NTuple{N, <: Integer} + space::NTuple{N, Integer}, tile::NTuple{N, Integer} ) where {N} return tile .* cld.(space, tile) @@ -12,8 +12,8 @@ end # Commonly set unsafe_indices = true, hence replaces KA.@index(Global, NTuple). @inline function global_index( block_index::NTuple{N, T}, - block_dim::NTuple{N, <: Integer}, - thread_index::NTuple{N, <: Integer} + block_dim::NTuple{N, Integer}, + thread_index::NTuple{N, Integer} ) where {N, T <: Integer} return (block_index .- one(T)) .* block_dim .+ thread_index diff --git a/ext/QuantumCliffordKAExt/utilities/mutex_management.jl b/ext/QuantumCliffordKAExt/utilities/mutex_management.jl index e294cacea..ce9580b8c 100644 --- a/ext/QuantumCliffordKAExt/utilities/mutex_management.jl +++ b/ext/QuantumCliffordKAExt/utilities/mutex_management.jl @@ -20,31 +20,27 @@ SINFUL IMPLEMENTATION WROUGHT ABOUT BY THE FOLLY OF MANKIND # TODO: Overhaul this entirely once support becomes available. @inline @generated function lock_mutex!( - mutex::M, data::AbstractArray... - )::Nothing where {M <: AbstractMutex} + mutex::AbstractMutex, data::AbstractArray... + )::Nothing - if length(data) > 0x0 - clause = :( - @atomicreplace :release :acquire mutex[0x1] mutex_exchange_lock - ) - # CAUTION: Atomics forces the necessary memory synchronisation barrier. - @inbounds nil = zero(eltype(data[0x1])) - fence = :( - # This is just a fancy atomic NOOP. - @atomicreplace :release :acquire data[0x1][0x1] $nil => $nil; - ) + if length(data) > 0 + swap = :(@atomicreplace :release :acquire mutex[1] mutex_exchange_lock) + # CAUTION: Atomics force the necessary memory synchronisation barrier. + @inbounds nil = zero(eltype(data[1])) + # This is just a fancy atomic NOOP. + fence = :(@atomicreplace :release :acquire data[1][1] $nil => $nil;) end - for n in 0x2 : length(data) + for n in 2 : length(data) @inbounds nil = zero(eltype(data[n])) fence = :( $fence; - @atomicreplace :release :acquire data[$n][0x1] $nil => $nil; + @atomicreplace :release :acquire data[$n][1] $nil => $nil; ) end return :( @inbounds begin; while true; - ($clause).success && break; + ($swap).success && break; end; $fence; end; @@ -55,30 +51,30 @@ end # TODO: Overhaul this entirely once support becomes available. @inline @generated function unlock_mutex!( - mutex::M, data::AbstractArray... - )::Nothing where {M <: AbstractMutex} + mutex::AbstractMutex, data::AbstractArray... + )::Nothing - if length(data) > 0x0 - # CAUTION: Atomics forces the necessary memory synchronisation barrier. + if length(data) > 0 + # CAUTION: Atomics force the necessary memory synchronisation barrier. fence = :( - temp_1 = data[0x1][0x1]; + temp_1 = data[1][1]; # This is just a fancy atomic NOOP. Always succeeds and releases. - @atomicreplace :release :acquire data[0x1][0x1] temp_1 => temp_1; + @atomicreplace :release :acquire data[1][1] temp_1 => temp_1; ) end - for n in 0x2 : length(data) + for n in 2 : length(data) sym = Symbol(:temp_, n) fence = :( $fence; - $sym = data[$n][0x1]; - @atomicreplace :release :acquire data[$n][0x1] $sym => $sym; + $sym = data[$n][1]; + @atomicreplace :release :acquire data[$n][1] $sym => $sym; ) end return :( @inbounds begin; $fence; # This will always succeed. - @atomicreplace :release :acquire mutex[0x1] mutex_exchange_unlock; + @atomicreplace :release :acquire mutex[1] mutex_exchange_unlock; end; return nothing; ) diff --git a/ext/QuantumCliffordKAExt/utilities/reductions.jl b/ext/QuantumCliffordKAExt/utilities/reductions.jl index 4aa0a52a7..c43938180 100644 --- a/ext/QuantumCliffordKAExt/utilities/reductions.jl +++ b/ext/QuantumCliffordKAExt/utilities/reductions.jl @@ -51,12 +51,10 @@ REDUCTION CATALOGUE index::Integer, stride::Integer, arguments::AbstractArray... )::Nothing - if length(arguments) > 0x0 - reduction = :( - arguments[0x1][index] += arguments[0x1][index + stride]; - ) + if length(arguments) > 0 + reduction = :(arguments[1][index] += arguments[1][index + stride];) end - for n in 0x2 : length(arguments) + for n in 2 : length(arguments) reduction = :( $reduction; arguments[$n][index] += arguments[$n][index + stride]; @@ -75,15 +73,13 @@ end index::Integer, stride::Integer, arguments::AbstractArray... )::Nothing - if length(arguments) > 0x0 - clause = :(arguments[0x1][index + stride] < arguments[0x1][index]) - body = :( - arguments[0x1][index] = arguments[0x1][index + stride]; - ) + if length(arguments) > 0 + clause = :(arguments[1][index + stride] < arguments[1][index]) + body = :(arguments[1][index] = arguments[1][index + stride];) end - for n in 0x2 : length(arguments) - subclause = :(arguments[0x1][index + stride] == arguments[0x1][index]) - for m in 0x2 : (n - 0x1) + for n in 2 : length(arguments) + subclause = :(arguments[1][index + stride] == arguments[1][index]) + for m in 2 : (n - 1) subclause = :( $subclause && arguments[$m][index + stride] == arguments[$m][index] diff --git a/ext/QuantumCliffordKAExt/utilities/scan_step.jl b/ext/QuantumCliffordKAExt/utilities/scan_step.jl index afcb367d9..0a9cb7330 100644 --- a/ext/QuantumCliffordKAExt/utilities/scan_step.jl +++ b/ext/QuantumCliffordKAExt/utilities/scan_step.jl @@ -4,50 +4,38 @@ @inline function scan_step( x_bits::T, z_bits::T, current_index::Integer, index::Integer, bit_shift::Integer, bit_type::PauliBit, - ::Val{sort_order} - ) where {T <: Unsigned, sort_order} - - current_shift_x = lowest_set_bit(x_bits) - current_shift_z = lowest_set_bit(z_bits) + ::Val{pauli_preferance}, ::Val{sort_order} + ) where {T <: Unsigned, pauli_preferance, sort_order} + + if pauli_preferance == pauli_preferance_x + current_shift_primary = lowest_set_bit(x_bits) + current_shift_secondary = lowest_set_bit(z_bits) + elseif pauli_preferance == pauli_preferance_z + current_shift_primary = lowest_set_bit(z_bits) + current_shift_secondary = lowest_set_bit(x_bits) + end - if sort_order == sort_order_pauli_bit_prefer_x + if sort_order == sort_order_pauli_bit - if current_shift_x < bit_count(T) && isless( - (pauli_bit_x, current_index, current_shift_x), + if current_shift_primary < bit_count(T) && isless( + (pauli_bit_primary, current_index, current_shift_primary), (bit_type, index, bit_shift) ) - output = (current_index, current_shift_x, pauli_bit_x, true) + output = ( + current_index, current_shift_primary, + pauli_bit_primary, true + ) - elseif current_shift_z < bit_count(T) && isless( - (pauli_bit_z, current_index, current_shift_z), + elseif current_shift_secondary < bit_count(T) && isless( + (pauli_bit_secondary, current_index, current_shift_secondary), (bit_type, index, bit_shift) ) - output = (current_index, current_shift_z, pauli_bit_z, false) - - else - - output = (index, bit_shift, bit_type, false) - - end - - elseif sort_order == sort_order_pauli_bit_prefer_z - - # This reverses the ordering of the Pauli bits. - if current_shift_z < bit_count(T) && isless( - (xor(Integer(pauli_bit_z), 0x1), current_index, current_shift_z), - (xor(Integer(bit_type), 0x1), index, bit_shift) - ) - - output = (current_index, current_shift_z, pauli_bit_z, true) - - elseif current_shift_x < bit_count(T) && isless( - (xor(Integer(pauli_bit_x), 0x1), current_index, current_shift_x), - (xor(Integer(bit_type), 0x1), index, bit_shift) - ) - - output = (current_index, current_shift_x, pauli_bit_x, false) + output = ( + current_index, current_shift_secondary, + pauli_bit_secondary, false + ) else @@ -55,14 +43,14 @@ end - elseif sort_order == sort_order_qubit_number_prefer_x + elseif sort_order == sort_order_qubit_number candidate = min( - (current_shift_x, pauli_bit_x), - (current_shift_z, pauli_bit_z) + (current_shift_primary, pauli_bit_primary), + (current_shift_secondary, pauli_bit_secondary) ) - if @inbounds candidate[0x1] < bit_count(T) && isless( + if @inbounds candidate[1] < bit_count(T) && isless( (current_index, candidate...), (index, bit_shift, bit_type) ) @@ -75,27 +63,6 @@ end - elseif sort_order == sort_order_qubit_number_prefer_z - - # This reverses the ordering of the Pauli bits. - candidate = min( - (current_shift_x, xor(Integer(pauli_bit_x), 0x1)), - (current_shift_z, xor(Integer(pauli_bit_z), 0x1)) - ) - - @inbounds if candidate[0x1] < bit_count(T) && isless( - (current_index, candidate...), - (index, bit_shift, xor(Integer(bit_type), 0x1)) - ) - - output = (current_index, candidate..., true) - - else - - output = (index, bit_shift, bit_type, false) - - end - end return output diff --git a/ext/QuantumCliffordKAExt/utilities/snippets.jl b/ext/QuantumCliffordKAExt/utilities/snippets.jl index 9a5b5ac9c..e68176c27 100644 --- a/ext/QuantumCliffordKAExt/utilities/snippets.jl +++ b/ext/QuantumCliffordKAExt/utilities/snippets.jl @@ -17,17 +17,17 @@ SNIPPET CATALOGUE ==============================================================================# @inline function snippet_mod_4_sum_phase!( - global_position::NTuple{N, <: Integer}, phases::AbstractArray{<: Unsigned}, + global_position::NTuple{N, Integer}, phases::AbstractArray{<: Unsigned}, partner::Union{Unsigned, AbstractArray{<: Unsigned}} )::Nothing where {N} @inbounds begin - i = global_position[0x1] + i = global_position[1] if i <= length(phases) if partner isa Integer phases[i] = (phases[i] + partner) & 0x3 elseif partner isa AbstractArray - j = ifelse(length(partner) > 0x1, i, one(i)) + j = ifelse(length(partner) > 1, i, one(i)) phases[i] = (phases[i] + partner[j]) & 0x3 end end @@ -37,11 +37,11 @@ SNIPPET CATALOGUE end @inline function snippet_mod_4_phase!( - global_position::NTuple{N, <: Integer}, phases::AbstractArray{<: Unsigned} + global_position::NTuple{N, Integer}, phases::AbstractArray{<: Unsigned} )::Nothing where {N} @inbounds begin - i = global_position[0x1] + i = global_position[1] if i <= length(phases) phases[i] &= 0x3 end @@ -51,50 +51,44 @@ end end @inline function snippet_track_pivot_canonicalize!( - global_position::NTuple{N, <: Integer}, + global_position::NTuple{N, Integer}, output_buffer::Union{Nothing, AbstractArray{<: Integer}}, - tracker::AbstractArray{<: Unsigned}, toggle::Bool, sort_order::SortOrder + tracker::AbstractArray{<: Unsigned}, toggle::Bool )::Nothing where {N} @inbounds begin current = KA.@uniform ( - ifelse(toggle, tracker_element_count, 0x0) + ifelse(toggle, tracker_element_count, zero(Csize_t)) ) previous = KA.@uniform ( - ifelse(toggle, 0x0, tracker_element_count) + ifelse(toggle, zero(Csize_t), tracker_element_count) ) - if global_position[0x1] == 0x1 + if global_position[1] == 1 bit_type = tracker[current + Integer(tracker_content_bit_type)] row = tracker[previous + Integer(tracker_content_swap_to)] invalid = Integer(pauli_bit_invalid) if !isnothing(output_buffer) - if sort_order == sort_order_pauli_bit_prefer_x - primary = Integer(pauli_bit_x) - secondary = Integer(pauli_bit_z) - elseif sort_order == sort_order_pauli_bit_prefer_z - primary = Integer(pauli_bit_z) - secondary = Integer(pauli_bit_x) - end - + primary = Integer(pauli_bit_primary) + secondary = Integer(pauli_bit_secondary) previous_bit_type = tracker[previous + Integer(tracker_content_bit_type)] # Primary => Invalid if bit_type >= invalid && previous_bit_type == primary - output_buffer[0x1] = row - output_buffer[0x2] = row - # Invalid/Primary => Secondary - elseif bit_type == secondary && previous_bit_type != secondary - output_buffer[0x1] = row + output_buffer[1] = row + output_buffer[2] = row # Secondary => Invalid elseif bit_type >= invalid && previous_bit_type == secondary - output_buffer[0x2] = row + output_buffer[2] = row + # Invalid/Primary => Secondary + elseif bit_type == secondary && previous_bit_type != secondary + output_buffer[1] = row end end - row = ifelse(bit_type < invalid, row + one(row), row) + row += ifelse(bit_type < invalid, one(row), zero(row)) tracker[current + Integer(tracker_content_swap_to)] = row end end @@ -103,19 +97,19 @@ end end @inline function snippet_track_pivot_canonicalize_rref!( - global_position::NTuple{N, <: Integer}, + global_position::NTuple{N, Integer}, output_buffer::Union{Nothing, AbstractArray{<: Integer}}, tracker::AbstractArray{<: Unsigned}, toggle::Bool )::Nothing where {N} @inbounds begin current = KA.@uniform ( - ifelse(toggle, tracker_element_count, 0x0) + ifelse(toggle, tracker_element_count, zero(Csize_t)) ) previous = KA.@uniform ( - ifelse(toggle, 0x0, tracker_element_count) + ifelse(toggle, zero(Csize_t), tracker_element_count) ) - if global_position[0x1] == 0x1 + if global_position[1] == 1 bit_type = tracker[current + Integer(tracker_content_bit_type)] row = tracker[previous + Integer(tracker_content_swap_to)] invalid = Integer(pauli_bit_invalid) @@ -125,11 +119,11 @@ end tracker[previous + Integer(tracker_content_bit_type)] # Valid => Invalid if bit_type >= invalid && previous_bit_type < invalid - output_buffer[0x1] = row - one(row) + output_buffer[1] = row - one(row) end end - row = ifelse(bit_type < invalid, row - one(row), row) + row -= ifelse(bit_type < invalid, one(row), zero(row)) tracker[current + Integer(tracker_content_swap_to)] = row end end @@ -138,19 +132,19 @@ end end @inline function snippet_swap_rows_prepare_tracker!( - global_position::NTuple{N, <: Integer}, + global_position::NTuple{N, Integer}, phases::AbstractArray{<: Unsigned}, xzs::AbstractArray{<: Unsigned}, tracker::AbstractArray{S}, toggle::Bool )::Nothing where {N, S <: Unsigned} @inbounds begin - i = global_position[0x1] - end_i = KA.@uniform (size(xzs, 0x1) >> 0x1) + i = global_position[1] + end_i = KA.@uniform (size(xzs, 1) >> 1) current = KA.@uniform ( - ifelse(toggle, tracker_element_count, 0x0) + ifelse(toggle, tracker_element_count, zero(Csize_t)) ) next = KA.@uniform ( - ifelse(toggle, 0x0, tracker_element_count) + ifelse(toggle, zero(Csize_t), tracker_element_count) ) valid = tracker[current + Integer(tracker_content_bit_type)] < @@ -185,27 +179,38 @@ end end @inline function snippet_set_row_phase_flag!( - global_position::NTuple{N, <: Integer}, + global_position::NTuple{N, Integer}, phases::AbstractArray{P}, xzs::AbstractArray{T}, - tracker::AbstractArray{<: Unsigned}, toggle::Bool + tracker::AbstractArray{<: Unsigned}, toggle::Bool, + pauli_preferance::PauliPreferance )::Nothing where {N, P <: Unsigned, T <: Unsigned} @inbounds begin - z_offset = KA.@uniform (size(xzs, 0x1) >> 0x1) - end_rows = KA.@uniform (size(xzs, 0x2)) + z_offset = KA.@uniform (size(xzs, 1) >> 1) + end_rows = KA.@uniform (size(xzs, 2)) current = KA.@uniform ( - ifelse(toggle, tracker_element_count, 0x0) + ifelse(toggle, tracker_element_count, zero(Csize_t)) ) - row = global_position[0x1] + row = global_position[1] index = tracker[current + Integer(tracker_content_index)] bit_shift = tracker[current + Integer(tracker_content_bit_shift)] bit_type = tracker[current + Integer(tracker_content_bit_type)] if bit_type < Integer(pauli_bit_invalid) && row <= end_rows - if bit_type == Integer(pauli_bit_x) + if bit_type == Integer(pauli_bit_primary) + index += ifelse( + pauli_preferance == pauli_preferance_x, + zero(z_offset), + z_offset + ) + status = xzs[index, row] & bit_mask(bit_shift, T) + elseif bit_type == Integer(pauli_bit_secondary) + index += ifelse( + pauli_preferance == pauli_preferance_x, + z_offset, + zero(z_offset) + ) status = xzs[index, row] & bit_mask(bit_shift, T) - elseif bit_type == Integer(pauli_bit_z) - status = xzs[index + z_offset, row] & bit_mask(bit_shift, T) end phases[row] &= 0x3 if status != zero(T) diff --git a/src/throws.jl b/src/throws.jl index 765b6f88f..3a0424802 100644 --- a/src/throws.jl +++ b/src/throws.jl @@ -11,6 +11,6 @@ const THROW_NQUBITS = "Unable to perform the requested operation due to encountering a mismatch \ between the number of qubits in the provided arguments." -const THROW_VALS = +const THROW_PARAMETERS = "Unable to perform the requested operation due to encountering a mismatch \ -between the provided `::Val` parameter(s) and the range of supported value(s)." +between the provided tuning parameter(s) and the range of supported value(s)." diff --git a/test/KernelAbstractions/implementation/definitions.jl b/test/KernelAbstractions/implementation/definitions.jl index bdf4ec4a4..0e5dfec5f 100644 --- a/test/KernelAbstractions/implementation/definitions.jl +++ b/test/KernelAbstractions/implementation/definitions.jl @@ -1,13 +1,5 @@ -# Small sizes for encoding issues, large sizes for race conditions. -const test_sizes = [ - 31, 32, 33, 63, 64, 65, 127, 128, 129, - 64 * 1023, 64 * 1024, 64 * 1025, 64 * 2047, 64 * 2048, 64 * 2049 - ] -# The tests are for correctness, not for device memory limits. -const max_rows = 1024 -# Keep it reasonable so that local testing remains accessible. -const round_count = 16 -# Correctness should be independent of parameter values. -# The omission of the const specifier is intentional, overridden in OpenCL. -block_sizes = rand(1 : 256, round_count) -const batch_sizes = rand(1 : 256, round_count) + +#=============================================================================# +include("definitions/test_configuration.jl") +include("definitions/tuning_parameters.jl") +#=============================================================================# diff --git a/test/KernelAbstractions/implementation/definitions/test_configuration.jl b/test/KernelAbstractions/implementation/definitions/test_configuration.jl new file mode 100644 index 000000000..302a9b2b0 --- /dev/null +++ b/test/KernelAbstractions/implementation/definitions/test_configuration.jl @@ -0,0 +1,15 @@ + +#=============================================================================# +# Small counts for encoding issues, large counts for race conditions. +const qubit_counts = [ + 31, 32, 33, 63, 64, 65, 127, 128, 129, + 64 * 1023, 64 * 1024, 64 * 1025, 64 * 2047, 64 * 2048, 64 * 2049 + ] +# The tests are for correctness, not for device memory limits. +const max_rows = 1024 + +# Keep it reasonable so that local testing remains accessible. +const max_rounds = 16 +# Certain operations are quite expensive and so should run fewer iterations. +const min_rounds = min(4, max_rounds) +#=============================================================================# diff --git a/test/KernelAbstractions/implementation/definitions/tuning_parameters.jl b/test/KernelAbstractions/implementation/definitions/tuning_parameters.jl new file mode 100644 index 000000000..63b31cb10 --- /dev/null +++ b/test/KernelAbstractions/implementation/definitions/tuning_parameters.jl @@ -0,0 +1,12 @@ + +#=============================================================================# +# Correctness should be independent of the tuning parameter values. +# They originate from a package extension, hence requiring this query. +const KAExt = Base.get_extension(QuantumClifford, :QuantumCliffordKAExt) + +const primary_axes = rand(instances(KAExt.PrimaryAxis), max_rounds) +# The omission of the const specifier is intentional, overridden in OpenCL. +# TODO: Revisit this once the POCL code generation issues are resolved. +block_sizes = rand(Base.OneTo(KAExt.default_block_size), max_rounds) +const batch_sizes = rand(Base.OneTo(KAExt.default_batch_size), max_rounds) +#=============================================================================# diff --git a/test/KernelAbstractions/implementation/imports.jl b/test/KernelAbstractions/implementation/imports.jl index 6c440d17a..bcff054f4 100644 --- a/test/KernelAbstractions/implementation/imports.jl +++ b/test/KernelAbstractions/implementation/imports.jl @@ -1,8 +1,15 @@ + +#=============================================================================# # Required for QuantumCliffordKAExt. import Atomix, GPUArraysCore, KernelAbstractions +# Required for QuantumCliffordAdaptExt. +using Adapt: adapt + # Assists in reducing resource demands. using GPUArrays: AllocCache, @cached, unsafe_free! + using QuantumClifford # This must be done explicitly as they are not exported. -using QuantumClifford: Tableau, AbstractStabilizer +using QuantumClifford: Tableau, AbstractStabilizer, random_tableau +#=============================================================================# diff --git a/test/KernelAbstractions/implementation/suites/test_KA_canonicalization.jl b/test/KernelAbstractions/implementation/suites/test_KA_canonicalization.jl new file mode 100644 index 000000000..2877a21b6 --- /dev/null +++ b/test/KernelAbstractions/implementation/suites/test_KA_canonicalization.jl @@ -0,0 +1,88 @@ + +#=============================================================================# +function test_KA_canonicalization(synchronize, AT, cache)::Nothing + for (round, nqubits) in Iterators.product( + Base.OneTo(min_rounds), qubit_counts + ) + + # Keep the memory usage sane. + rows = min(nqubits, max_rows) + axis = primary_axes[round] + block = block_sizes[round] + batch = batch_sizes[round] + + @cached cache begin + + # Stabilizer + host_stabilizer = Stabilizer(random_tableau(max_rows, nqubits)) + device_stabilizer = adapt(AT, host_stabilizer) + + # Placeholders + host_temp_stabilizer = zero(host_stabilizer) + device_temp_stabilizer = adapt(AT, host_temp_stabilizer) + + # Important optional argument. + if isodd(round) + colindices = one(nqubits) : (one(nqubits) << 1) : nqubits + xzs = host_stabilizer.tab.xzs + bit_masks = AT(zeros(eltype(xzs), size(xzs, 1) >> 1)) + fill!(bit_masks, alternating_bit_mask(eltype(xzs))) + else + colindices = Base.OneTo(nqubits) + bit_masks = nothing + end + + i = rand(Base.OneTo(rows)) + + # canonicalize + @test begin + copy_to!(host_temp_stabilizer, host_stabilizer) + copy_to!(device_temp_stabilizer, device_stabilizer) + + host_output, host_pivot_x, host_pivot_z = canonicalize!( + host_temp_stabilizer; ranks = true + ) + device_buffer = AT(zeros(typeof(host_pivot_x), 2)) + device_output, device_buffer = canonicalize!( + device_temp_stabilizer, device_buffer; + primary_axis = axis, block_size = block, batch_size = batch + ) + synchronize() + + result = equal_phases(host_output, device_output, i) + result &= equal_xzs(host_output, device_output, i) + result &= [host_pivot_x, host_pivot_z] == Array(device_buffer) + end + + # canonicalize_rref + @test begin + copy_to!(host_temp_stabilizer, host_stabilizer) + copy_to!(device_temp_stabilizer, device_stabilizer) + + host_output, host_pivot = canonicalize_rref!( + host_temp_stabilizer, colindices + ) + device_buffer = AT(zeros(typeof(host_pivot), 1)) + device_output, device_buffer = canonicalize_rref!( + device_temp_stabilizer, device_buffer, bit_masks; + primary_axis = axis, block_size = block, batch_size = batch + ) + synchronize() + + result = equal_phases(host_output, device_output, i) + result &= equal_xzs(host_output, device_output, i) + result &= [host_pivot] == Array(device_buffer) + end + + # canonicalize_gott + # TODO: Implement. + + # Marks the end for @cached + end + + # Marks the end for (round, nqubits) + end + + return nothing +end +#=============================================================================# diff --git a/test/KernelAbstractions/implementation/suites/test_KA_mul.jl b/test/KernelAbstractions/implementation/suites/test_KA_mul.jl new file mode 100644 index 000000000..ce2c972a4 --- /dev/null +++ b/test/KernelAbstractions/implementation/suites/test_KA_mul.jl @@ -0,0 +1,223 @@ + +#=============================================================================# +# TODO: Revisit should the base package support more multiplication signatures. +# This must be done explicitly as they are not exported. +using QuantumClifford: mul_left!, mul_right! + +function test_KA_mul(synchronize, AT, cache)::Nothing + for (round, nqubits) in Iterators.product( + Base.OneTo(max_rounds), qubit_counts + ) + + # Keep the memory usage sane. + rows = min(nqubits, max_rows) + axis = primary_axes[round] + block = block_sizes[round] + batch = batch_sizes[round] + + for mul! in (mul_left!, mul_right!) + @cached cache begin + + # PauliOperator + host_pauli_1 = random_pauli(nqubits) + device_pauli_1 = adapt(AT, host_pauli_1) + host_pauli_2 = random_pauli(nqubits) + device_pauli_2 = adapt(AT, host_pauli_2) + + # Tableau + host_tableau = random_tableau(rows, nqubits) + device_tableau = adapt(AT, host_tableau) + + # Destabilizer + host_destabilizer = Destabilizer(random_tableau(rows << 1, nqubits)) + device_destabilizer = adapt(AT, host_destabilizer) + + # Placeholders + host_temp_pauli = zero(host_pauli_1) + device_temp_pauli = adapt(AT, host_temp_pauli) + host_temp_tableau = zero(host_tableau) + device_temp_tableau = adapt(AT, host_temp_tableau) + + i = rand(Base.OneTo(rows)) + j = rand(Base.OneTo(rows)) + + # Potential aliasing problem. + @test begin + copy_to!(host_temp_tableau, host_tableau) + copy_to!(device_temp_tableau, device_tableau) + + host_output = mul!(host_temp_tableau, i, i) + device_output = mul!( + device_temp_tableau, i, i; + primary_axis = axis, block_size = block, batch_size = batch + ) + synchronize() + + flag = equal_phases(host_output, device_output, i) + flag &= equal_xzs(host_output, device_output, i) + end + + # PauliOperator - PauliOperator + @test begin + copy_to!(host_temp_pauli, host_pauli_1) + copy_to!(device_temp_pauli, device_pauli_1) + + host_output = mul!(host_temp_pauli, host_pauli_2) + device_output = mul!( + device_temp_pauli, device_pauli_2; + primary_axis = axis, block_size = block, batch_size = batch + ) + synchronize() + + flag = host_output.phase == Array(device_output.phase) + flag &= host_output.xz == Array(device_output.xz) + end + + # PauliOperator - Tableau/AbstractStabilizer[i] + @test begin + copy_to!(host_temp_pauli, host_pauli_1) + copy_to!(device_temp_pauli, device_pauli_1) + + host_output = mul!(host_temp_pauli, host_tableau, i) + device_output = mul!( + device_temp_pauli, device_tableau, i; + primary_axis = axis, block_size = block, batch_size = batch + ) + synchronize() + + flag = host_output.phase == Array(device_output.phase) + flag &= host_output.xz == Array(device_output.xz) + end + + # Tableau/AbstractStabilizer - PauliOperator + @test begin + copy_to!(host_temp_tableau, host_tableau) + copy_to!(device_temp_tableau, device_tableau) + + host_output = mul!(host_temp_tableau, host_pauli_1) + device_output = mul!( + device_temp_tableau, device_pauli_1; + primary_axis = axis, block_size = block, batch_size = batch + ) + synchronize() + + flag = equal_phases(host_output, device_output, i) + flag &= equal_xzs(host_output, device_output, i) + end + + # Tableau/AbstractStabilizer - Tableau/AbstractStabilizer[i] + @test begin + copy_to!(host_temp_tableau, host_tableau) + copy_to!(device_temp_tableau, device_tableau) + + host_output = mul!(host_temp_tableau, view_pauli(host_tableau, i)) + device_output = mul!( + device_temp_tableau, device_tableau, i; + primary_axis = axis, block_size = block, batch_size = batch + ) + synchronize() + + flag = equal_phases(host_output, device_output, i) + flag &= equal_xzs(host_output, device_output, i) + end + + # Tableau - Tableau/AbstractStabilizer + @test begin + copy_to!(device_temp_tableau, device_tableau) + + device_output = mul!( + device_temp_tableau, device_tableau; + primary_axis = axis, block_size = block, batch_size = batch + ) + synchronize() + + # Rows become {+/-} x Identity since it is multiplied by itself. + flag = reduce(|, view_phases(device_output)) & 0x1 == 0 + flag &= reduce(|, view_xzs(device_output, i)) == 0 + end + + # Tableau[i] - PauliOperator + @test begin + copy_to!(host_temp_pauli, view_pauli(host_tableau, i)) + copy_to!(device_temp_tableau, device_tableau) + + host_output = mul!(host_temp_pauli, host_pauli_1) + device_output = mul!( + device_temp_tableau, i, device_pauli_1; + primary_axis = axis, block_size = block, batch_size = batch + ) + synchronize() + + flag = host_output.phase == Array(view_phases(device_output, i)) + flag &= host_output.xz == Array(view_xzs(device_output, i)) + end + + # Tableau[i] - Tableau/AbstractStabilizer[j] + @test begin + copy_to!(host_temp_tableau, host_tableau) + copy_to!(device_temp_tableau, device_tableau) + + host_output = mul!(host_temp_tableau, i, host_tableau, j) + device_output = mul!( + device_temp_tableau, i, device_tableau, j; + primary_axis = axis, block_size = block, batch_size = batch + ) + synchronize() + + flag = equal_phases(host_output, device_output, i) + flag &= equal_xzs(host_output, device_output, i) + end + + # Tableau/AbstractStabilizer[i] - Self[j] + @test begin + copy_to!(host_temp_tableau, host_tableau) + copy_to!(device_temp_tableau, device_tableau) + + host_output = mul!(host_temp_tableau, i, j) + device_output = mul!( + device_temp_tableau, i, j; + primary_axis = axis, block_size = block, batch_size = batch + ) + synchronize() + + flag = equal_phases(host_output, device_output, i) + flag &= equal_xzs(host_output, device_output, i) + end + + # (Mixed)Destabilizer[i] - Self[j] + @test begin + # There is no need to copy. + host_output_1 = mul!( + view_pauli(host_destabilizer, j), + view_pauli(host_destabilizer, i); + phases = Val(false) + ) + offset = length(host_destabilizer.tab.phases) >> 1 + host_output_2 = mul!( + view_pauli(host_destabilizer, i + offset), + view_pauli(host_destabilizer, j + offset) + ) + device_output = mul!( + device_destabilizer, i, j; + primary_axis = axis, block_size = block, batch_size = batch + ) + synchronize() + + flag = host_output_1.phase == Array(view_phases(device_output, j)) + flag &= host_output_1.xz == Array(view_xzs(device_output, j)) + i += offset + flag &= host_output_2.phase == Array(view_phases(device_output, i)) + flag &= host_output_2.xz == Array(view_xzs(device_output, i)) + end + + # Marks the end for @cached + end + # Marks the end for mul! + end + + # Marks the end for (round, nqubits) + end + + return nothing +end +#=============================================================================# diff --git a/test/KernelAbstractions/implementation/test_KA_mul_leftright.jl b/test/KernelAbstractions/implementation/test_KA_mul_leftright.jl deleted file mode 100644 index ad55ec3bc..000000000 --- a/test/KernelAbstractions/implementation/test_KA_mul_leftright.jl +++ /dev/null @@ -1,238 +0,0 @@ -# TODO: Revisit should the base package support more multiplication signatures. -# This must be done explicitly as they are not exported. -using QuantumClifford: mul_left!, mul_right! - -@inline function test_KA_mul_leftright(AT, synchronize) - cache = AllocCache() - for n in test_sizes - # Keep the memory usage sane. - rows = min(n, max_rows) - for r in one(round_count) : round_count - block = block_sizes[r] - batch = batch_sizes[r] - for mul! in (mul_left!, mul_right!) - @cached cache begin - - # PauliOperator - h_p1 = random_pauli(n) - d_p1 = PauliOperator(AT(u32(h_p1.phase)), h_p1.nqubits, AT(h_p1.xz)) - h_p2 = random_pauli(n) - d_p2 = PauliOperator(AT(u32(h_p2.phase)), h_p2.nqubits, AT(h_p2.xz)) - - # Stabilizer - h_s = Stabilizer( - Tableau( - rand(eltype(h_p1.phase), rows) .& 0x3, - n, - rand(eltype(h_p1.xz), length(h_p1.xz), rows) - ) - ) - d_s = Stabilizer( - Tableau( - AT(u32(h_s.tab.phases)), - h_s.tab.nqubits, - AT(h_s.tab.xzs) - ) - ) - - # Destabilizer - h_d = Destabilizer( - Tableau( - rand(eltype(h_p1.phase), rows << 1) .& 0x3, - n, - rand(eltype(h_p1.xz), length(h_p1.xz), rows << 1) - ) - ) - d_d = Destabilizer( - Tableau( - AT(u32(h_d.tab.phases)), - h_d.tab.nqubits, - AT(h_d.tab.xzs) - ) - ) - i = rand(one(rows) : rows) - j = rand(one(rows) : rows) - - # Independent of phases. - d_o_true = mul!( - copy(d_p1), d_p2; - phases = Val(true), - block_size = Val(block), batch_size = Val(batch) - ) - d_o_false = mul!( - copy(d_p1), d_p2; - phases = Val(false), - block_size = Val(block), batch_size = Val(batch) - ) - synchronize() - @test Array(d_o_true.xz) == Array(d_o_false.xz) - - # Left/Right order. - d_L = mul_left!( - copy(d_p1), d_p2; - block_size = Val(block), batch_size = Val(batch) - ) - d_R = mul_right!( - copy(d_p1), d_p2; - block_size = Val(block), batch_size = Val(batch) - ) - synchronize() - # Either commutes or anti-commutes. - @test begin - (Array(d_L.phase)[1] - Array(d_R.phase)[1]) & 0x3 == - (comm(h_p1, h_p2) << 1) && - Array(d_L.xz) == Array(d_R.xz) - end - - # Potential aliasing problem. - h_o = mul!(copy(get_pauli(h_s, i)), h_s, i) - d_o = mul!( - copy(d_s), i, i; - block_size = Val(block), batch_size = Val(batch) - ) - synchronize() - @test begin - h_o.phase == Array(phases(d_o, i)) && - h_o.xz == Array(xzs(d_o, i)) - end - - # PauliOperator - PauliOperator - h_o = mul!(copy(h_p1), h_p2) - d_o = mul!( - copy(d_p1), d_p2; - block_size = Val(block), batch_size = Val(batch) - ) - synchronize() - @test begin - h_o.phase == Array(d_o.phase) && - h_o.xz == Array(d_o.xz) - end - - for (h_v, d_v) in ((h_s.tab, d_s.tab), (h_s, d_s)) - - # PauliOperator - Tableau/AbstractStabilizer[i] - h_o = mul!(copy(h_p1), h_v, i) - d_o = mul!( - copy(d_p1), d_v, i; - block_size = Val(block), batch_size = Val(batch) - ) - synchronize() - @test begin - h_o.phase == Array(d_o.phase) && - h_o.xz == Array(d_o.xz) - end - - # Marks the end for (h_v, d_v) - end - - for (h_u, d_u) in ((h_s.tab, d_s.tab), (h_s, d_s)) - - # Tableau/AbstractStabilizer - PauliOperator - h_o = mul!(copy(h_u), h_p1) - d_o = mul!( - copy(d_u), d_p1; - block_size = Val(block), batch_size = Val(batch) - ) - synchronize() - @test begin - phases(h_o, i) == Array(phases(d_o, i)) && - xzs(h_o, i) == Array(xzs(d_o, i)) - end - - # Tableau/AbstractStabilizer[i] - PauliOperator - h_o = mul!(copy(get_pauli(h_u, i)), h_p1) - d_o = mul!( - copy(d_u), i, d_p1; - block_size = Val(block), batch_size = Val(batch) - ) - synchronize() - @test begin - h_o.phase == Array(phases(d_o, i)) && - h_o.xz == Array(xzs(d_o, i)) - end - - # Tableau/AbstractStabilizer[i] - Self[j] - h_o = mul!(copy(get_pauli(h_u, i)), h_u, j) - d_o = mul!( - copy(d_u), i, j; - block_size = Val(block), batch_size = Val(batch) - ) - synchronize() - @test begin - h_o.phase == Array(phases(d_o, i)) && - h_o.xz == Array(xzs(d_o, i)) - end - - for (h_v, d_v) in ((h_s.tab, d_s.tab), (h_s, d_s)) - - # Tableau/AbstractStabilizer - Tableau/AbstractStabilizer - d_o = mul!( - copy(d_u), d_v; - block_size = Val(block), batch_size = Val(batch) - ) - synchronize() - # Rows become {+/-} x Identity since it is multiplied by itself. - @test begin - reduce(|, phases(d_o)) & 0x1 == 0 && - reduce(|, xzs(d_o, i)) == 0 - end - - # Tableau/AbstractStabilizer - Tableau/AbstractStabilizer[i] - h_o = mul!(copy(h_u), get_pauli(h_v, i)) - d_o = mul!( - copy(d_u), d_v, i; - block_size = Val(block), batch_size = Val(batch) - ) - synchronize() - @test begin - phases(h_o, j) == Array(phases(d_o, j)) && - xzs(h_o, j) == Array(xzs(d_o, j)) - end - - # Tableau/AbstractStabilizer[i] - Tableau/AbstractStabilizer[j] - h_o = mul!(copy(get_pauli(h_u, i)), h_v, j) - d_o = mul!( - copy(d_u), i, d_v, j; - block_size = Val(block), batch_size = Val(batch) - ) - synchronize() - @test begin - h_o.phase == Array(phases(d_o, i)) && - h_o.xz == Array(xzs(d_o, i)) - end - - # Marks the end for (h_v, d_v) - end - - # Marks the end for (h_u, d_u) - end - - # (Mixed)Destabilizer[i] - Self[j] - h_o1 = mul!( - copy(get_pauli(h_d, j)), get_pauli(h_d, i); - phases = Val(false) - ) - m = length(h_d.tab.phases) >> 1 - h_o2 = mul!(copy(get_pauli(h_d, i + m)), get_pauli(h_d, j + m)) - d_o = mul!( - d_d, i, j; - block_size = Val(block), batch_size = Val(batch) - ) - synchronize() - @test begin - h_o1.phase == Array(phases(d_o, j)) && - h_o1.xz == Array(xzs(d_o, j)) && - h_o2.phase == Array(phases(d_o, i + m)) && - h_o2.xz == Array(xzs(d_o, i + m)) - end - - # Marks the end for @cached - end - # Marks the end for mul! - end - # Marks the end for r in one(round_count) : round_count - end - # Marks the end for n in test_sizes - end - unsafe_free!(cache) -end diff --git a/test/KernelAbstractions/implementation/test_platform.jl b/test/KernelAbstractions/implementation/test_platform.jl index 7b25cff16..89d98377f 100644 --- a/test/KernelAbstractions/implementation/test_platform.jl +++ b/test/KernelAbstractions/implementation/test_platform.jl @@ -1,10 +1,22 @@ + +#=============================================================================# include("imports.jl") include("definitions.jl") include("utilities.jl") -include("test_KA_mul_leftright.jl") -@inline function test_platform(AT, synchronize) - @testset "mul_leftright" begin - test_KA_mul_leftright(AT, synchronize) +include("suites/test_KA_mul.jl") +include("suites/test_KA_canonicalization.jl") + +@inline function test_platform(synchronize, AT)::Nothing + cache = AllocCache() + + @testset "mul" begin + test_KA_mul(synchronize, AT, cache) end + @testset "canonicalization" begin + test_KA_canonicalization(synchronize, AT, cache) + end + + return nothing end +#=============================================================================# diff --git a/test/KernelAbstractions/implementation/utilities.jl b/test/KernelAbstractions/implementation/utilities.jl index c56f2a2c5..1664a22d7 100644 --- a/test/KernelAbstractions/implementation/utilities.jl +++ b/test/KernelAbstractions/implementation/utilities.jl @@ -1,46 +1,7 @@ -# Works even when broadcasting on zero-dimensional arrays. -@inline function u32(v) - return map(x -> UInt32(x), v) -end -# Surprisingly, these do not already exist. -@inline function get_pauli(t::Tableau, i::Integer) - return PauliOperator( - (@view t.phases[i]), - t.nqubits, - (@view t.xzs[:, i]) - ) -end -@inline function get_pauli(s::AbstractStabilizer, i::Integer) - return PauliOperator( - (@view s.tab.phases[i]), - s.tab.nqubits, - (@view s.tab.xzs[:, i]) - ) -end - -@inline function phases(t::Tableau) - return t.phases -end -@inline function phases(t::Tableau, i::Integer) - return (@view t.phases[i]) -end -@inline function phases(s::AbstractStabilizer) - return s.tab.phases -end -@inline function phases(s::AbstractStabilizer, i::Integer) - return (@view s.tab.phases[i]) -end - -@inline function xzs(t::Tableau) - return t.xzs -end -@inline function xzs(t::Tableau, i::Integer) - return (@view t.xzs[:, i]) -end -@inline function xzs(s::AbstractStabilizer) - return s.tab.xzs -end -@inline function xzs(s::AbstractStabilizer, i::Integer) - return (@view s.tab.xzs[:, i]) -end +#=============================================================================# +include("utilities/bit_manipulation.jl") +include("utilities/memory_management.jl") +include("utilities/views.jl") +include("utilities/equalities.jl") +#=============================================================================# diff --git a/test/KernelAbstractions/implementation/utilities/bit_manipulation.jl b/test/KernelAbstractions/implementation/utilities/bit_manipulation.jl new file mode 100644 index 000000000..1c87c8c3d --- /dev/null +++ b/test/KernelAbstractions/implementation/utilities/bit_manipulation.jl @@ -0,0 +1,13 @@ + +#=============================================================================# +# Repeats 0x55 to fill out all the bits in the given type. +@inline function alternating_bit_mask(::Type{T}) where {T <: Unsigned} + counter = count_zeros(zero(T)) >> one(T) + pattern = one(T) + while counter > one(counter) + pattern |= pattern << counter + counter >>= one(counter) + end + return pattern +end +#=============================================================================# diff --git a/test/KernelAbstractions/implementation/utilities/equalities.jl b/test/KernelAbstractions/implementation/utilities/equalities.jl new file mode 100644 index 000000000..4c43b6271 --- /dev/null +++ b/test/KernelAbstractions/implementation/utilities/equalities.jl @@ -0,0 +1,44 @@ + +#=============================================================================# +@inline function equal_phases( + u::U, v::V + ) where { + U <: Union{Tableau, AbstractStabilizer}, + V <: Union{Tableau, AbstractStabilizer} + } + + return Array(view_phases(u)) == Array(view_phases(v)) + +end +@inline function equal_phases( + u::U, v::V, i::Integer + ) where { + U <: Union{Tableau, AbstractStabilizer}, + V <: Union{Tableau, AbstractStabilizer} + } + + return Array(view_phases(u, i)) == Array(view_phases(v, i)) + +end + +@inline function equal_xzs( + u::U, v::V + ) where { + U <: Union{Tableau, AbstractStabilizer}, + V <: Union{Tableau, AbstractStabilizer} + } + + return Array(view_xzs(u)) == Array(view_xzs(v)) + +end +@inline function equal_xzs( + u::U, v::V, i::Integer + ) where { + U <: Union{Tableau, AbstractStabilizer}, + V <: Union{Tableau, AbstractStabilizer} + } + + return Array(view_xzs(u, i)) == Array(view_xzs(v, i)) + +end +#=============================================================================# diff --git a/test/KernelAbstractions/implementation/utilities/memory_management.jl b/test/KernelAbstractions/implementation/utilities/memory_management.jl new file mode 100644 index 000000000..4f44fc974 --- /dev/null +++ b/test/KernelAbstractions/implementation/utilities/memory_management.jl @@ -0,0 +1,34 @@ + +#=============================================================================# +# TODO: Remove these once the main package establishes them. +@inline function copy_to!( + target::PauliOperator, source::PauliOperator + ) + + target.nqubits == source.nqubits || throw(ArgumentError("BAD COPY_TO!")) + copyto!(target.phase, source.phase) + copyto!(target.xz, source.xz) + return target + +end + +@inline function copy_to!( + target::Tableau, source::Tableau + ) + + target.nqubits == source.nqubits || throw(ArgumentError("BAD COPY_TO!")) + copyto!(target.phases, source.phases) + copyto!(target.xzs, source.xzs) + return target + +end + +@inline function copy_to!( + target::AbstractStabilizer, source::AbstractStabilizer + ) + + copy_to!(tab(target), tab(source)) + return target + +end +#=============================================================================# diff --git a/test/KernelAbstractions/implementation/utilities/views.jl b/test/KernelAbstractions/implementation/utilities/views.jl new file mode 100644 index 000000000..ded07547c --- /dev/null +++ b/test/KernelAbstractions/implementation/utilities/views.jl @@ -0,0 +1,26 @@ + +#=============================================================================# +# Avoid scalar indexing to combat complaints from the device backend(s). +@inline function view_pauli(t::Union{Tableau, AbstractStabilizer}, i::Integer) + t_tab = tab(t) + return PauliOperator( + (@view t_tab.phases[i]), + t_tab.nqubits, + (@view t_tab.xzs[:, i]) + ) +end + +@inline function view_phases(t::Union{Tableau, AbstractStabilizer}) + return tab(t).phases +end +@inline function view_phases(t::Union{Tableau, AbstractStabilizer}, i::Integer) + return (@view tab(t).phases[i]) +end + +@inline function view_xzs(t::Union{Tableau, AbstractStabilizer}) + return tab(t).xzs +end +@inline function view_xzs(t::Union{Tableau, AbstractStabilizer}, i::Integer) + return (@view tab(t).xzs[:, i]) +end +#=============================================================================# diff --git a/test/KernelAbstractions/test_platform_CUDA.jl b/test/KernelAbstractions/test_platform_CUDA.jl index 4448ce93d..4bab9b939 100644 --- a/test/KernelAbstractions/test_platform_CUDA.jl +++ b/test/KernelAbstractions/test_platform_CUDA.jl @@ -1,3 +1,5 @@ + +#=============================================================================# @testitem "CUDA" tags = [:cuda] begin include("implementation/test_platform.jl") @@ -12,7 +14,8 @@ end if can_run - test_platform(AT, synchronize) + test_platform(synchronize, AT) end end +#=============================================================================# diff --git a/test/KernelAbstractions/test_platform_OpenCL.jl b/test/KernelAbstractions/test_platform_OpenCL.jl index 19409794b..bce986f3d 100644 --- a/test/KernelAbstractions/test_platform_OpenCL.jl +++ b/test/KernelAbstractions/test_platform_OpenCL.jl @@ -1,3 +1,5 @@ + +#=============================================================================# @testitem "OpenCL" tags = [:opencl] begin include("implementation/test_platform.jl") @@ -16,9 +18,10 @@ if can_run # TODO: Revisit this once the POCL code generation issues are resolved. - block_sizes = fill(256, round_count) + block_sizes = fill(KAExt.default_block_size, max_rounds) synchronize() = finish(queue()) - test_platform(AT, synchronize) + test_platform(synchronize, AT) end end +#=============================================================================# diff --git a/test/KernelAbstractions/test_platform_ROCm.jl b/test/KernelAbstractions/test_platform_ROCm.jl index 3b87658cf..65e501b66 100644 --- a/test/KernelAbstractions/test_platform_ROCm.jl +++ b/test/KernelAbstractions/test_platform_ROCm.jl @@ -1,3 +1,5 @@ + +#=============================================================================# @testitem "ROCm" tags = [:rocm] begin include("implementation/test_platform.jl") @@ -12,7 +14,8 @@ end if can_run - test_platform(AT, synchronize) + test_platform(synchronize, AT) end end +#=============================================================================# diff --git a/test/test_gpu_canonicalization.jl b/test/test_gpu_canonicalization.jl index 927e43e86..552037667 100644 --- a/test/test_gpu_canonicalization.jl +++ b/test/test_gpu_canonicalization.jl @@ -1,40 +1,42 @@ @testitem "GPU Canonicalization" tags=[:cuda] begin using CUDA using QuantumClifford + GPUExt = Base.get_extension(QuantumClifford, :QuantumCliffordGPUExt) using Random using QuantumClifford: to_cpu, to_gpu, random_tableau - + if CUDA.functional() @testset "GPU canonicalize! correctness" begin for n in [10, 50, 100] cpu_stab = random_stabilizer(n) gpu_stab = to_gpu(cpu_stab) cpu_result = canonicalize!(copy(cpu_stab)) - gpu_result = canonicalize!(copy(gpu_stab)) + gpu_result = GPUExt.cuda_canonicalize!(copy(gpu_stab)) cpu_from_gpu = to_cpu(gpu_result) @test cpu_result == cpu_from_gpu end end - + @testset "GPU canonicalize! performance" begin n = 6000 # Large enough to see GPU benefit # Only linear independence is required to achieve full rank. # Occurs with unit probability for arbitrarily large qubit counts. cpu_stab = Stabilizer(random_tableau(n, n)) + # Only algebraic independence is required for a full rank canonicalization. + # In the limit of large n, this occurs with a probability that tends to unity. gpu_stab = to_gpu(cpu_stab) - canonicalize!(copy(gpu_stab)) - gpu_time = @elapsed canonicalize!(copy(gpu_stab)) + GPUExt.cuda_canonicalize!(copy(gpu_stab)) + gpu_time = @elapsed GPUExt.cuda_canonicalize!(copy(gpu_stab)) # Sanity check, note that for 6000 size Stabilizer cpu version takes 1900-2400 ms on average @test 1000 * gpu_time < 1900 end - + @testset "GPU canonicalize! phases option" begin n = 50 cpu_stab = random_stabilizer(n) gpu_stab = to_gpu(cpu_stab) - gpu_result_no_phases = canonicalize!(copy(gpu_stab); phases=false) cpu_result_no_phases = canonicalize!(copy(cpu_stab); phases=false) - + gpu_result_no_phases = GPUExt.cuda_canonicalize!(copy(gpu_stab); phases=false) @test to_cpu(gpu_result_no_phases) == cpu_result_no_phases end else From 87c8e2f6bcb0f9d29e1ff5aff37b017e9ada3008 Mon Sep 17 00:00:00 2001 From: "ha.git" Date: Fri, 22 Aug 2025 18:39:37 +0200 Subject: [PATCH 3/3] Correct typo (preferance => preference). --- .../canonicalization/canonicalize.jl | 16 ++++++++-------- .../canonicalization/canonicalize_rref.jl | 16 ++++++++-------- .../canonicalization/common.jl | 14 +++++++------- .../definitions/default_parameters.jl | 2 +- .../definitions/enumerations.jl | 6 +++--- ext/QuantumCliffordKAExt/utilities/scan_step.jl | 8 ++++---- ext/QuantumCliffordKAExt/utilities/snippets.jl | 6 +++--- 7 files changed, 34 insertions(+), 34 deletions(-) diff --git a/ext/QuantumCliffordKAExt/canonicalization/canonicalize.jl b/ext/QuantumCliffordKAExt/canonicalization/canonicalize.jl index 78099590e..65e4edfc4 100644 --- a/ext/QuantumCliffordKAExt/canonicalization/canonicalize.jl +++ b/ext/QuantumCliffordKAExt/canonicalization/canonicalize.jl @@ -8,7 +8,7 @@ function device_canonicalize!( ph::AbstractArray{<: Unsigned}, xzs::AbstractArray{<: Unsigned}, output_buffer::Union{Nothing, AbstractArray{S}} = nothing; multiplication_order::MultiplicationOrder = default_multiplication_order, - pauli_preferance::PauliPreferance = default_pauli_preferance, + pauli_preference::PauliPreference = default_pauli_preference, primary_axis::PrimaryAxis = default_primary_axis, phases::Bool = default_phases, block_size::Integer = default_block_size, @@ -59,7 +59,7 @@ function device_canonicalize!( bit_scan!( tracker, mutex, xzs, nothing, primary_axis, - Val(pauli_preferance), Val(sort_order_pauli_bit), + Val(pauli_preference), Val(sort_order_pauli_bit), Val(block_size), Val(batch_size); workgroupsize = tile, ndrange = space ) @@ -76,13 +76,13 @@ function device_canonicalize!( ) snippet!( snippet_set_row_phase_flag!, - ph, xzs, tracker, toggle, pauli_preferance; + ph, xzs, tracker, toggle, pauli_preference; ndrange = row_count ) mul_and_scan!( ph, xzs, tracker, toggle, mutex, nothing, false, scan_side_greater, multiplication_order, primary_axis, - Val(pauli_preferance), Val(sort_order_pauli_bit), + Val(pauli_preference), Val(sort_order_pauli_bit), Val(phases), Val(block_size), Val(batch_size); workgroupsize = tile, ndrange = space ) @@ -108,7 +108,7 @@ end state::DeviceUnionTableau, output_buffer::Union{Nothing, AbstractGPUArray{<: Integer}} = nothing; multiplication_order::MultiplicationOrder = default_multiplication_order, - pauli_preferance::PauliPreferance = default_pauli_preferance, + pauli_preference::PauliPreference = default_pauli_preference, primary_axis::PrimaryAxis = default_primary_axis, phases::Bool = default_phases, block_size::Integer = default_block_size, @@ -126,7 +126,7 @@ end return do_canonicalize!( state, output_buffer; multiplication_order = multiplication_order, - pauli_preferance = pauli_preferance, + pauli_preference = pauli_preference, primary_axis = primary_axis, phases = phases, block_size = block_size, batch_size = batch_size ) @@ -137,7 +137,7 @@ end state::DeviceUnionTableau, output_buffer::Union{Nothing, AbstractGPUArray{<: Integer}} = nothing; multiplication_order::MultiplicationOrder = default_multiplication_order, - pauli_preferance::PauliPreferance = default_pauli_preferance, + pauli_preference::PauliPreference = default_pauli_preference, primary_axis::PrimaryAxis = default_primary_axis, phases::Bool = default_phases, block_size::Integer = default_block_size, @@ -153,7 +153,7 @@ end device_canonicalize!( state_tab.phases, state_tab.xzs, output_buffer; multiplication_order = multiplication_order, - pauli_preferance = pauli_preferance, + pauli_preference = pauli_preference, primary_axis = primary_axis, phases = phases, block_size = block_size, batch_size = batch_size ) diff --git a/ext/QuantumCliffordKAExt/canonicalization/canonicalize_rref.jl b/ext/QuantumCliffordKAExt/canonicalization/canonicalize_rref.jl index f508dd903..e44c1279c 100644 --- a/ext/QuantumCliffordKAExt/canonicalization/canonicalize_rref.jl +++ b/ext/QuantumCliffordKAExt/canonicalization/canonicalize_rref.jl @@ -9,7 +9,7 @@ function device_canonicalize_rref!( output_buffer::Union{Nothing, AbstractArray{<: Integer}} = nothing, bit_masks::Union{Nothing, AbstractArray{T}} = nothing; multiplication_order::MultiplicationOrder = default_multiplication_order, - pauli_preferance::PauliPreferance = default_pauli_preferance, + pauli_preference::PauliPreference = default_pauli_preference, primary_axis::PrimaryAxis = default_primary_axis, phases::Bool = default_phases, block_size::Integer = default_block_size, @@ -61,7 +61,7 @@ function device_canonicalize_rref!( bit_scan!( tracker, mutex, xzs, bit_masks, primary_axis, - Val(pauli_preferance), Val(sort_order_qubit_number), + Val(pauli_preference), Val(sort_order_qubit_number), Val(block_size), Val(batch_size); workgroupsize = tile, ndrange = space ) @@ -73,13 +73,13 @@ function device_canonicalize_rref!( ) snippet!( snippet_set_row_phase_flag!, - ph, xzs, tracker, toggle, pauli_preferance; + ph, xzs, tracker, toggle, pauli_preference; ndrange = row_count ) mul_and_scan!( ph, xzs, tracker, toggle, mutex, bit_masks, shrink_workspace, scan_side_lesser, multiplication_order, primary_axis, - Val(pauli_preferance), Val(sort_order_qubit_number), + Val(pauli_preference), Val(sort_order_qubit_number), Val(phases), Val(block_size), Val(batch_size); workgroupsize = tile, ndrange = space ) @@ -106,7 +106,7 @@ end output_buffer::Union{Nothing, AbstractGPUArray{<: Integer}} = nothing, bit_masks::Union{Nothing, AbstractGPUArray{<: Unsigned}} = nothing; multiplication_order::MultiplicationOrder = default_multiplication_order, - pauli_preferance::PauliPreferance = default_pauli_preferance, + pauli_preference::PauliPreference = default_pauli_preference, primary_axis::PrimaryAxis = default_primary_axis, phases::Bool = default_phases, block_size::Integer = default_block_size, @@ -128,7 +128,7 @@ end return do_canonicalize_rref!( state, output_buffer, bit_masks; multiplication_order = multiplication_order, - pauli_preferance = pauli_preferance, + pauli_preference = pauli_preference, primary_axis = primary_axis, phases = phases, block_size = block_size, batch_size = batch_size ) @@ -140,7 +140,7 @@ end output_buffer::Union{Nothing, AbstractGPUArray{<: Integer}} = nothing, bit_masks::Union{Nothing, AbstractGPUArray{<: Unsigned}} = nothing; multiplication_order::MultiplicationOrder = default_multiplication_order, - pauli_preferance::PauliPreferance = default_pauli_preferance, + pauli_preference::PauliPreference = default_pauli_preference, primary_axis::PrimaryAxis = default_primary_axis, phases::Bool = default_phases, block_size::Integer = default_block_size, @@ -156,7 +156,7 @@ end device_canonicalize_rref!( state_tab.phases, state_tab.xzs, output_buffer, bit_masks; multiplication_order = multiplication_order, - pauli_preferance = pauli_preferance, + pauli_preference = pauli_preference, primary_axis = primary_axis, phases = phases, block_size = block_size, batch_size = batch_size ) diff --git a/ext/QuantumCliffordKAExt/canonicalization/common.jl b/ext/QuantumCliffordKAExt/canonicalization/common.jl index fc15df92b..232cba6eb 100644 --- a/ext/QuantumCliffordKAExt/canonicalization/common.jl +++ b/ext/QuantumCliffordKAExt/canonicalization/common.jl @@ -6,11 +6,11 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_bit_scan!( @Const(xzs::AbstractArray{T}), @Const(bit_masks::Union{Nothing, AbstractArray{T}}), @Const(primary_axis::PrimaryAxis), - ::Val{pauli_preferance}, ::Val{sort_order}, + ::Val{pauli_preference}, ::Val{sort_order}, ::Val{block_size}, ::Val{batch_size} ) where { S <: Unsigned, T <: Unsigned, - pauli_preferance, sort_order, block_size, batch_size + pauli_preference, sort_order, block_size, batch_size } if primary_axis == primary_axis_rows @@ -46,7 +46,7 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_bit_scan!( index, bit_shift, bit_type, break_flag = scan_step( x_bits, z_bits, i, index, bit_shift, bit_type, - Val(pauli_preferance), Val(sort_order) + Val(pauli_preference), Val(sort_order) ) break_flag && break end @@ -140,11 +140,11 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( @Const(shrink_workspace::Bool), @Const(scan_side::ScanSide), @Const(multiplication_order::MultiplicationOrder), @Const(primary_axis::PrimaryAxis), - ::Val{pauli_preferance}, ::Val{sort_order}, + ::Val{pauli_preference}, ::Val{sort_order}, ::Val{phases}, ::Val{block_size}, ::Val{batch_size} ) where { S <: Unsigned, P <: Unsigned, T <: Unsigned, - pauli_preferance, sort_order, phases, block_size, batch_size + pauli_preference, sort_order, phases, block_size, batch_size } if primary_axis == primary_axis_rows @@ -266,7 +266,7 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( index, bit_shift, bit_type, break_flag = scan_step( x_bits, z_bits, i, index, bit_shift, bit_type, - Val(pauli_preferance), Val(sort_order) + Val(pauli_preference), Val(sort_order) ) break_flag && break end @@ -335,7 +335,7 @@ KA.@kernel inbounds = true unsafe_indices = true function kernel_mul_and_scan!( index, bit_shift, bit_type = scan_step( x_new, z_new, i, index, bit_shift, bit_type, - Val(pauli_preferance), Val(sort_order) + Val(pauli_preference), Val(sort_order) ) end diff --git a/ext/QuantumCliffordKAExt/definitions/default_parameters.jl b/ext/QuantumCliffordKAExt/definitions/default_parameters.jl index 9a7c45d42..7cfdf4b1a 100644 --- a/ext/QuantumCliffordKAExt/definitions/default_parameters.jl +++ b/ext/QuantumCliffordKAExt/definitions/default_parameters.jl @@ -4,7 +4,7 @@ const default_multiplication_order = multiplication_order_left # Customary choice to position X pauli operators before their Z counterparts. -const default_pauli_preferance = pauli_preferance_x +const default_pauli_preference = pauli_preference_x # Potentially boosts cache hits and reduces atomic contention. const default_primary_axis = primary_axis_rows diff --git a/ext/QuantumCliffordKAExt/definitions/enumerations.jl b/ext/QuantumCliffordKAExt/definitions/enumerations.jl index 1f6af5fc4..069945d90 100644 --- a/ext/QuantumCliffordKAExt/definitions/enumerations.jl +++ b/ext/QuantumCliffordKAExt/definitions/enumerations.jl @@ -7,9 +7,9 @@ end # Specifies whether ordering should prioritise X or Z pauli operators. -@enum PauliPreferance::UInt8 begin - pauli_preferance_x - pauli_preferance_z +@enum PauliPreference::UInt8 begin + pauli_preference_x + pauli_preference_z end # Determines whether the first grid dimension matches the rows or the qubits. diff --git a/ext/QuantumCliffordKAExt/utilities/scan_step.jl b/ext/QuantumCliffordKAExt/utilities/scan_step.jl index 0a9cb7330..c003c734f 100644 --- a/ext/QuantumCliffordKAExt/utilities/scan_step.jl +++ b/ext/QuantumCliffordKAExt/utilities/scan_step.jl @@ -4,13 +4,13 @@ @inline function scan_step( x_bits::T, z_bits::T, current_index::Integer, index::Integer, bit_shift::Integer, bit_type::PauliBit, - ::Val{pauli_preferance}, ::Val{sort_order} - ) where {T <: Unsigned, pauli_preferance, sort_order} + ::Val{pauli_preference}, ::Val{sort_order} + ) where {T <: Unsigned, pauli_preference, sort_order} - if pauli_preferance == pauli_preferance_x + if pauli_preference == pauli_preference_x current_shift_primary = lowest_set_bit(x_bits) current_shift_secondary = lowest_set_bit(z_bits) - elseif pauli_preferance == pauli_preferance_z + elseif pauli_preference == pauli_preference_z current_shift_primary = lowest_set_bit(z_bits) current_shift_secondary = lowest_set_bit(x_bits) end diff --git a/ext/QuantumCliffordKAExt/utilities/snippets.jl b/ext/QuantumCliffordKAExt/utilities/snippets.jl index e68176c27..5fdfa28ca 100644 --- a/ext/QuantumCliffordKAExt/utilities/snippets.jl +++ b/ext/QuantumCliffordKAExt/utilities/snippets.jl @@ -182,7 +182,7 @@ end global_position::NTuple{N, Integer}, phases::AbstractArray{P}, xzs::AbstractArray{T}, tracker::AbstractArray{<: Unsigned}, toggle::Bool, - pauli_preferance::PauliPreferance + pauli_preference::PauliPreference )::Nothing where {N, P <: Unsigned, T <: Unsigned} @inbounds begin @@ -199,14 +199,14 @@ end if bit_type < Integer(pauli_bit_invalid) && row <= end_rows if bit_type == Integer(pauli_bit_primary) index += ifelse( - pauli_preferance == pauli_preferance_x, + pauli_preference == pauli_preference_x, zero(z_offset), z_offset ) status = xzs[index, row] & bit_mask(bit_shift, T) elseif bit_type == Integer(pauli_bit_secondary) index += ifelse( - pauli_preferance == pauli_preferance_x, + pauli_preference == pauli_preference_x, z_offset, zero(z_offset) )