Skip to content

Commit

Permalink
Default float type to float(Real), not Real (#685)
Browse files Browse the repository at this point in the history
* Default float type to float(Real), not Real

Closes #684

* Trigger CI on backport branches/PRs

* Add integration test for #684

* Bump Turing version to 0.34 in test subfolder
  • Loading branch information
penelopeysm committed Oct 11, 2024
1 parent c38e65f commit ef37de2
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ on:
push:
branches:
- master
- backport-*
pull_request:
branches:
- master
- backport-*
merge_group:
types: [checks_requested]

Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.30"
version = "0.30.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
6 changes: 3 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -934,10 +934,10 @@ end
"""
float_type_with_fallback(x)
Return type corresponding to `float(typeof(x))` if possible; otherwise return `Real`.
Return type corresponding to `float(typeof(x))` if possible; otherwise return `float(Real)`.
"""
float_type_with_fallback(::Type) = Real
float_type_with_fallback(::Type{Union{}}) = Real
float_type_with_fallback(::Type) = float(Real)
float_type_with_fallback(::Type{Union{}}) = float(Real)
float_type_with_fallback(::Type{T}) where {T<:Real} = float(T)

"""
Expand Down
2 changes: 1 addition & 1 deletion test/turing/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ DynamicPPL = "0.24, 0.25, 0.26, 0.27, 0.28, 0.29"
HypothesisTests = "0.11"
MCMCChains = "6"
ReverseDiff = "1.15"
Turing = "0.33, 0.34"
Turing = "0.34"
julia = "1.7"
15 changes: 15 additions & 0 deletions test/turing/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,4 +342,19 @@
model = state_space(y, length(t))
@test size(sample(model, NUTS(; adtype=AutoReverseDiff(true)), n), 1) == n
end

if Threads.nthreads() > 1
@testset "DynamicPPL#684: OrderedDict with multiple types when multithreaded" begin
@model function f(x)
ns ~ filldist(Normal(0, 2.0), 3)
m ~ Uniform(0, 1)
x ~ Normal(m, 1)
end
model = f(1)
chain = sample(model, NUTS(), MCMCThreads(), 10, 2);
loglikelihood(model, chain)
logprior(model, chain)
logjoint(model, chain)
end
end
end

0 comments on commit ef37de2

Please sign in to comment.