Skip to content

Commit 5b9ec32

Browse files
authored
Resolve test_approx ambiguities (#253)
1 parent 4050989 commit 5b9ec32

File tree

3 files changed

+27
-5
lines changed

3 files changed

+27
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "1.9.0"
3+
version = "1.9.1"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/check_result.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,16 @@ function test_approx(
2424
@test_msg msg isapprox(actual, expected; kwargs...)
2525
end
2626

27-
for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, AbstractThunk))
27+
for (T1, T2) in
28+
(
29+
(AbstractThunk, Any),
30+
(AbstractThunk, AbstractThunk),
31+
(Any, AbstractThunk),
32+
(Tangent, AbstractThunk),
33+
(AbstractThunk, Tangent),
34+
(AbstractZero, AbstractThunk),
35+
(AbstractThunk, AbstractZero),
36+
)
2837
@eval function test_approx(actual::$T1, expected::$T2, msg=""; kwargs...)
2938
return test_approx(unthunk(actual), unthunk(expected), msg; kwargs...)
3039
end
@@ -123,9 +132,8 @@ function test_approx(actual::Tangent{P,T}, expected, msg=""; kwargs...) where {T
123132
end
124133
test_approx(x, y::Tangent, msg=""; kwargs...) = test_approx(y, x, msg; kwargs...)
125134

126-
function test_approx(actual::Tangent, expected::AbstractThunk, msg=""; kwargs...)
127-
return test_approx(actual, unthunk(expected), msg; kwargs...)
128-
end
135+
test_approx(z::NoTangent, t::Tangent, msg=""; kwargs...) = all(==(NoTangent()), t)
136+
test_approx(t::Tangent, z::NoTangent, msg=""; kwargs...) = all(==(NoTangent()), t)
129137

130138
# This catches comparisons of Tangents and Tuples/NamedTuple
131139
# and gives an error message complaining about that. the `@test` will definitely fail

test/check_result.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,20 @@ end
8383
Tangent{Tuple{Float64,Float64}}(1.0, 2.0),
8484
@thunk(Tangent{Tuple{Float64,Float64}}(1.0, 2.0)),
8585
)
86+
test_approx(
87+
@thunk(Tangent{Tuple{Float64,Float64}}(1.0, 2.0)),
88+
Tangent{Tuple{Float64,Float64}}(1.0, 2.0),
89+
)
90+
test_approx(@thunk(ZeroTangent()), ZeroTangent())
91+
test_approx(ZeroTangent(), @thunk(ZeroTangent()))
92+
test_approx(
93+
Tangent{Tuple{Float64,Float64}}(NoTangent(), NoTangent()),
94+
NoTangent(),
95+
)
96+
test_approx(
97+
NoTangent(),
98+
Tangent{Tuple{Float64,Float64}}(NoTangent(), NoTangent()),
99+
)
86100
end
87101
@testset "negative case" begin
88102
@test fails(() -> test_approx(1.0, 2.0))

0 commit comments

Comments
 (0)