Skip to content

Conversation

@kshyatt
Copy link
Member

@kshyatt kshyatt commented Oct 22, 2025

No description provided.

@codecov
Copy link

codecov bot commented Oct 22, 2025

Codecov Report

❌ Patch coverage is 0% with 199 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...ixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl 0.00% 197 Missing ⚠️
src/pullbacks/eig.jl 0.00% 2 Missing ⚠️
Files with missing lines Coverage Δ
src/implementations/eigh.jl 75.53% <ø> (-18.75%) ⬇️
src/pullbacks/eig.jl 0.00% <0.00%> (-96.11%) ⬇️
...ixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl 0.00% <0.00%> (ø)

... and 31 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@github-actions
Copy link

github-actions bot commented Oct 22, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
index e18cc41..8ef0b4e 100644
--- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
+++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
@@ -198,8 +198,8 @@ function EnzymeRules.augmented_primal(
     svd_compact!(A.val, USVᴴ.val, alg.val.alg)
     cache_USVᴴ = copy.(USVᴴ.val)
     USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ.val, alg.val.trunc)
-    ϵ.val      = MatrixAlgebraKit.truncation_error!(diagview(USVᴴ.val[2]), ind)
-    primal     = EnzymeRules.needs_primal(config) ? (USVᴴ′..., ϵ.val) : nothing
+    ϵ.val = MatrixAlgebraKit.truncation_error!(diagview(USVᴴ.val[2]), ind)
+    primal = EnzymeRules.needs_primal(config) ? (USVᴴ′..., ϵ.val) : nothing
     shadow_USVᴴ = if !isa(A, Const) && !isa(USVᴴ, Const)
         dU, dS, dVᴴ = USVᴴ.dval
         # This creates new output shadow matrices, we do this slicing
@@ -207,8 +207,8 @@ function EnzymeRules.augmented_primal(
         # These new shadow matrices are "filled in" with the accumulated
         # results from earlier in reverse-mode AD after this function exits
         # and before `reverse` is called.
-        dStrunc  = Diagonal(diagview(dS)[ind])
-        dUtrunc  = dU[:, ind]
+        dStrunc = Diagonal(diagview(dS)[ind])
+        dUtrunc = dU[:, ind]
         dVᴴtrunc = dVᴴ[ind, :]
         (dUtrunc, dStrunc, dVᴴtrunc)
     else
@@ -227,9 +227,9 @@ function EnzymeRules.reverse(
         alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
     ) where {RT}
     cache_A, cache_USVᴴ, shadow_USVᴴ, ind = cache
-    U, S, Vᴴ    = cache_USVᴴ
+    U, S, Vᴴ = cache_USVᴴ
     dU, dS, dVᴴ = shadow_USVᴴ
-    Aval        = isnothing(cache_A) ? A.val : cache_A
+    Aval = isnothing(cache_A) ? A.val : cache_A
     if !isa(A, Const) && !isa(USVᴴ, Const)
         svd_pullback!(A.dval, Aval, (U, S, Vᴴ), shadow_USVᴴ, ind)
     end
@@ -246,14 +246,14 @@ function EnzymeRules.augmented_primal(
         alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
     )
     # form cache if needed
-    cache_A     = copy(A.val)
+    cache_A = copy(A.val)
     U, S, Vᴴ, ϵ = svd_trunc(A.val, USVᴴ.val, alg.val.alg)
-    primal      = EnzymeRules.needs_primal(config) ? (U, S, Vᴴ, ϵ) : nothing
-    dU  = zero(U)
-    dS  = zero(S)
+    primal = EnzymeRules.needs_primal(config) ? (U, S, Vᴴ, ϵ) : nothing
+    dU = zero(U)
+    dS = zero(S)
     dVᴴ = zero(Vᴴ)
-    dϵ  = zero(ϵ)
-    shadow      = EnzymeRules.needs_shadow(config) ? (dU, dS, dVᴴ, dϵ) : nothing
+    dϵ = zero(ϵ)
+    shadow = EnzymeRules.needs_shadow(config) ? (dU, dS, dVᴴ, dϵ) : nothing
     return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, (U, S, Vᴴ), (dU, dS, dVᴴ)))
 end
 function EnzymeRules.reverse(
@@ -265,9 +265,9 @@ function EnzymeRules.reverse(
         alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
     )
     cache_A, cache_USVᴴ, shadow_USVᴴ = cache
-    U, S, Vᴴ    = cache_USVᴴ
+    U, S, Vᴴ = cache_USVᴴ
     dU, dS, dVᴴ = shadow_USVᴴ
-    Aval        = isnothing(cache_A) ? A.val : cache_A
+    Aval = isnothing(cache_A) ? A.val : cache_A
     if !isa(A, Const) && !isa(USVᴴ, Const)
         svd_trunc_pullback!(A.dval, Aval, (U, S, Vᴴ), shadow_USVᴴ, ind)
     end
diff --git a/test/enzyme.jl b/test/enzyme.jl
index 208f456..7c257bd 100644
--- a/test/enzyme.jl
+++ b/test/enzyme.jl
@@ -397,7 +397,7 @@ end
                     fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
                     # broken due to Enzyme -- copying in gaugefix????
                     test_reverse(svd_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm = fdm)
-                    test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), return_act=RT)
+                    test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), return_act = RT)
                 end
                 U, S, Vᴴ = svd_compact(A)
                 ΔU = randn(rng, T, m, minmn)
@@ -416,7 +416,7 @@ end
                 fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
                 # broken due to Enzyme
                 test_reverse(svd_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm = fdm)
-                test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), return_act=RT)
+                test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), return_act = RT)
             end
         end
     end

@kshyatt kshyatt force-pushed the ksh/enzyme branch 2 times, most recently from 4408849 to ec8354b Compare October 24, 2025 12:26
func::Const{typeof($f)},
::Type{RT},
A::Annotation{<:AbstractMatrix},
arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove AbstractMatrix restrictions here and elsewhere

) where {RT}
cache_arg = nothing
# form cache if needed
cache_A = (EnzymeRules.overwritten(config)[2] && !(typeof(arg) <: Const)) ? copy(A.val) : nothing
Copy link
Member Author

@kshyatt kshyatt Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe for the ! methods? Those should be tested directly

@kshyatt kshyatt force-pushed the ksh/enzyme branch 4 times, most recently from 7478c90 to 9091c56 Compare November 13, 2025 13:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants