From 145db5ef950baf879c78312fd13033a48da5ec97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Sun, 7 Apr 2024 01:46:38 +0300 Subject: [PATCH] feat: add initial implementation --- .JuliaFormatter.toml | 5 +- .gitignore | 1 + Project.toml | 37 ++++++++++++- README.md | 2 + src/UDEComponents.jl | 44 ++++++++++++++- src/hacks.jl | 4 ++ src/utils.jl | 25 +++++++++ test/lotka_volterra.jl | 121 +++++++++++++++++++++++++++++++++++++++++ test/qa.jl | 20 +++++++ test/runtests.jl | 14 ++--- 10 files changed, 260 insertions(+), 13 deletions(-) create mode 100644 src/hacks.jl create mode 100644 src/utils.jl create mode 100644 test/lotka_volterra.jl create mode 100644 test/qa.jl diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 01bfab9..3e91a6c 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1 +1,4 @@ -# See https://domluna.github.io/JuliaFormatter.jl/stable/ for a list of options +style = "sciml" +format_markdown = true +format_docstrings = true +annotate_untyped_fields_with_any = false diff --git a/.gitignore b/.gitignore index 95731a5..5e276ba 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ /Manifest.toml /docs/Manifest.toml /docs/build/ +.vscode diff --git a/Project.toml b/Project.toml index d088727..068e8a8 100644 --- a/Project.toml +++ b/Project.toml @@ -3,13 +3,48 @@ uuid = "f162e290-f571-43a6-83d9-22ecc16da15f" authors = ["Sebastian Micluța-Câmpeanu and contributors"] version = "1.0.0-DEV" +[deps] +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" +ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" + [compat] +Aqua = "0.8" +ComponentArrays = "0.15" +ForwardDiff = "0.10.36" +JET = "0.8" +Lux = "0.5.32" +LuxCore = "0.1.14" +ModelingToolkit = "9.9.0" +ModelingToolkitStandardLibrary = "2.6" +NNlib = "0.9" +Optimization = "3.22" +OptimizationOptimisers = "0.2" +OrdinaryDiffEq = "6.74" +Random = "1" +SafeTestsets = "0.1" +SciMLStructures = "1.1.0" +SymbolicIndexingInterface = "0.3.15" +Symbolics = "5.27" +Test = "1" julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" +OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "JET", "Test"] +test = ["Aqua", "JET", "Test", "OrdinaryDiffEq", "ForwardDiff", "Optimization", "OptimizationOptimisers", "SafeTestsets", "SciMLStructures", "SymbolicIndexingInterface"] diff --git a/README.md b/README.md index 97857a2..971a1bb 100644 --- a/README.md +++ b/README.md @@ -6,3 +6,5 @@ [![Coverage](https://codecov.io/gh/SebastianM-C/UDEComponents.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/SebastianM-C/UDEComponents.jl) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) + +## Build UDEs with ModelingToolkit diff --git a/src/UDEComponents.jl b/src/UDEComponents.jl index b935cd6..99347fd 100644 --- a/src/UDEComponents.jl +++ b/src/UDEComponents.jl @@ -1,5 +1,47 @@ module UDEComponents -# Write your package code here. +using ModelingToolkit: @parameters, @named, ODESystem, t_nounits +using ModelingToolkitStandardLibrary.Blocks: RealInput, RealOutput +using Symbolics: Symbolics, @register_array_symbolic, @wrapped +using LuxCore: stateless_apply +using Lux: Lux +using Random: Xoshiro +using NNlib: softplus +using ComponentArrays: ComponentArray + +export create_ude_component, multi_layer_feed_forward + +include("utils.jl") +include("hacks.jl") # this should be removed / upstreamed + +""" + + create_ude_component(n_input = 1, n_output = 1; + chain = multi_layer_feed_forward(n_input, n_output), + rng = Xoshiro(0)) + +Create an `ODESystem` with a neural network inside. +""" +function create_ude_component(n_input = 1, + n_output = 1; + chain = multi_layer_feed_forward(n_input, n_output), + rng = Xoshiro(0)) + lux_p, st = Lux.setup(rng, chain) + ca = ComponentArray(lux_p) + + @parameters p[1:length(ca)] = Vector(ca) + @parameters T::typeof(typeof(p))=typeof(p) [tunable = false] + + @named input = RealInput(nin = n_input) + @named output = RealOutput(nout = n_output) + + out = stateless_apply(chain, input.u, lazyconvert(typeof(ca), p)) + + eqs = [output.u ~ out] + + @named ude_comp = ODESystem( + eqs, t_nounits, [], [p, T], systems = [input, output]) + return ude_comp +end end diff --git a/src/hacks.jl b/src/hacks.jl new file mode 100644 index 0000000..3f5c568 --- /dev/null +++ b/src/hacks.jl @@ -0,0 +1,4 @@ +lazyconvert(x, y) = convert(x, y) +lazyconvert(x, y::Symbolics.Arr) = Symbolics.array_term(convert, x, y) +Symbolics.propagate_ndims(::typeof(convert), x, y) = ndims(y) +Symbolics.propagate_shape(::typeof(convert), x, y) = Symbolics.shape(y) diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..58f59e2 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,25 @@ +function multi_layer_feed_forward(input_length, output_length; width::Int = 5, + depth::Int = 1, activation = softplus) + Lux.Chain(Lux.Dense(input_length, width, activation), + [Lux.Dense(width, width, activation) for _ in 1:(depth)]..., + Lux.Dense(width, output_length); disable_optimizations = true) +end + +# Symbolics.@register_array_symbolic print_input(x) begin +# size = size(x) +# eltype = eltype(x) +# end + +# function print_input(x) +# @info x +# x +# end + +# function debug_component(n_input, n_output) +# @named input = RealInput(nin = n_input) +# @named output = RealOutput(nout = n_output) + +# eqs = [output.u ~ print_input(input.u)] + +# @named dbg_comp = ODESystem(eqs, t_nounits, [], [], systems = [input, output]) +# end diff --git a/test/lotka_volterra.jl b/test/lotka_volterra.jl new file mode 100644 index 0000000..154f56e --- /dev/null +++ b/test/lotka_volterra.jl @@ -0,0 +1,121 @@ +using Test +using JET +using UDEComponents +using ModelingToolkit +using ModelingToolkitStandardLibrary.Blocks +using OrdinaryDiffEq +using SymbolicIndexingInterface +using Optimization +using OptimizationOptimisers: Adam +using SciMLStructures +using SciMLStructures: Tunable +using ForwardDiff + +function lotka_ude() + @variables t x(t)=3.1 y(t)=1.5 + @parameters α=1.3 β=0.9 γ=0.8 δ=1.8 + Dt = ModelingToolkit.D_nounits + @named nn_in = RealInput(nin = 2) + @named nn_out = RealOutput(nout = 2) + + eqs = [ + Dt(x) ~ α * x + nn_in.u[1], + Dt(y) ~ -δ * y + nn_in.u[2], + nn_out.u[1] ~ x, + nn_out.u[2] ~ y + ] + return ODESystem( + eqs, ModelingToolkit.t_nounits, name = :lotka, systems = [nn_in, nn_out]) +end + +function lotka_true() + @variables t x(t)=3.1 y(t)=1.5 + @parameters α=1.3 β=0.9 γ=0.8 δ=1.8 + Dt = ModelingToolkit.D_nounits + + eqs = [ + Dt(x) ~ α * x - β * x * y, + Dt(y) ~ -δ * y + δ * x * y + ] + return ODESystem(eqs, ModelingToolkit.t_nounits, name = :lotka_true) +end + +model = lotka_ude() +nn = create_ude_component(2, 2) + +eqs = [ + connect(model.nn_in, nn.output) + connect(model.nn_out, nn.input) +] + +ude_sys = complete(ODESystem( + eqs, ModelingToolkit.t_nounits, systems = [model, nn], name = :ude_sys)) + +sys = structural_simplify(ude_sys) + +prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 1.0), []) + +model_true = structural_simplify(lotka_true()) +prob_true = ODEProblem{true, SciMLBase.FullSpecialize}(model_true, [], (0, 1.0), []) +sol_ref = solve(prob_true, Rodas4()) + +x0 = reduce(vcat, getindex.((default_values(sys),), tunable_parameters(sys))) + +get_vars = getu(sys, [sys.lotka.x, sys.lotka.y]) +get_refs = getu(model_true, [model_true.x, model_true.y]) + +function loss(x, (prob, sol_ref, get_vars, get_refs)) + new_p = SciMLStructures.replace(Tunable(), prob.p, x) + new_prob = remake(prob, p = new_p) + ts = sol_ref.t + new_sol = solve(new_prob, Rodas4(), saveat = ts) + + loss = zero(eltype(x)) + + for i in eachindex(new_sol.u) + loss += sum(sqrt.(abs2.(get_vars(new_sol, i) .- get_refs(sol_ref, i)))) + end + + if SciMLBase.successful_retcode(new_sol) + loss + else + Inf + end +end + + +of = OptimizationFunction{true}(loss, AutoForwardDiff()) + +ps = (prob, sol_ref, get_vars, get_refs); + +@test_call target_modules=(UDEComponents,) loss(x0, ps) +@test_opt target_modules=(UDEComponents,) loss(x0, ps) + +@test all(.!isnan.(ForwardDiff.gradient(Base.Fix2(of, ps), x0))) + +op = OptimizationProblem(of, x0, (prob, sol_ref, get_vars, get_refs)) + + +# using Plots + +# oh = [] + +plot_cb = (opt_state, loss) -> begin + @info "step $(opt_state.iter), loss: $loss" + # push!(oh, opt_state) + # new_p = SciMLStructures.replace(Tunable(), prob.p, opt_state.u) + # new_prob = remake(prob, p = new_p) + # sol = solve(new_prob, Rodas4()) + # display(plot(sol)) + false +end + +res = solve(op, Adam(), maxiters = 2000)#, callback = plot_cb) + +@test res.objective < 1 + +res_p = SciMLStructures.replace(Tunable(), prob.p, res) +res_prob = remake(prob, p = res_p) +res_sol = solve(res_prob, Rodas4()) + +@test SciMLBase.successful_retcode(res_sol) diff --git a/test/qa.jl b/test/qa.jl new file mode 100644 index 0000000..002f0bd --- /dev/null +++ b/test/qa.jl @@ -0,0 +1,20 @@ +using Test +using UDEComponents +using Aqua +using JET + +@testset verbose = true "Code quality (Aqua.jl)" begin + Aqua.find_persistent_tasks_deps(UDEComponents) + Aqua.test_ambiguities(UDEComponents, recursive = false) + Aqua.test_deps_compat(UDEComponents) + # TODO: fix type piracy in propagate_ndims and propagate_shape + Aqua.test_piracies(UDEComponents, broken=true) + Aqua.test_project_extras(UDEComponents) + Aqua.test_stale_deps(UDEComponents, ignore = Symbol[]) + Aqua.test_unbound_args(UDEComponents) + Aqua.test_undefined_exports(UDEComponents) +end + +@testset "Code linting (JET.jl)" begin + JET.test_package(UDEComponents; target_defined_modules = true) +end diff --git a/test/runtests.jl b/test/runtests.jl index fb8db63..7391e01 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,14 +1,8 @@ using UDEComponents using Test -using Aqua -using JET +using SafeTestsets -@testset "UDEComponents.jl" begin - @testset "Code quality (Aqua.jl)" begin - Aqua.test_all(UDEComponents) - end - @testset "Code linting (JET.jl)" begin - JET.test_package(UDEComponents; target_defined_modules = true) - end - # Write your tests here. +@testset verbose=true "UDEComponents.jl" begin + @safetestset "QA" include("qa.jl") + @safetestset "Basic" include("lotka_volterra.jl") end