Skip to content

Commit b1803db

Browse files
committed
rework tests to become a testsuite
1 parent 5a2083c commit b1803db

File tree

3 files changed

+157
-43
lines changed

3 files changed

+157
-43
lines changed

test/runtests.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@ const sectorlist = (
4646
TimeReversed{FermionParity SU2Irrep NewSU2Irrep},
4747
)
4848

49-
@testset "$(TensorKitSectors.type_repr(I))" for I in sectorlist
50-
@include("sectors.jl")
51-
end
49+
include("testsuite.jl")
50+
using .SectorTestSuite
51+
52+
foreach(SectorTestSuite.test, sectorlist)
5253

5354
@testset "Deligne product" begin
5455
sectorlist′ = (Trivial, sectorlist...)

test/sectors.jl

Lines changed: 48 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,62 @@
1-
Istr = TKS.type_repr(I)
2-
@testset "Sector $Istr: Basic properties" begin
1+
using .TestSetup: smallset, randsector, hasfusiontensor
2+
using .SectorTestSuite: @testsuite
3+
using TensorKitSectors
4+
using TensorOperations
5+
using LinearAlgebra
6+
7+
@testsuite "Basic properties" I -> begin
38
s = (randsector(I), randsector(I), randsector(I))
4-
@test eval(Meta.parse(sprint(show, I))) == I
5-
@test eval(Meta.parse(TKS.type_repr(I))) == I
6-
@test eval(Meta.parse(sprint(show, s[1]))) == s[1]
7-
@test @constinferred(hash(s[1])) == hash(deepcopy(s[1]))
8-
@test @constinferred(unit(s[1])) == @constinferred(unit(I))
9-
@constinferred dual(s[1])
10-
@constinferred dim(s[1])
11-
@constinferred frobenius_schur_phase(s[1])
12-
@constinferred frobenius_schur_indicator(s[1])
13-
@constinferred Nsymbol(s...)
14-
@constinferred Asymbol(s...)
15-
B = @constinferred Bsymbol(s...)
16-
F = @constinferred Fsymbol(s..., s...)
9+
@test Base.eval(Main, Meta.parse(sprint(show, I))) == I
10+
@test Base.eval(Main, Meta.parse(TensorKitSectors.type_repr(I))) == I
11+
@test Base.eval(Main, Meta.parse(sprint(show, s[1]))) == s[1]
12+
@test @testinferred(hash(s[1])) == hash(deepcopy(s[1]))
13+
@test @testinferred(unit(s[1])) == @testinferred(unit(I))
14+
@testinferred dual(s[1])
15+
@testinferred dim(s[1])
16+
@testinferred frobenius_schur_phase(s[1])
17+
@testinferred frobenius_schur_indicator(s[1])
18+
@testinferred Nsymbol(s...)
19+
@testinferred Asymbol(s...)
20+
B = @testinferred Bsymbol(s...)
21+
F = @testinferred Fsymbol(s..., s...)
1722
if BraidingStyle(I) isa HasBraiding
18-
R = @constinferred Rsymbol(s...)
23+
R = @testinferred Rsymbol(s...)
1924
if FusionStyle(I) === SimpleFusion()
20-
@test typeof(R * F) <: @constinferred sectorscalartype(I)
25+
@test typeof(R * F) <: @testinferred sectorscalartype(I)
2126
else
22-
@test Base.promote_op(*, eltype(R), eltype(F)) <: @constinferred sectorscalartype(I)
27+
@test Base.promote_op(*, eltype(R), eltype(F)) <: @testinferred sectorscalartype(I)
2328
end
2429
else
2530
if FusionStyle(I) === SimpleFusion()
26-
@test typeof(F) <: @constinferred sectorscalartype(I)
31+
@test typeof(F) <: @testinferred sectorscalartype(I)
2732
else
28-
@test eltype(F) <: @constinferred sectorscalartype(I)
33+
@test eltype(F) <: @testinferred sectorscalartype(I)
2934
end
3035
end
31-
it = @constinferred s[1] s[2]
32-
@constinferred (s..., s...)
36+
@testinferred(s[1] s[2])
37+
@testinferred((s..., s...))
3338
end
34-
@testset "Sector $Istr: Value iterator" begin
39+
40+
@testsuite "Value iterator" I -> begin
3541
@test eltype(values(I)) == I
3642
sprev = unit(I)
3743
for (i, s) in enumerate(values(I))
38-
@test !isless(s, sprev) # confirm compatibility with sort order
39-
@test s == @constinferred (values(I)[i])
44+
@test !isless(s, sprev)
45+
@test s == @testinferred(values(I)[i])
4046
@test findindex(values(I), s) == i
4147
sprev = s
4248
i >= 10 && break
4349
end
4450
@test unit(I) == first(values(I))
4551
@test length(allunits(I)) == 1
46-
@test (@constinferred findindex(values(I), unit(I))) == 1
52+
@test (@testinferred findindex(values(I), unit(I))) == 1
4753
for s in smallset(I)
48-
@test (@constinferred values(I)[findindex(values(I), s)]) == s
54+
@test (@testinferred values(I)[findindex(values(I), s)]) == s
4955
end
5056
end
51-
if BraidingStyle(I) isa Bosonic && hasfusiontensor(I)
52-
@testset "Sector $Istr: fusion tensor and F-move and R-move" begin
57+
58+
@testsuite "fusion tensor and F-move and R-move" I -> begin
59+
if BraidingStyle(I) isa Bosonic && hasfusiontensor(I)
5360
for a in smallset(I), b in smallset(I)
5461
for c in (a, b)
5562
X1 = permutedims(fusiontensor(a, b, c), (2, 1, 3, 4))
@@ -80,8 +87,9 @@ if BraidingStyle(I) isa Bosonic && hasfusiontensor(I)
8087
end
8188
end
8289
end
83-
if hasfusiontensor(I)
84-
@testset "Orthogonality of fusiontensors" begin
90+
91+
@testsuite "Orthogonality of fusiontensors" I -> begin
92+
if hasfusiontensor(I)
8593
for a in smallset(I), b in smallset(I)
8694
cs = vec(collect(a b))
8795
CGCs = map(c -> reshape(fusiontensor(a, b, c), :, dim(c)), cs)
@@ -93,7 +101,7 @@ if hasfusiontensor(I)
93101
end
94102
end
95103

96-
@testset "Sector $Istr: Unitarity of F-move" begin
104+
@testsuite "Unitarity of F-move" I -> begin
97105
for a in smallset(I), b in smallset(I), c in smallset(I)
98106
for d in (a, b, c)
99107
es = collect(intersect((a, b), map(dual, (c, dual(d)))))
@@ -105,29 +113,29 @@ end
105113
Fblocks = Vector{Any}()
106114
for e in es, f in fs
107115
Fs = Fsymbol(a, b, c, d, e, f)
108-
push!(
109-
Fblocks,
110-
reshape(Fs, (size(Fs, 1) * size(Fs, 2), size(Fs, 3) * size(Fs, 4)))
111-
)
116+
push!(Fblocks, reshape(Fs, (size(Fs, 1) * size(Fs, 2), size(Fs, 3) * size(Fs, 4))))
112117
end
113118
F = hvcat(length(fs), Fblocks...)
114119
end
115120
@test isapprox(F' * F, one(F); atol = 1.0e-12, rtol = 1.0e-12)
116121
end
117122
end
118123
end
119-
@testset "Sector $Istr: Triangle equation" begin
124+
125+
@testsuite "Triangle equation" I -> begin
120126
for a in smallset(I), b in smallset(I)
121127
@test triangle_equation(a, b; atol = 1.0e-12, rtol = 1.0e-12)
122128
end
123129
end
124-
@testset "Sector $Istr: Pentagon equation" begin
130+
131+
@testsuite "Pentagon equation" I -> begin
125132
for a in smallset(I), b in smallset(I), c in smallset(I), d in smallset(I)
126133
@test pentagon_equation(a, b, c, d; atol = 1.0e-12, rtol = 1.0e-12)
127134
end
128135
end
129-
if BraidingStyle(I) isa HasBraiding
130-
@testset "Sector $Istr: Hexagon equation" begin
136+
137+
@testsuite "Hexagon equation" I -> begin
138+
if BraidingStyle(I) isa HasBraiding
131139
for a in smallset(I), b in smallset(I), c in smallset(I)
132140
@test hexagon_equation(a, b, c; atol = 1.0e-12, rtol = 1.0e-12)
133141
end

test/testsuite.jl

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""
2+
module SectorTestSuite
3+
4+
Lightweight testsuite registration for sector tests, inspired by the GPUArrays
5+
testsuite style. Each logical group of tests is registered under a string key
6+
and can be iterated for every sector type.
7+
"""
8+
module SectorTestSuite
9+
10+
export tests, @testsuite, @testinferred
11+
12+
using Test
13+
using TestExtras
14+
using TensorKitSectors
15+
const TKS = TensorKitSectors
16+
17+
const tests = Dict()
18+
19+
"""
20+
@testsuite name I -> begin
21+
# test code here
22+
end
23+
24+
Register a sector testsuite. The body is executed with a single argument `I`, the concrete `Sector` type under test.
25+
"""
26+
macro testsuite(name, ex)
27+
safe_name = lowercase(replace(replace(name, " " => "_"), "/" => "_"))
28+
fn = Symbol("test_$(safe_name)")
29+
return quote
30+
$(esc(fn))(I) = $(esc(ex))(I)
31+
@assert !haskey(tests, $name)
32+
tests[$name] = $fn
33+
end
34+
end
35+
36+
"""
37+
Runs the entire TensorKitSectors test suite on sector type `I`
38+
"""
39+
function test(I::Type)
40+
return @testset "$(TKS.type_repr(I))" begin
41+
for (name, fun) in tests
42+
code = quote
43+
$fun($I)
44+
end
45+
@eval @testset $name $code
46+
end
47+
end
48+
end
49+
50+
macro testinferred(ex)
51+
return _inferred(ex, __module__)
52+
end
53+
macro testinferred(ex, allow)
54+
return _inferred(ex, __module__, allow)
55+
end
56+
57+
function _inferred(ex, mod, allow = :(Union{}))
58+
if Meta.isexpr(ex, :ref)
59+
ex = Expr(:call, :getindex, ex.args...)
60+
end
61+
Meta.isexpr(ex, :call)|| error("@testinferred requires a call expression")
62+
farg = ex.args[1]
63+
if isa(farg, Symbol) && farg !== :.. && first(string(farg)) == '.'
64+
farg = Symbol(string(farg)[2:end])
65+
ex = Expr(
66+
:call, GlobalRef(Test, :_materialize_broadcasted),
67+
farg, ex.args[2:end]...
68+
)
69+
end
70+
result = let ex = ex
71+
quote
72+
let allow = $(esc(allow))
73+
allow isa Type || throw(ArgumentError("@inferred requires a type as second argument"))
74+
$(
75+
if any(@nospecialize(a) -> (Meta.isexpr(a, :kw) || Meta.isexpr(a, :parameters)), ex.args)
76+
# Has keywords
77+
# Create the call expression with escaped user expressions
78+
call_expr = :($(esc(ex.args[1]))(args...; kwargs...))
79+
quote
80+
args, kwargs, result = $(esc(Expr(:call, _args_and_call, ex.args[2:end]..., ex.args[1])))
81+
# wrap in dummy hygienic-scope to work around scoping issues with `call_expr` already having `esc` on the necessary parts
82+
inftype = $(Expr(:var"hygienic-scope", Base.gen_call_with_extracted_types(mod, Base.infer_return_type, call_expr; is_source_reflection = false), Test))
83+
end
84+
else
85+
# No keywords
86+
quote
87+
args = ($([esc(ex.args[i]) for i in 2:length(ex.args)]...),)
88+
result = $(esc(ex.args[1]))(args...)
89+
inftype = Base.infer_return_type($(esc(ex.args[1])), Base.typesof(args...))
90+
end
91+
end
92+
)
93+
rettype = result isa Type ? Type{result} : typeof(result)
94+
@test rettype <: allow || rettype == Base.typesplit(inftype, allow)
95+
result
96+
end
97+
end
98+
end
99+
return Base.remove_linenums!(result)
100+
end
101+
102+
include("testsetup.jl")
103+
include("sectors.jl")
104+
105+
end # module SectorTestSuite

0 commit comments

Comments
 (0)