Skip to content

Commit

Permalink
feat: add initial implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianM-C committed Apr 6, 2024
1 parent 0bd6e2e commit 145db5e
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 13 deletions.
5 changes: 4 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
# See https://domluna.github.io/JuliaFormatter.jl/stable/ for a list of options
style = "sciml"
format_markdown = true
format_docstrings = true
annotate_untyped_fields_with_any = false
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
/Manifest.toml
/docs/Manifest.toml
/docs/build/
.vscode
37 changes: 36 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,48 @@ uuid = "f162e290-f571-43a6-83d9-22ecc16da15f"
authors = ["Sebastian Micluța-Câmpeanu <[email protected]> and contributors"]
version = "1.0.0-DEV"

[deps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[compat]
Aqua = "0.8"
ComponentArrays = "0.15"
ForwardDiff = "0.10.36"
JET = "0.8"
Lux = "0.5.32"
LuxCore = "0.1.14"
ModelingToolkit = "9.9.0"
ModelingToolkitStandardLibrary = "2.6"
NNlib = "0.9"
Optimization = "3.22"
OptimizationOptimisers = "0.2"
OrdinaryDiffEq = "6.74"
Random = "1"
SafeTestsets = "0.1"
SciMLStructures = "1.1.0"
SymbolicIndexingInterface = "0.3.15"
Symbolics = "5.27"
Test = "1"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "JET", "Test"]
test = ["Aqua", "JET", "Test", "OrdinaryDiffEq", "ForwardDiff", "Optimization", "OptimizationOptimisers", "SafeTestsets", "SciMLStructures", "SymbolicIndexingInterface"]
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
[![Coverage](https://codecov.io/gh/SebastianM-C/UDEComponents.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/SebastianM-C/UDEComponents.jl)
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
[![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)

## Build UDEs with ModelingToolkit
44 changes: 43 additions & 1 deletion src/UDEComponents.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,47 @@
module UDEComponents

# Write your package code here.
using ModelingToolkit: @parameters, @named, ODESystem, t_nounits
using ModelingToolkitStandardLibrary.Blocks: RealInput, RealOutput
using Symbolics: Symbolics, @register_array_symbolic, @wrapped
using LuxCore: stateless_apply
using Lux: Lux
using Random: Xoshiro
using NNlib: softplus
using ComponentArrays: ComponentArray

export create_ude_component, multi_layer_feed_forward

include("utils.jl")
include("hacks.jl") # this should be removed / upstreamed

"""
create_ude_component(n_input = 1, n_output = 1;
chain = multi_layer_feed_forward(n_input, n_output),
rng = Xoshiro(0))
Create an `ODESystem` with a neural network inside.
"""
function create_ude_component(n_input = 1,
n_output = 1;
chain = multi_layer_feed_forward(n_input, n_output),
rng = Xoshiro(0))
lux_p, st = Lux.setup(rng, chain)
ca = ComponentArray(lux_p)

@parameters p[1:length(ca)] = Vector(ca)
@parameters T::typeof(typeof(p))=typeof(p) [tunable = false]

@named input = RealInput(nin = n_input)
@named output = RealOutput(nout = n_output)

out = stateless_apply(chain, input.u, lazyconvert(typeof(ca), p))

eqs = [output.u ~ out]

@named ude_comp = ODESystem(
eqs, t_nounits, [], [p, T], systems = [input, output])
return ude_comp
end

end
4 changes: 4 additions & 0 deletions src/hacks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
lazyconvert(x, y) = convert(x, y)
lazyconvert(x, y::Symbolics.Arr) = Symbolics.array_term(convert, x, y)
Symbolics.propagate_ndims(::typeof(convert), x, y) = ndims(y)
Symbolics.propagate_shape(::typeof(convert), x, y) = Symbolics.shape(y)
25 changes: 25 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
function multi_layer_feed_forward(input_length, output_length; width::Int = 5,
depth::Int = 1, activation = softplus)
Lux.Chain(Lux.Dense(input_length, width, activation),
[Lux.Dense(width, width, activation) for _ in 1:(depth)]...,
Lux.Dense(width, output_length); disable_optimizations = true)
end

# Symbolics.@register_array_symbolic print_input(x) begin
# size = size(x)
# eltype = eltype(x)
# end

# function print_input(x)
# @info x
# x
# end

# function debug_component(n_input, n_output)
# @named input = RealInput(nin = n_input)
# @named output = RealOutput(nout = n_output)

# eqs = [output.u ~ print_input(input.u)]

# @named dbg_comp = ODESystem(eqs, t_nounits, [], [], systems = [input, output])
# end
121 changes: 121 additions & 0 deletions test/lotka_volterra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
using Test
using JET
using UDEComponents
using ModelingToolkit
using ModelingToolkitStandardLibrary.Blocks
using OrdinaryDiffEq
using SymbolicIndexingInterface
using Optimization
using OptimizationOptimisers: Adam
using SciMLStructures
using SciMLStructures: Tunable
using ForwardDiff

function lotka_ude()
@variables t x(t)=3.1 y(t)=1.5
@parameters α=1.3 β=0.9 γ=0.8 δ=1.8
Dt = ModelingToolkit.D_nounits
@named nn_in = RealInput(nin = 2)
@named nn_out = RealOutput(nout = 2)

eqs = [
Dt(x) ~ α * x + nn_in.u[1],
Dt(y) ~ -δ * y + nn_in.u[2],
nn_out.u[1] ~ x,
nn_out.u[2] ~ y
]
return ODESystem(
eqs, ModelingToolkit.t_nounits, name = :lotka, systems = [nn_in, nn_out])
end

function lotka_true()
@variables t x(t)=3.1 y(t)=1.5
@parameters α=1.3 β=0.9 γ=0.8 δ=1.8
Dt = ModelingToolkit.D_nounits

eqs = [
Dt(x) ~ α * x - β * x * y,
Dt(y) ~ -δ * y + δ * x * y
]
return ODESystem(eqs, ModelingToolkit.t_nounits, name = :lotka_true)
end

model = lotka_ude()
nn = create_ude_component(2, 2)

eqs = [
connect(model.nn_in, nn.output)
connect(model.nn_out, nn.input)
]

ude_sys = complete(ODESystem(
eqs, ModelingToolkit.t_nounits, systems = [model, nn], name = :ude_sys))

sys = structural_simplify(ude_sys)

prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 1.0), [])

model_true = structural_simplify(lotka_true())
prob_true = ODEProblem{true, SciMLBase.FullSpecialize}(model_true, [], (0, 1.0), [])
sol_ref = solve(prob_true, Rodas4())

x0 = reduce(vcat, getindex.((default_values(sys),), tunable_parameters(sys)))

get_vars = getu(sys, [sys.lotka.x, sys.lotka.y])
get_refs = getu(model_true, [model_true.x, model_true.y])

function loss(x, (prob, sol_ref, get_vars, get_refs))
new_p = SciMLStructures.replace(Tunable(), prob.p, x)
new_prob = remake(prob, p = new_p)
ts = sol_ref.t
new_sol = solve(new_prob, Rodas4(), saveat = ts)

loss = zero(eltype(x))

for i in eachindex(new_sol.u)
loss += sum(sqrt.(abs2.(get_vars(new_sol, i) .- get_refs(sol_ref, i))))
end

if SciMLBase.successful_retcode(new_sol)
loss
else
Inf
end
end


of = OptimizationFunction{true}(loss, AutoForwardDiff())

ps = (prob, sol_ref, get_vars, get_refs);

@test_call target_modules=(UDEComponents,) loss(x0, ps)
@test_opt target_modules=(UDEComponents,) loss(x0, ps)

@test all(.!isnan.(ForwardDiff.gradient(Base.Fix2(of, ps), x0)))

op = OptimizationProblem(of, x0, (prob, sol_ref, get_vars, get_refs))


# using Plots

# oh = []

plot_cb = (opt_state, loss) -> begin
@info "step $(opt_state.iter), loss: $loss"
# push!(oh, opt_state)
# new_p = SciMLStructures.replace(Tunable(), prob.p, opt_state.u)
# new_prob = remake(prob, p = new_p)
# sol = solve(new_prob, Rodas4())
# display(plot(sol))
false
end

res = solve(op, Adam(), maxiters = 2000)#, callback = plot_cb)

@test res.objective < 1

res_p = SciMLStructures.replace(Tunable(), prob.p, res)
res_prob = remake(prob, p = res_p)
res_sol = solve(res_prob, Rodas4())

@test SciMLBase.successful_retcode(res_sol)
20 changes: 20 additions & 0 deletions test/qa.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using Test
using UDEComponents
using Aqua
using JET

@testset verbose = true "Code quality (Aqua.jl)" begin
Aqua.find_persistent_tasks_deps(UDEComponents)
Aqua.test_ambiguities(UDEComponents, recursive = false)
Aqua.test_deps_compat(UDEComponents)
# TODO: fix type piracy in propagate_ndims and propagate_shape
Aqua.test_piracies(UDEComponents, broken=true)
Aqua.test_project_extras(UDEComponents)
Aqua.test_stale_deps(UDEComponents, ignore = Symbol[])
Aqua.test_unbound_args(UDEComponents)
Aqua.test_undefined_exports(UDEComponents)
end

@testset "Code linting (JET.jl)" begin
JET.test_package(UDEComponents; target_defined_modules = true)
end
14 changes: 4 additions & 10 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
using UDEComponents
using Test
using Aqua
using JET
using SafeTestsets

@testset "UDEComponents.jl" begin
@testset "Code quality (Aqua.jl)" begin
Aqua.test_all(UDEComponents)
end
@testset "Code linting (JET.jl)" begin
JET.test_package(UDEComponents; target_defined_modules = true)
end
# Write your tests here.
@testset verbose=true "UDEComponents.jl" begin
@safetestset "QA" include("qa.jl")
@safetestset "Basic" include("lotka_volterra.jl")
end

0 comments on commit 145db5e

Please sign in to comment.