diff --git a/src/check_result.jl b/src/check_result.jl index a0ed890..fee2943 100644 --- a/src/check_result.jl +++ b/src/check_result.jl @@ -40,7 +40,9 @@ for (T1, T2) in end test_approx(::AbstractZero, x, msg=""; kwargs...) = test_approx(zero(x), x, msg; kwargs...) +test_approx(::AbstractZero, x::AbstractArray{<:AbstractArray}, msg=""; kwargs...) = test_approx(map(zero, x), x, msg; kwargs...) test_approx(x, ::AbstractZero, msg=""; kwargs...) = test_approx(x, zero(x), msg; kwargs...) +test_approx(x::AbstractArray{<:AbstractArray}, ::AbstractZero, msg=""; kwargs...) = test_approx(x, map(zero, x), msg; kwargs...) test_approx(x::ZeroTangent, y::ZeroTangent, msg=""; kwargs...) = @test true test_approx(x::NoTangent, y::NoTangent, msg=""; kwargs...) = @test true diff --git a/test/check_result.jl b/test/check_result.jl index eabb1af..a708990 100644 --- a/test/check_result.jl +++ b/test/check_result.jl @@ -36,6 +36,8 @@ end test_approx([1.0, 2.0], [1.0, 2.0]) test_approx([[1.0], [2.0]], [[1.0], [2.0]]) + test_approx([[0.0], [0.0]], ZeroTangent()) + test_approx(ZeroTangent(), [[0.0], [0.0]]) test_approx(Broadcast.broadcasted(identity, [1.0 2.0; 3.0 4.0]), [1.0 2.0; 3.0 4.0]) test_approx(@thunk(10 * 0.1 * [[1.0], [2.0]]), [[1.0], [2.0]]) @@ -108,6 +110,8 @@ end @test fails(() -> test_approx([1.0, 2.0], [1.0, 3.9])) @test fails(() -> test_approx([[1.0], [2.0]], [[1.1], [2.0]])) + @test fails(() -> test_approx([[0.0], [0.1]], ZeroTangent())) + @test fails(() -> test_approx(ZeroTangent(), [[0.1], [0.0]])) @test fails(() -> test_approx(@thunk(10 * [[1.0], [2.0]]), [[1.0], [2.0]]))