Skip to content

Commit c57de02

Browse files
committed
Merge remote-tracking branch 'origin/main' into breaking
2 parents 77af4eb + 80cf12d commit c57de02

File tree

14 files changed

+383
-193
lines changed

14 files changed

+383
-193
lines changed

.github/workflows/Benchmarking.yml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,19 @@ jobs:
3333
echo "$version_info" >> $GITHUB_ENV
3434
echo "EOF" >> $GITHUB_ENV
3535
36-
# Capture benchmark output into a variable
36+
# Capture benchmark output into a variable. The sed and tail calls cut out anything but the
37+
# final block of results.
3738
echo "Running Benchmarks..."
38-
benchmark_output=$(julia --project=benchmarks benchmarks/benchmarks.jl)
39-
39+
benchmark_output=$(\
40+
julia --project=benchmarks benchmarks/benchmarks.jl \
41+
| sed -n '/Final results:/,$p' \
42+
| tail -n +2\
43+
)
44+
4045
# Print benchmark results directly to the workflow log
4146
echo "Benchmark Results:"
4247
echo "$benchmark_output"
43-
48+
4449
# Set the benchmark output as an env var for later steps
4550
echo "BENCHMARK_OUTPUT<<EOF" >> $GITHUB_ENV
4651
echo "$benchmark_output" >> $GITHUB_ENV

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead.
66

7+
## 0.38.4
8+
9+
Improve performance of VarNamedVector. It should now be very nearly on par with Metadata for all models we've benchmarked on.
10+
711
## 0.38.3
812

913
Add an implementation of `returned(::Model, ::AbstractDict{<:VarName})`.

benchmarks/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ LogDensityProblems = "2.1.2"
3030
Mooncake = "0.4"
3131
PrettyTables = "3"
3232
ReverseDiff = "1.15.3"
33-
StableRNGs = "1"
33+
StableRNGs = "1"

benchmarks/benchmarks.jl

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,27 @@ using StableRNGs: StableRNG
77

88
rng = StableRNG(23)
99

10+
function print_results(results_table)
11+
table_matrix = hcat(Iterators.map(collect, zip(results_table...))...)
12+
header = [
13+
"Model",
14+
"Dim",
15+
"AD Backend",
16+
"VarInfo",
17+
"Linked",
18+
"t(eval)/t(ref)",
19+
"t(grad)/t(eval)",
20+
]
21+
return pretty_table(
22+
table_matrix;
23+
column_labels=header,
24+
backend=:text,
25+
formatters=[fmt__printf("%.1f", [6, 7])],
26+
fit_table_in_display_horizontally=false,
27+
fit_table_in_display_vertically=false,
28+
)
29+
end
30+
1031
# Create DynamicPPL.Model instances to run benchmarks on.
1132
smorgasbord_instance = Models.smorgasbord(randn(rng, 100), randn(rng, 100))
1233
loop_univariate1k, multivariate1k = begin
@@ -41,6 +62,8 @@ chosen_combinations = [
4162
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true),
4263
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true),
4364
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true),
65+
("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true),
66+
("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true),
4467
("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true),
4568
("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true),
4669
("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true),
@@ -82,17 +105,9 @@ for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinati
82105
relative_ad_eval_time,
83106
),
84107
)
108+
println("Results so far:")
109+
print_results(results_table)
85110
end
86111

87-
table_matrix = hcat(Iterators.map(collect, zip(results_table...))...)
88-
header = [
89-
"Model", "Dim", "AD Backend", "VarInfo", "Linked", "t(eval)/t(ref)", "t(grad)/t(eval)"
90-
]
91-
pretty_table(
92-
table_matrix;
93-
column_labels=header,
94-
backend=:text,
95-
formatters=[fmt__printf("%.1f", [6, 7])],
96-
fit_table_in_display_horizontally=false,
97-
fit_table_in_display_vertically=false,
98-
)
112+
println("Final results:")
113+
print_results(results_table)

benchmarks/src/DynamicPPLBenchmarks.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
8080
retvals = model(rng)
8181
vns = [VarName{k}() for k in keys(retvals)]
8282
SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals))))
83+
elseif varinfo_choice == :typed_vector
84+
DynamicPPL.typed_vector_varinfo(rng, model)
85+
elseif varinfo_choice == :untyped_vector
86+
DynamicPPL.untyped_vector_varinfo(rng, model)
8387
else
8488
error("Unknown varinfo choice: $varinfo_choice")
8589
end

docs/src/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ DynamicPPL.reset!
414414
DynamicPPL.update!
415415
DynamicPPL.insert!
416416
DynamicPPL.loosen_types!!
417-
DynamicPPL.tighten_types
417+
DynamicPPL.tighten_types!!
418418
```
419419

420420
```@docs

src/contexts/init.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ function tilde_assume!!(
180180
end
181181
# Neither of these set the `trans` flag so we have to do it manually if
182182
# necessary.
183-
insert_transformed_value && set_transformed!!(vi, true, vn)
183+
if insert_transformed_value
184+
vi = set_transformed!!(vi, true, vn)
185+
end
184186
# `accumulate_assume!!` wants untransformed values as the second argument.
185187
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
186188
# We always return the untransformed value here, as that will determine

src/debug_utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ add_io_context(io::IO) = IOContext(io, :compact => true, :limit => true)
2727
show_varname(io::IO, varname::VarName) = print(io, varname)
2828
function show_varname(io::IO, varname::Array{<:VarName,N}) where {N}
2929
# Attempt to make the type concrete in case the symbol is shared.
30-
return _show_varname(io, map(identity, varname))
30+
return _show_varname(io, [vn for vn in varname])
3131
end
3232
function _show_varname(io::IO, varname::Array{<:VarName,N}) where {N}
3333
# Print the first and last element of the array.
@@ -407,7 +407,7 @@ julia> @model function demo_incorrect()
407407
end
408408
demo_incorrect (generic function with 2 methods)
409409
410-
julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually
410+
julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually
411411
# alert us to the issue of `x` being sampled twice.
412412
model = demo_incorrect(); varinfo = VarInfo(model);
413413

src/logdensityfunction.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ box:
4949
- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring
5050
any effects of linking
5151
- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected
52-
by linking, since transforms are only applied to random variables)
52+
by linking, since transforms are only applied to random variables)
5353
5454
!!! note
5555
By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the
@@ -146,7 +146,7 @@ struct LogDensityFunction{
146146
is_supported(adtype) ||
147147
@warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
148148
# Get a set of dummy params to use for prep
149-
x = map(identity, varinfo[:])
149+
x = [val for val in varinfo[:]]
150150
if use_closure(adtype)
151151
prep = DI.prepare_gradient(
152152
LogDensityAt(model, getlogdensity, varinfo), adtype, x
@@ -282,7 +282,7 @@ function LogDensityProblems.logdensity_and_gradient(
282282
) where {M,F,V,AD<:ADTypes.AbstractADType}
283283
f.prep === nothing &&
284284
error("Gradient preparation not available; this should not happen")
285-
x = map(identity, x) # Concretise type
285+
x = [val for val in x] # Concretise type
286286
# Make branching statically inferrable, i.e. type-stable (even if the two
287287
# branches happen to return different types)
288288
return if use_closure(f.adtype)

src/simple_varinfo.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ function set_transformed!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName)
484484
"Individual variables in SimpleVarInfo cannot have different `set_transformed` statuses.",
485485
)
486486
end
487+
return vi
487488
end
488489

489490
is_transformed(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)

0 commit comments

Comments
 (0)