Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop macro refactoring #84

Merged
merged 13 commits into from
May 12, 2020
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
*.cov
.ipynb_checkpoints
demo/*.jl
docs/build
docs/build
Manifest.toml
44 changes: 22 additions & 22 deletions src/algorithms/expectation_propagation/expectation_propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ expectationPropagationAlgorithm,
"""
Create a sum-product algorithm to infer marginals over `variables`, and compile it to Julia code
"""
function expectationPropagationAlgorithm(variables::Vector{Variable};
function expectationPropagationAlgorithm(variables::Vector{Variable};
id=Symbol(""),
free_energy=false)

Expand All @@ -22,12 +22,12 @@ function expectationPropagationAlgorithm(variables::Vector{Variable};
schedule = expectationPropagationSchedule(pf)
pf.schedule = condense(flatten(schedule)) # Inline all internal message passing and remove clamp node entries
pf.marginal_table = marginalTable(variables)

# Populate fields for algorithm compilation
algo = InferenceAlgorithm(pfz, id=id)
assembleInferenceAlgorithm!(algo)
free_energy && assembleFreeEnergy!(algo)

return algo
end
expectationPropagationAlgorithm(variable::Variable; id=Symbol(""), free_energy=false) = expectationPropagationAlgorithm([variable], id=id, free_energy=free_energy)
Expand All @@ -37,16 +37,16 @@ A non-specific expectation propagation update
"""
abstract type ExpectationPropagationRule{factor_type} <: MessageUpdateRule end

"""
"""
`expectationPropagationSchedule()` generates a expectation propagation
message passing schedule.
"""
message passing schedule.
"""
function expectationPropagationSchedule(pf::PosteriorFactor)
ep_sites = collectEPSites(nodes(current_graph))
breaker_sites = Interface[site.partner for site in ep_sites]
breaker_types = breakerTypes(breaker_sites)

schedule = summaryPropagationSchedule(sort(collect(pf.target_variables), rev=true),
schedule = summaryPropagationSchedule(sort(collect(pf.target_variables), rev=true),
sort(collect(pf.target_clusters), rev=true);
target_sites=[breaker_sites; ep_sites])

Expand Down Expand Up @@ -102,7 +102,7 @@ function inferUpdateRule!(entry::ScheduleEntry,

# Find outbound id
outbound_id = something(findfirst(isequal(entry.interface), entry.interface.node.interfaces), 0)

# Find applicable rule(s)
applicable_rules = Type[]
for rule in leaftypes(entry.message_update_rule)
Expand Down Expand Up @@ -140,9 +140,9 @@ function collectInboundTypes(entry::ScheduleEntry,
end

"""
`@expectationPropagationRule` registers a expectation propagation update
rule by defining the rule type and the corresponding methods for the `outboundType`
and `isApplicable` functions. If no name (type) for the new rule is passed, a
`@expectationPropagationRule` registers a expectation propagation update
rule by defining the rule type and the corresponding methods for the `outboundType`
and `isApplicable` functions. If no name (type) for the new rule is passed, a
unique name (type) will be generated. Returns the rule type.
"""
macro expectationPropagationRule(fields...)
Expand Down Expand Up @@ -182,23 +182,23 @@ macro expectationPropagationRule(fields...)
end

# Build validators for isApplicable
input_type_validators =
String["length(input_types) == $(length(inbound_types.args))"]
input_type_validators = Expr[]

push!(input_type_validators, :(length(input_types) == $(length(inbound_types.args))))
for (i, i_type) in enumerate(inbound_types.args)
if i_type != :Nothing
# Only validate inbounds required for message update
push!(input_type_validators, "ForneyLab.matches(input_types[$i], $i_type)")
push!(input_type_validators, :(ForneyLab.matches(input_types[$i], $i_type)))
end
end

expr = parse("""
begin
mutable struct $name <: ExpectationPropagationRule{$node_type} end
ForneyLab.outboundType(::Type{$name}) = $outbound_type
ForneyLab.isApplicable(::Type{$name}, input_types::Vector{<:Type}, outbound_id::Int64) = $(join(input_type_validators, " && ")) && (outbound_id == $outbound_id)
$name
expr = quote
struct $name <: ExpectationPropagationRule{$node_type} end
ForneyLab.outboundType(::Type{$name}) = $outbound_type
ForneyLab.isApplicable(::Type{$name}, input_types::Vector{<:Type}, outbound_id::Int64) = begin
$(reduce((current, item) -> :($current && $item), input_type_validators, init = :(outbound_id === $outbound_id)))
end
""")
end

return esc(expr)
end
Expand All @@ -223,4 +223,4 @@ function collectInbounds(entry::ScheduleEntry, ::Type{T}) where T<:ExpectationPr
end

return inbounds
end
end
26 changes: 13 additions & 13 deletions src/algorithms/joint_marginals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ through a node-specific joint marginal update rule.
function MarginalEntry(target::Cluster, outbound_types::Dict{Interface, Type})
inbound_types = collectInboundTypes(target, outbound_types)
marginal_update_rule = inferMarginalRule(target, inbound_types)
# Collect inbound interfaces

# Collect inbound interfaces
inbound_interfaces = Interface[]
for edge in target.edges
if edge.a in target.node.interfaces
Expand Down Expand Up @@ -135,22 +135,22 @@ macro marginalRule(fields...)
end

# Build validators for isApplicable
input_type_validators =
String["length(input_types) == $(length(inbound_types.args))"]
input_type_validators = Expr[]

push!(input_type_validators, :(length(input_types) == $(length(inbound_types.args))))
for (i, i_type) in enumerate(inbound_types.args)
if i_type != :Nothing
# Only validate inbounds required for update
push!(input_type_validators, "ForneyLab.matches(input_types[$i], $i_type)")
push!(input_type_validators, :(ForneyLab.matches(input_types[$i], $i_type)))
end
end

expr = parse("""
begin
mutable struct $name <: MarginalRule{$node_type} end
ForneyLab.isApplicable(::Type{$name}, input_types::Vector{<:Type}) = $(join(input_type_validators, " && "))
$name
expr = quote
struct $name <: MarginalRule{$node_type} end
ForneyLab.isApplicable(::Type{$name}, input_types::Vector{<:Type}) = begin
$(reduce((current, item) -> :($current && $item), input_type_validators, init = :true))
end
""")
end

return esc(expr)
end
Expand All @@ -171,7 +171,7 @@ function collectMarginalNodeInbounds(::FactorNode, entry::MarginalEntry)
entry_pf = posteriorFactor(first(entry.target.edges))
encountered_external_regions = Set{Region}()
for node_interface in entry.target.node.interfaces
current_region = region(inbound_cluster.node, node_interface.edge) # Note: edges that are not assigned to a posterior factor are assumed mean-field
current_region = region(inbound_cluster.node, node_interface.edge) # Note: edges that are not assigned to a posterior factor are assumed mean-field
current_pf = posteriorFactor(node_interface.edge) # Returns an Edge if no posterior factor is assigned
inbound_interface = ultimatePartner(node_interface)

Expand All @@ -189,4 +189,4 @@ function collectMarginalNodeInbounds(::FactorNode, entry::MarginalEntry)
end

return inbounds
end
end
36 changes: 18 additions & 18 deletions src/algorithms/sum_product/sum_product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ function sumProductAlgorithm(variables::Vector{Variable};
pfz = PosteriorFactorization()
# Contain the entire graph in a single posterior factor
pf = PosteriorFactor(pfz, id=Symbol(""))

# Set the target regions (variables and clusters) of the posterior factor
setTargets!(pf, pfz, variables, free_energy=free_energy, external_targets=false)

# Infer schedule and marginal computations
schedule = sumProductSchedule(pf) # For free energy computation, additional targets might be required
pf.schedule = condense(flatten(schedule)) # Inline all internal message passing and remove clamp node entries from schedule
Expand All @@ -37,13 +37,13 @@ A non-specific sum-product update
"""
abstract type SumProductRule{factor_type} <: MessageUpdateRule end

"""
"""
`sumProductSchedule()` generates a sum-product message passing schedule that
computes the marginals for each of the posterior factor targets.
"""
"""
function sumProductSchedule(pf::PosteriorFactor)
# Generate a feasible summary propagation schedule
schedule = summaryPropagationSchedule(sort(collect(pf.target_variables), rev=true),
schedule = summaryPropagationSchedule(sort(collect(pf.target_variables), rev=true),
sort(collect(pf.target_clusters), rev=true))

# Assign the sum-product update rule to each of the schedule entries
Expand Down Expand Up @@ -162,12 +162,12 @@ macro sumProductRule(fields...)
outbound_type = :unknown
inbound_types = :unknown
name = :auto # Triggers automatic naming unless overwritten

# Loop over fields because order is unknown
for arg in fields

(arg.args[1] == :(=>)) || error("Invalid call to @sumProductRule")

if arg.args[2].value == :node_type
node_type = arg.args[3]
elseif arg.args[2].value == :outbound_type
Expand All @@ -191,23 +191,23 @@ macro sumProductRule(fields...)
end

# Build validators for isApplicable
input_type_validators =
String["length(input_types) == $(length(inbound_types.args))"]
input_type_validators = Expr[]

push!(input_type_validators, :(length(input_types) == $(length(inbound_types.args))))
for (i, i_type) in enumerate(inbound_types.args)
if i_type != :Nothing
# Only validate inbounds required for message update
push!(input_type_validators, "ForneyLab.matches(input_types[$i], $i_type)")
push!(input_type_validators, :(ForneyLab.matches(input_types[$i], $i_type)))
end
end

expr = parse("""
begin
mutable struct $name <: SumProductRule{$node_type} end
ForneyLab.outboundType(::Type{$name}) = $outbound_type
ForneyLab.isApplicable(::Type{$name}, input_types::Vector{<:Type}) = $(join(input_type_validators, " && "))
$name
expr = quote
struct $name <: SumProductRule{$node_type} end
ForneyLab.outboundType(::Type{$name}) = $outbound_type
ForneyLab.isApplicable(::Type{$name}, input_types::Vector{<:Type}) = begin
$(reduce((current, item) -> :($current && $item), input_type_validators, init = :true))
end
""")
end

return esc(expr)
end
Expand Down Expand Up @@ -241,4 +241,4 @@ function collectSumProductNodeInbounds(::FactorNode, entry::ScheduleEntry)
end

return inbounds
end
end
44 changes: 22 additions & 22 deletions src/algorithms/variational_bayes/naive_variational_bayes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ variationalAlgorithm,
"""
Create a variational algorithm to infer marginals over a posterior distribution, and compile it to Julia code
"""
function variationalAlgorithm(pfz::PosteriorFactorization=currentPosteriorFactorization();
id=Symbol(""),
function variationalAlgorithm(pfz::PosteriorFactorization=currentPosteriorFactorization();
id=Symbol(""),
free_energy=false)

(length(pfz.posterior_factors) > 0) || error("No factors defined on posterior factorization.")
Expand Down Expand Up @@ -52,8 +52,8 @@ function variationalSchedule(posterior_factor::PosteriorFactor)
nodes_connected_to_external_edges = nodesConnectedToExternalEdges(posterior_factor)

# Schedule messages towards posterior factors and target sites, limited to the internal edges
schedule = summaryPropagationSchedule(sort(collect(posterior_factor.target_variables), rev=true),
sort(collect(posterior_factor.target_clusters), rev=true),
schedule = summaryPropagationSchedule(sort(collect(posterior_factor.target_variables), rev=true),
sort(collect(posterior_factor.target_clusters), rev=true),
limit_set=posterior_factor.internal_edges)
for entry in schedule
if (entry.interface.node in nodes_connected_to_external_edges) && !isa(entry.interface.node, DeltaFactor)
Expand All @@ -62,7 +62,7 @@ function variationalSchedule(posterior_factor::PosteriorFactor)
entry.message_update_rule = NaiveVariationalRule{typeof(entry.interface.node)}
else
entry.message_update_rule = StructuredVariationalRule{typeof(entry.interface.node)}
end
end
else
entry.message_update_rule = SumProductRule{typeof(entry.interface.node)}
end
Expand All @@ -81,7 +81,7 @@ function inferUpdateRule!( entry::ScheduleEntry,
inferred_outbound_types::Dict{Interface, Type}) where T<:NaiveVariationalRule
# Collect inbound types
inbound_types = collectInboundTypes(entry, rule_type, inferred_outbound_types)

# Find applicable rule(s)
applicable_rules = Type[]
for rule in leaftypes(entry.message_update_rule)
Expand Down Expand Up @@ -115,20 +115,20 @@ function collectInboundTypes( entry::ScheduleEntry,
push!(inbound_types, Nothing)
else
# Edge is external, accept marginal
push!(inbound_types, ProbabilityDistribution)
push!(inbound_types, ProbabilityDistribution)
end
end

return inbound_types
end

"""
"""
`@naiveVariationalRule` registers a variational update rule for the naive
(mean-field) factorization by defining the rule type and the corresponding
methods for the `outboundType` and `isApplicable` functions. If no name (type)
for the new rule is passed, a unique name (type) will be generated. Returns the
rule type.
"""
rule type.
"""
macro naiveVariationalRule(fields...)
# Init required fields in macro scope
node_type = :unknown
Expand Down Expand Up @@ -163,23 +163,23 @@ macro naiveVariationalRule(fields...)
end

# Build validators for isApplicable
input_type_validators =
String["length(input_types) == $(length(inbound_types.args))"]
input_type_validators = Expr[]

push!(input_type_validators, :(length(input_types) == $(length(inbound_types.args))))
for (i, i_type) in enumerate(inbound_types.args)
if i_type != :Nothing
# Only validate inbounds required for message update
push!(input_type_validators, "ForneyLab.matches(input_types[$i], $i_type)")
push!(input_type_validators, :(ForneyLab.matches(input_types[$i], $i_type)))
end
end

expr = parse("""
begin
mutable struct $name <: NaiveVariationalRule{$node_type} end
ForneyLab.outboundType(::Type{$name}) = $outbound_type
ForneyLab.isApplicable(::Type{$name}, input_types::Vector{<:Type}) = $(join(input_type_validators, " && "))
$name
expr = quote
struct $name <: NaiveVariationalRule{$node_type} end
ForneyLab.outboundType(::Type{$name}) = $outbound_type
ForneyLab.isApplicable(::Type{$name}, input_types::Vector{<:Type}) = begin
$(reduce((current, item) -> :($current && $item), input_type_validators, init = :true))
end
""")
end

return esc(expr)
end
Expand All @@ -196,7 +196,7 @@ Returns a vector with inbounds that correspond with required interfaces.
"""
function collectNaiveVariationalNodeInbounds(::FactorNode, entry::ScheduleEntry)
target_to_marginal_entry = current_inference_algorithm.target_to_marginal_entry

inbounds = Any[]
for node_interface in entry.interface.node.interfaces
inbound_interface = ultimatePartner(node_interface)
Expand All @@ -213,4 +213,4 @@ function collectNaiveVariationalNodeInbounds(::FactorNode, entry::ScheduleEntry)
end

return inbounds
end
end
Loading