Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can no longer pass extra variables from loss function to callback #835

Open
John-Boik opened this issue Sep 27, 2024 · 12 comments
Open

Can no longer pass extra variables from loss function to callback #835

John-Boik opened this issue Sep 27, 2024 · 12 comments
Labels
bug Something isn't working

Comments

@John-Boik
Copy link

On a new install of Julia, my previously working neural SDE code no longer allows extra variables to be passed from the loss function to the callback function. The documentation says that this should still be possible. The following is test code from the optimization test library. It runs if unmodified, but fails if I try to pass extra variables from loss function to callback. Is there a new way to accomplish the passing of extra variables?

module Test

using Optimization, OptimizationOptimisers, DiffEqFlux.Lux, Zygote, MLUtils, Random,
          ComponentArrays

x = rand(10000)
y = sin.(x)
data = MLUtils.DataLoader((x, y), batchsize = 100)

# Define the neural network
model = Chain(Dense(1, 32, tanh), Dense(32, 1))
ps, st = Lux.setup(Random.default_rng(), model)
ps_ca = ComponentArray(ps)
smodel = StatefulLuxLayer{true}(model, nothing, st)

function callback(state, l, extra)  # this fails
#function callback(state, l)  # this works
    state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l
    return l < 1e-4
end

function loss(ps, data)
    ypred = [smodel([data[1][i]], ps)[1] for i in eachindex(data[1])]
    extra = 5
    return sum(abs2, ypred .- data[2]), extra   # this fails
    #return sum(abs2, ypred .- data[2])  # this works
end

optf = OptimizationFunction(loss, AutoZygote())
prob = OptimizationProblem(optf, ps_ca, data)

res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 10000)

end  # --module

Error & Stacktrace ⚠️
With the modified code, the error is:

ERROR: LoadError: MethodError: no method matching callback(::Optimization.OptimizationState{…}, ::Float64)

Closest candidates are:
  callback(::Any, ::Any, ::Any)
   @ Main.Test ~/Devel/GridSim/DS.GridSim/julia/gridsim/src/Xfmr/SMUD/NCSDE1/test.jl:19

Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/OptimizationOptimisers/864RB/src/OptimizationOptimisers.jl:101 [inlined]
 [2] macro expansion
   @ ~/.julia/packages/Optimization/bmAND/src/utils.jl:32 [inlined]
 [3] __solve(cache::OptimizationBase.OptimizationCache{…})
   @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/864RB/src/OptimizationOptimisers.jl:82
 [4] solve!(cache::OptimizationBase.OptimizationCache{…})
   @ SciMLBase ~/.julia/packages/SciMLBase/SQnVC/src/solve.jl:188
 [5] solve(::SciMLBase.OptimizationProblem{…}, ::Optimisers.Adam; kwargs::@Kwargs{})
   @ SciMLBase ~/.julia/packages/SciMLBase/SQnVC/src/solve.jl:96

Environment (please complete the following information):

  • Output of using Pkg; Pkg.status()
  [336ed68f] CSV v0.10.14
  [b0b7db55] ComponentArrays v0.15.17
  [2569d6c7] ConcreteStructs v0.2.3
  [a93c6f00] DataFrames v1.7.0
  [82cc6244] DataInterpolations v6.4.1
  [2b5f629d] DiffEqBase v6.155.3 `~/.julia/dev/DiffEqBase`
⌃ [aae7a2af] DiffEqFlux v3.5.1
  [31c24e10] Distributions v0.25.112
⌃ [7da242da] Enzyme v0.12.36
  [1fa38f19] Format v1.3.7
  [c27321d9] Glob v1.3.1
  [5903a43b] Infiltrator v1.8.3
  [842dd82b] InlineStrings v1.4.2
  [c8e1da08] IterTools v1.10.0
  [f1d291b0] MLUtils v0.4.4
  [7f7a1694] Optimization v4.0.2
  [42dfb2eb] OptimizationOptimisers v0.3.2
  [0bf8acf5] PPInterpolation v0.7.0
  [91a5bcdd] Plots v1.40.8
  [295af30f] Revise v3.6.0
  [0aa819cd] SQLite v1.6.1
  [0bca4576] SciMLBase v2.54.1
  [2913bbd2] StatsBase v0.34.3
  [789caeaf] StochasticDiffEq v6.69.1
  [f269a46b] TimeZones v1.18.1
  [e88e6eb3] Zygote v0.6.71
  [ade2ca70] Dates
  [9a3f8284] Random
  [9e88b42a] Serialization
  [10745b16] Statistics v1.10.0
  • Output of using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
  [47edcb42] ADTypes v1.9.0
  [621f4979] AbstractFFTs v1.5.0
  [1520ce14] AbstractTrees v0.4.5
  [7d9f7c33] Accessors v0.1.38
  [79e6a3ab] Adapt v4.0.4
  [66dad0bd] AliasTables v1.1.3
  [dce04be8] ArgCheck v2.3.0
  [ec485272] ArnoldiMethod v0.4.0
  [4fba245c] ArrayInterface v7.16.0
  [4c555306] ArrayLayouts v1.10.3
  [a9b6321e] Atomix v0.1.0
  [198e06fe] BangBang v0.4.3
  [9718e550] Baselet v0.1.1
  [6e4b80f9] BenchmarkTools v1.5.0
  [e2ed5e7c] Bijections v0.1.9
  [d1d4a3ce] BitFlags v0.1.9
  [62783981] BitTwiddlingConvenienceFunctions v0.1.6
  [fa961155] CEnum v0.5.0
  [2a0fbf3d] CPUSummary v0.2.6
  [336ed68f] CSV v0.10.14
  [7057c7e9] Cassette v0.3.13
  [082447d4] ChainRules v1.71.0
  [d360d2e6] ChainRulesCore v1.25.0
  [fb6a15b2] CloseOpenIntervals v0.1.13
  [da1fd8a2] CodeTracking v1.3.6
  [944b1d66] CodecZlib v0.7.6
  [35d6a980] ColorSchemes v3.26.0
  [3da002f7] ColorTypes v0.11.5
  [c3611d14] ColorVectorSpace v0.10.0
  [5ae59095] Colors v0.12.11
  [861a8166] Combinatorics v1.0.2
  [38540f10] CommonSolve v0.2.4
  [bbf7d656] CommonSubexpressions v0.3.1
  [f70d9fcc] CommonWorldInvalidations v1.0.0
  [34da2185] Compat v4.16.0
  [b0b7db55] ComponentArrays v0.15.17
  [b152e2b5] CompositeTypes v0.1.4
  [a33af91c] CompositionsBase v0.1.2
  [2569d6c7] ConcreteStructs v0.2.3
  [f0e56b4a] ConcurrentUtilities v2.4.2
  [88cd18e8] ConsoleProgressMonitor v0.1.2
  [187b0558] ConstructionBase v1.5.8
  [6add18c4] ContextVariablesX v0.1.3
  [d38c429a] Contour v0.6.3
  [adafc99b] CpuId v0.3.1
  [a8cc5b0e] Crayons v4.1.1
  [a10d1c49] DBInterface v2.6.1
  [717857b8] DSP v0.7.10
  [9a962f9c] DataAPI v1.16.0
  [a93c6f00] DataFrames v1.7.0
  [82cc6244] DataInterpolations v6.4.1
  [864edb3b] DataStructures v0.18.20
  [e2d170a0] DataValueInterfaces v1.0.0
  [244e2a9f] DefineSingletons v0.1.2
  [8bb1440f] DelimitedFiles v1.9.1
  [85a47980] Dictionaries v0.4.2
  [2b5f629d] DiffEqBase v6.155.3 `~/.julia/dev/DiffEqBase`
  [459566f4] DiffEqCallbacks v3.9.1
⌃ [aae7a2af] DiffEqFlux v3.5.1
  [77a26b50] DiffEqNoiseProcess v5.23.0
  [163ba53b] DiffResults v1.1.0
  [b552c78f] DiffRules v1.15.1
⌅ [a0c0ee7d] DifferentiationInterface v0.5.17
  [8d63f2c5] DispatchDoctor v0.4.15
  [b4f34e82] Distances v0.10.11
  [31c24e10] Distributions v0.25.112
  [ced4e74d] DistributionsAD v0.6.55
  [ffbed154] DocStringExtensions v0.9.3
  [5b8099bc] DomainSets v0.7.14
  [7c1d4256] DynamicPolynomials v0.6.0
  [b7d42ee7] Einsum v0.4.1
  [da5c29d0] EllipsisNotation v1.8.0
  [4e289a0a] EnumX v1.0.4
⌃ [7da242da] Enzyme v0.12.36
⌅ [f151be2c] EnzymeCore v0.7.8
  [460bff9d] ExceptionUnwrapping v0.1.10
  [d4d017d3] ExponentialUtilities v1.26.1
  [e2ba6199] ExprTools v0.1.10
⌅ [6b7a57c9] Expronicon v0.8.5
  [c87230d0] FFMPEG v0.4.2
  [7a1cc6ca] FFTW v1.8.0
  [cc61a311] FLoops v0.2.2
  [b9860ae5] FLoopsBase v0.1.1
  [7034ab61] FastBroadcast v0.3.5
  [9aa1b823] FastClosures v0.3.2
  [29a986be] FastLapackInterface v2.0.4
  [48062228] FilePathsBase v0.9.22
  [1a297f60] FillArrays v1.13.0
  [64ca27bc] FindFirstFunctions v1.4.1
  [6a86dc24] FiniteDiff v2.24.0
  [53c48c17] FixedPointNumbers v0.8.5
  [1fa38f19] Format v1.3.7
  [f6369f11] ForwardDiff v0.10.36
  [f62d2435] FunctionProperties v0.1.2
  [069b7b12] FunctionWrappers v1.1.3
  [77dc65aa] FunctionWrappersWrappers v0.1.3
  [d9f16b24] Functors v0.4.12
  [0c68f7d7] GPUArrays v10.3.1
  [46192b85] GPUArraysCore v0.1.6
⌃ [61eb1bfa] GPUCompiler v0.27.5
  [28b8d3ca] GR v0.73.7
  [c145ed77] GenericSchur v0.5.4
  [c27321d9] Glob v1.3.1
  [86223c79] Graphs v1.11.2
  [42e2da0e] Grisu v1.0.2
  [cd3eb016] HTTP v1.10.8
  [3e5b6fbb] HostCPUFeatures v0.1.17
  [0e44f5e4] Hwloc v3.3.0
  [1baab800] HybridArrays v0.4.16
  [34004b35] HypergeometricFunctions v0.3.24
  [7869d1d1] IRTools v0.4.14
  [615f187c] IfElse v0.1.1
  [313cdc1a] Indexing v1.1.1
  [5903a43b] Infiltrator v1.8.3
  [d25df0c9] Inflate v0.1.5
  [22cec73e] InitialValues v0.3.1
  [842dd82b] InlineStrings v1.4.2
  [18e54dd8] IntegerMathUtils v0.1.2
  [8197267c] IntervalSets v0.7.10
  [3587e190] InverseFunctions v0.1.17
  [41ab1584] InvertedIndices v1.3.0
  [92d709cd] IrrationalConstants v0.2.2
  [c8e1da08] IterTools v1.10.0
  [82899510] IteratorInterfaceExtensions v1.0.0
  [1019f520] JLFzf v0.1.8
  [692b3bcd] JLLWrappers v1.6.0
  [682c06a0] JSON v0.21.4
  [aa1ae85d] JuliaInterpreter v0.9.36
  [b14d175d] JuliaVariables v0.2.4
  [ccbc3e58] JumpProcesses v9.13.7
  [ef3ab10e] KLU v0.6.0
⌃ [63c18a36] KernelAbstractions v0.9.26
  [2c470bb0] Kronecker v0.5.5
  [ba0b0d4f] Krylov v0.9.6
  [5be7bae1] LBFGSB v0.4.1
⌅ [929cbde3] LLVM v9.0.0
  [b964fa9f] LaTeXStrings v1.3.1
  [984bce1d] LambertW v0.4.6
  [23fbe1c1] Latexify v0.16.5
  [10f19ff3] LayoutPointers v0.1.17
⌃ [5078a376] LazyArrays v1.10.0
  [1d6d02ad] LeftChildRightSiblingTrees v0.2.0
  [2d8b4e74] LevyArea v1.0.0
  [d3d80556] LineSearches v7.3.0
  [7a12625a] LinearMaps v3.11.3
⌃ [7ed4a6bd] LinearSolve v2.34.0
  [2ab3a3ac] LogExpFunctions v0.3.28
  [e6f89c97] LoggingExtras v1.0.3
  [bdcacae8] LoopVectorization v0.12.171
  [30fc2ffe] LossFunctions v0.11.2
  [6f1432cf] LoweredCodeUtils v3.0.2
⌅ [b2108857] Lux v0.5.68
⌅ [bb33d45b] LuxCore v0.1.25
  [34f89e08] LuxDeviceUtils v0.1.27
⌅ [82251201] LuxLib v0.3.51
  [7e8f7934] MLDataDevices v1.1.1
  [d8e11817] MLStyle v0.4.17
  [f1d291b0] MLUtils v0.4.4
  [1914dd2f] MacroTools v0.5.13
  [af67fdf4] ManifoldDiff v0.3.12
⌅ [1cead3c2] Manifolds v0.9.20
  [3362f125] ManifoldsBase v0.15.16
  [d125e4d3] ManualMemory v0.1.8
  [99c1a7ee] MatrixEquations v2.4.2
⌅ [a3b82374] MatrixFactorizations v2.2.0
  [bb5d69b7] MaybeInplace v0.1.4
  [739be429] MbedTLS v1.1.9
  [442fdcdd] Measures v0.3.2
  [128add7d] MicroCollections v0.2.0
  [e1d29d7a] Missings v1.2.0
  [78c3b35d] Mocking v0.8.1
  [46d2c3a1] MuladdMacro v0.2.4
  [102ac46a] MultivariatePolynomials v0.5.6
  [d8a4904e] MutableArithmetics v1.5.0
  [d41bc354] NLSolversBase v7.8.3
  [2774e3e8] NLsolve v4.5.1
  [872c559c] NNlib v0.9.24
  [77ba4419] NaNMath v1.0.2
  [71a1bf82] NameResolution v0.1.5
  [356022a1] NamedDims v1.2.2
  [8913a72c] NonlinearSolve v3.14.0
  [d8793406] ObjectFile v0.4.2
  [6fd5a793] Octavian v0.3.28
  [6fe1bfb0] OffsetArrays v1.14.1
  [4d8831e6] OpenSSL v1.4.3
  [429524aa] Optim v1.9.4
  [3bd65402] Optimisers v0.3.3
  [7f7a1694] Optimization v4.0.2
  [bca83a33] OptimizationBase v2.0.4
  [42dfb2eb] OptimizationOptimisers v0.3.2
  [bac558e1] OrderedCollections v1.6.3
  [1dea7af3] OrdinaryDiffEq v6.89.0
  [89bda076] OrdinaryDiffEqAdamsBashforthMoulton v1.1.0
  [6ad6398a] OrdinaryDiffEqBDF v1.1.2
  [bbf590c4] OrdinaryDiffEqCore v1.6.0
  [50262376] OrdinaryDiffEqDefault v1.1.0
  [4302a76b] OrdinaryDiffEqDifferentiation v1.1.0
  [9286f039] OrdinaryDiffEqExplicitRK v1.1.0
  [e0540318] OrdinaryDiffEqExponentialRK v1.1.0
  [becaefa8] OrdinaryDiffEqExtrapolation v1.1.0
  [5960d6e9] OrdinaryDiffEqFIRK v1.1.1
  [101fe9f7] OrdinaryDiffEqFeagin v1.1.0
  [d3585ca7] OrdinaryDiffEqFunctionMap v1.1.1
  [d28bc4f8] OrdinaryDiffEqHighOrderRK v1.1.0
  [9f002381] OrdinaryDiffEqIMEXMultistep v1.1.0
  [521117fe] OrdinaryDiffEqLinear v1.1.0
  [1344f307] OrdinaryDiffEqLowOrderRK v1.2.0
  [b0944070] OrdinaryDiffEqLowStorageRK v1.2.1
  [127b3ac7] OrdinaryDiffEqNonlinearSolve v1.2.1
  [c9986a66] OrdinaryDiffEqNordsieck v1.1.0
  [5dd0a6cf] OrdinaryDiffEqPDIRK v1.1.0
  [5b33eab2] OrdinaryDiffEqPRK v1.1.0
  [04162be5] OrdinaryDiffEqQPRK v1.1.0
  [af6ede74] OrdinaryDiffEqRKN v1.1.0
  [43230ef6] OrdinaryDiffEqRosenbrock v1.2.0
  [2d112036] OrdinaryDiffEqSDIRK v1.1.0
  [669c94d9] OrdinaryDiffEqSSPRK v1.2.0
  [e3e12d00] OrdinaryDiffEqStabilizedIRK v1.1.0
  [358294b1] OrdinaryDiffEqStabilizedRK v1.1.0
  [fa646aed] OrdinaryDiffEqSymplecticRK v1.1.0
  [b1df2697] OrdinaryDiffEqTsit5 v1.1.0
  [79d7bb75] OrdinaryDiffEqVerner v1.1.1
  [90014a1f] PDMats v0.11.31
  [0bf8acf5] PPInterpolation v0.7.0
  [65ce6f38] PackageExtensionCompat v1.0.2
  [d96e819e] Parameters v0.12.3
  [69de0a69] Parsers v2.8.1
  [b98c9c47] Pipe v1.3.0
  [ccf2f8ad] PlotThemes v3.2.0
  [995b91a9] PlotUtils v1.4.1
  [91a5bcdd] Plots v1.40.8
  [e409e4f3] PoissonRandom v0.4.4
  [f517fe37] Polyester v0.7.16
  [1d0040c9] PolyesterWeave v0.2.2
  [f27b6e38] Polynomials v4.0.11
  [2dfb63ee] PooledArrays v1.4.3
  [85a6dd25] PositiveFactorizations v0.2.4
  [d236fae5] PreallocationTools v0.4.24
  [aea7be01] PrecompileTools v1.2.1
  [21216c6a] Preferences v1.4.3
  [8162dcfd] PrettyPrint v0.2.0
  [08abe8d2] PrettyTables v2.4.0
  [27ebfcd6] Primes v0.5.6
  [33c8b6b6] ProgressLogging v0.1.4
  [92933f4c] ProgressMeter v1.10.2
  [43287f4e] PtrArrays v1.2.1
  [1fd47b50] QuadGK v2.11.1
  [94ee1d12] Quaternions v0.7.6
  [74087812] Random123 v1.7.0
  [e6cf234a] RandomNumbers v1.6.0
  [c1ae055f] RealDot v0.1.0
  [3cdcf5f2] RecipesBase v1.3.4
  [01d81517] RecipesPipeline v0.6.12
  [731186ca] RecursiveArrayTools v3.27.0
  [f2c3362d] RecursiveFactorization v0.2.23
  [189a3867] Reexport v1.2.2
  [05181044] RelocatableFolders v1.0.1
  [ae029012] Requires v1.3.0
  [ae5879a3] ResettableStacks v1.1.1
  [37e2e3b7] ReverseDiff v1.15.3
  [295af30f] Revise v3.6.0
  [79098fc4] Rmath v0.8.0
  [7e49a35a] RuntimeGeneratedFunctions v0.5.13
  [94e857df] SIMDTypes v0.1.0
  [476501e8] SLEEFPirates v0.6.43
  [0aa819cd] SQLite v1.6.1
  [0bca4576] SciMLBase v2.54.1
  [c0aeaf25] SciMLOperators v0.3.10
  [1ed8b502] SciMLSensitivity v7.68.0
  [53ae85a6] SciMLStructures v1.5.0
  [6c6a2e73] Scratch v1.2.1
  [91c51154] SentinelArrays v1.4.5
  [efcf1570] Setfield v1.1.1
  [605ecd9f] ShowCases v0.1.0
  [992d4aef] Showoff v1.0.3
  [777ac1f9] SimpleBufferStream v1.2.0
⌃ [727e6d20] SimpleNonlinearSolve v1.12.2
  [699a6c99] SimpleTraits v0.9.4
  [ce78b400] SimpleUnPack v1.1.0
  [47aef6b3] SimpleWeightedGraphs v1.4.0
  [a2af1166] SortingAlgorithms v1.2.1
  [9f842d2f] SparseConnectivityTracer v0.6.5
⌃ [47a9eef4] SparseDiffTools v2.20.0
  [dc90abb0] SparseInverseSubset v0.1.2
  [0a514795] SparseMatrixColorings v0.4.1
  [e56a9233] Sparspak v0.3.9
  [276daf66] SpecialFunctions v2.4.0
  [171d559e] SplittablesBase v0.1.15
  [aedffcd0] Static v1.1.1
  [0d7ed370] StaticArrayInterface v1.8.0
  [90137ffa] StaticArrays v1.9.7
  [1e83bf80] StaticArraysCore v1.4.3
  [82ae8749] StatsAPI v1.7.0
  [2913bbd2] StatsBase v0.34.3
  [4c63d2b9] StatsFuns v1.3.2
  [789caeaf] StochasticDiffEq v6.69.1
  [7792a7ef] StrideArraysCore v0.5.7
  [892a3eda] StringManipulation v0.4.0
  [09ab397b] StructArrays v0.6.18
  [53d494c1] StructIO v0.3.1
  [4297ee4d] SymbolicAnalysis v0.3.0
  [2efcf032] SymbolicIndexingInterface v0.3.31
  [19f23fe9] SymbolicLimits v0.2.2
  [d1185830] SymbolicUtils v3.7.1
⌃ [0c5d862f] Symbolics v6.7.0
  [dc5dba14] TZJData v1.3.0+2024b
  [3783bdb8] TableTraits v1.0.1
  [bd369af6] Tables v1.12.0
  [62fd8b95] TensorCore v0.1.1
  [8ea1fca8] TermInterface v2.0.0
  [5d786b92] TerminalLoggers v0.1.7
  [8290d209] ThreadingUtilities v0.5.2
  [f269a46b] TimeZones v1.18.1
  [a759f4b9] TimerOutputs v0.5.24
  [9f7883ad] Tracker v0.2.35
  [3bb67fe8] TranscodingStreams v0.11.2
  [28d57a85] Transducers v0.4.82
  [d5829a12] TriangularSolve v0.2.1
  [410a4b4d] Tricks v0.1.9
  [781d530d] TruncatedStacktraces v1.4.0
  [5c2747f8] URIs v1.5.1
  [3a884ed6] UnPack v1.0.2
  [1cfade01] UnicodeFun v0.4.1
  [1986cc42] Unitful v1.21.0
  [45397f5d] UnitfulLatexify v1.6.4
  [a7c27f48] Unityper v0.1.6
  [0fe1646c] UnrolledUtilities v0.1.5
  [013be700] UnsafeAtomics v0.2.1
  [d80eeb9a] UnsafeAtomicsLLVM v0.2.1
  [41fe7b60] Unzip v0.2.0
  [3d5dd08c] VectorizationBase v0.21.70
  [19fa3120] VertexSafeGraphs v0.2.0
  [ea10d353] WeakRefStrings v1.4.2
  [d49dbf32] WeightInitializers v1.0.3
  [76eceee3] WorkerUtilities v1.6.1
  [e88e6eb3] Zygote v0.6.71
  [700de1a5] ZygoteRules v0.2.5
  [6e34b625] Bzip2_jll v1.0.8+1
  [83423d85] Cairo_jll v1.18.0+2
  [ee1fde0b] Dbus_jll v1.14.10+0
⌅ [7cc45869] Enzyme_jll v0.0.148+0
  [2702e6a9] EpollShim_jll v0.0.20230411+0
  [2e619515] Expat_jll v2.6.2+0
⌅ [b22a6f82] FFMPEG_jll v4.4.4+1
  [f5851436] FFTW_jll v3.3.10+1
  [a3f928ae] Fontconfig_jll v2.13.96+0
  [d7e528f0] FreeType2_jll v2.13.2+0
  [559328eb] FriBidi_jll v1.0.14+0
  [0656b61e] GLFW_jll v3.4.0+1
  [d2c73de3] GR_jll v0.73.7+0
  [78b55507] Gettext_jll v0.21.0+0
  [7746bdde] Glib_jll v2.80.2+0
  [3b182d85] Graphite2_jll v1.3.14+0
  [2e76f6c2] HarfBuzz_jll v8.3.1+0
  [e33a78d0] Hwloc_jll v2.11.2+0
  [1d5cc7b8] IntelOpenMP_jll v2024.2.1+0
  [aacddb02] JpegTurbo_jll v3.0.4+0
  [c1c5ebd0] LAME_jll v3.100.2+0
⌅ [88015f11] LERC_jll v3.0.0+1
⌅ [dad2f222] LLVMExtra_jll v0.0.33+0
  [1d63c593] LLVMOpenMP_jll v18.1.7+0
  [dd4b983a] LZO_jll v2.10.2+1
  [81d17ec3] L_BFGS_B_jll v3.0.1+0
⌅ [e9f186c6] Libffi_jll v3.2.2+1
  [d4300ac3] Libgcrypt_jll v1.8.11+0
  [7e76a0d4] Libglvnd_jll v1.6.0+0
  [7add5ba3] Libgpg_error_jll v1.49.0+0
  [94ce4f54] Libiconv_jll v1.17.0+0
  [4b2f31a3] Libmount_jll v2.40.1+0
⌅ [89763e89] Libtiff_jll v4.5.1+1
  [38a345b3] Libuuid_jll v2.40.1+0
  [856f044c] MKL_jll v2024.2.0+0
  [e7412a2a] Ogg_jll v1.3.5+1
  [458c3c95] OpenSSL_jll v3.0.15+1
  [efe28fd5] OpenSpecFun_jll v0.5.5+0
  [91d4177d] Opus_jll v1.3.3+0
  [36c8627f] Pango_jll v1.54.1+0
  [30392449] Pixman_jll v0.43.4+0
  [c0090381] Qt6Base_jll v6.7.1+1
  [629bc702] Qt6Declarative_jll v6.7.1+2
  [ce943373] Qt6ShaderTools_jll v6.7.1+1
  [e99dba38] Qt6Wayland_jll v6.7.1+1
  [f50d1b31] Rmath_jll v0.5.1+0
  [76ed43ae] SQLite_jll v3.45.3+0
  [a44049a8] Vulkan_Loader_jll v1.3.243+0
  [a2964d1f] Wayland_jll v1.21.0+1
  [2381bf8a] Wayland_protocols_jll v1.31.0+0
  [02c8fc9c] XML2_jll v2.13.3+0
  [aed1982a] XSLT_jll v1.1.41+0
  [ffd25f8a] XZ_jll v5.4.6+0
  [f67eecfb] Xorg_libICE_jll v1.1.1+0
  [c834827a] Xorg_libSM_jll v1.2.4+0
  [4f6342f7] Xorg_libX11_jll v1.8.6+0
  [0c0b7dd1] Xorg_libXau_jll v1.0.11+0
  [935fb764] Xorg_libXcursor_jll v1.2.0+4
  [a3789734] Xorg_libXdmcp_jll v1.1.4+0
  [1082639a] Xorg_libXext_jll v1.3.6+0
  [d091e8ba] Xorg_libXfixes_jll v5.0.3+4
  [a51aa0fd] Xorg_libXi_jll v1.7.10+4
  [d1454406] Xorg_libXinerama_jll v1.1.4+4
  [ec84b674] Xorg_libXrandr_jll v1.5.2+4
  [ea2f1a96] Xorg_libXrender_jll v0.9.11+0
  [14d82f49] Xorg_libpthread_stubs_jll v0.1.1+0
  [c7cfdc94] Xorg_libxcb_jll v1.17.0+0
  [cc61e674] Xorg_libxkbfile_jll v1.1.2+0
  [e920d4aa] Xorg_xcb_util_cursor_jll v0.1.4+0
  [12413925] Xorg_xcb_util_image_jll v0.4.0+1
  [2def613f] Xorg_xcb_util_jll v0.4.0+1
  [975044d2] Xorg_xcb_util_keysyms_jll v0.4.0+1
  [0d47668e] Xorg_xcb_util_renderutil_jll v0.3.9+1
  [c22f9ab0] Xorg_xcb_util_wm_jll v0.4.1+1
  [35661453] Xorg_xkbcomp_jll v1.4.6+0
  [33bec58e] Xorg_xkeyboard_config_jll v2.39.0+0
  [c5fb5394] Xorg_xtrans_jll v1.5.0+0
  [3161d3a3] Zstd_jll v1.5.6+1
  [35ca27e7] eudev_jll v3.2.9+0
  [214eeab7] fzf_jll v0.53.0+0
  [1a1c6b14] gperf_jll v3.1.1+0
  [a4ae2306] libaom_jll v3.9.0+0
  [0ac62f75] libass_jll v0.15.2+0
  [1183f4f0] libdecor_jll v0.2.2+0
  [2db6ffa8] libevdev_jll v1.11.0+0
  [f638f0a6] libfdk_aac_jll v2.0.3+0
  [36db933b] libinput_jll v1.18.0+0
  [b53b4c65] libpng_jll v1.6.44+0
  [f27f6e37] libvorbis_jll v1.3.7+2
  [009596ad] mtdev_jll v1.1.6+0
  [1317d2d5] oneTBB_jll v2021.12.0+0
⌅ [1270edf5] x264_jll v2021.5.5+0
⌅ [dfaa095f] x265_jll v3.5.0+0
  [d8fb68d0] xkbcommon_jll v1.4.1+1
  [0dad84c5] ArgTools v1.1.1
  [56f22d72] Artifacts
  [2a0f44e3] Base64
  [ade2ca70] Dates
  [8ba89e20] Distributed
  [f43a241f] Downloads v1.6.0
  [7b1f6079] FileWatching
  [9fa8497b] Future
  [b77e0a4c] InteractiveUtils
  [4af54fe1] LazyArtifacts
  [b27032c2] LibCURL v0.6.4
  [76f85450] LibGit2
  [8f399da3] Libdl
  [37e2e46d] LinearAlgebra
  [56ddb016] Logging
  [d6f4376e] Markdown
  [a63ad114] Mmap
  [ca575930] NetworkOptions v1.2.0
  [44cfe95a] Pkg v1.10.0
  [de0858da] Printf
  [9abbd945] Profile
  [3fa0cd96] REPL
  [9a3f8284] Random
  [ea8e919c] SHA v0.7.0
  [9e88b42a] Serialization
  [1a1011a3] SharedArrays
  [6462fe0b] Sockets
  [2f01184e] SparseArrays v1.10.0
  [10745b16] Statistics v1.10.0
  [4607b0f0] SuiteSparse
  [fa267f1f] TOML v1.0.3
  [a4e569a6] Tar v1.10.0
  [8dfed614] Test
  [cf7118a7] UUIDs
  [4ec0a83e] Unicode
  [e66e0078] CompilerSupportLibraries_jll v1.1.1+0
  [deac9b47] LibCURL_jll v8.4.0+0
  [e37daf67] LibGit2_jll v1.6.4+0
  [29816b5a] LibSSH2_jll v1.11.0+1
  [c8ffd9c3] MbedTLS_jll v2.28.2+1
  [14a3606d] MozillaCACerts_jll v2023.1.10
  [4536629a] OpenBLAS_jll v0.3.23+4
  [05823500] OpenLibm_jll v0.8.1+2
  [efcefdf7] PCRE2_jll v10.42.0+1
  [bea87d4a] SuiteSparse_jll v7.2.1+1
  [83775a58] Zlib_jll v1.2.13+1
  [8e850b90] libblastrampoline_jll v5.11.0+0
  [8e850ede] nghttp2_jll v1.52.0+1
  [3f19e933] p7zip_jll v17.4.0+2

Version:

Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 36 × Intel(R) Xeon(R) W-2195 CPU @ 2.30GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, skylake-avx512)
Threads: 25 default, 0 interactive, 12 GC (on 36 virtual cores)
Environment:
  JULIA_NUM_THREADS = 25
@John-Boik John-Boik added the bug Something isn't working label Sep 27, 2024
@John-Boik
Copy link
Author

John-Boik commented Sep 27, 2024

@ChrisRackauckas, I wouldn't think so, but could this be related to the new DiffEqBase issue: SciML/DiffEqBase.jl#1088?

@Vaibhavdixit02
Copy link
Member

Hey, the support for it has been removed, can you specify where in the docs you say this is still there? I thought I had removed it but might have missed some places.

In the example above extra is not used at all, can you specify what you use the extra return for otherwise?

@John-Boik
Copy link
Author

John-Boik commented Sep 27, 2024

@Vaibhavdixit02 Thanks for the reply. I use the extra variable(s) to send the current batch of training data to the callback function for plotting. It is very useful to me, and I'm not sure how I would plot the current solution vs. training data if I could not send data from the loss function. Can you tell me where in the code base it has been removed? Perhaps I can edit my dev version to put it back in. Or, is there some other way to pass data from the loss function to the callback?

More specifically, I use a dataloader to send an index of data rows to the loss function. In the loss function, I create data matrices based on the indexes. I would either need the row indexes or the constructed data in order to plot the current solution vs. training data in the callback function. Is there some other way to accomplish this? I imagine that many users would want to pass data from the loss function to the callback, for the very reason specified in the documentation.

Here is what the documentation says:

Callback Functions
The callback function callback is a function which is called after every optimizer step. Its signature is:
callback = (state, loss_val, other_args) -> false
where state is a OptimizationState and stores information for the current iteration of the solver and loss_val is loss/objective value. For more information about the fields of the state look at the OptimizationState documentation. The other_args can be the extra things returned from the optimization f. This allows for saving values from the optimization and using them for plotting and display without recalculating.

@John-Boik
Copy link
Author

With the new changes, perhaps the recommended way to pass data from the loss function to the callback is via a global variable. That seems to work fine. Is there a better way to accomplish this?

@Vaibhavdixit02
Copy link
Member

Yeah, you'd have to use some global variables some way or the other for this now, there's a few options you could access the index by checking the iter field of the status arg of callback (first arg) or simply use a global variable that gets overwritten with these indices that you access in the loss function. The missed places where the docs for this weren't updated should be cleaned up by next week so that could give you more options by just looking at how it's done (though it would pretty much be the same as what we have discussed)

@Vaibhavdixit02
Copy link
Member

Also the data handling has changed so make sure to look up the new tutorial on that https://docs.sciml.ai/Optimization/stable/tutorials/minibatch/, basically now instead of passing your dataloader to the solve call you pass in the DataLoader object as the p field of OptimizationProblem.

@John-Boik
Copy link
Author

@Vaibhavdixit02, the minibatch documentation imports ncycle but then never uses it. Does ncycle need to be called, or is ncycle imported for some other reason?

using IterTools: ncycle
res1 = Optimization.solve(
    optprob, Optimisers.ADAM(0.05); callback = callback, epochs = 1000)

@Vaibhavdixit02
Copy link
Member

No it's not needed it should be removed from there

@John-Boik
Copy link
Author

I think the keyword epochs is being used as maxiters or something like that. Adapting the package's test code at bit, the code below produces the following output. The expected output is 2 full rounds of the training data, or 10 iterations in total, not two iterations in each epoch (with epoch meaning training over a full set of data, using each record once).

Iteration= 1, 1, Loss= 2.197e+00
Iteration= 2, 2, Loss= 9.193e-01
Iteration= 2, 3, Loss= 9.193e-01
Iteration= 1, 4, Loss= 2.091e+00
Iteration= 2, 5, Loss= 8.688e-01
Iteration= 2, 6, Loss= 8.688e-01

The code is:

module Test

using Optimization, OptimizationOptimisers, DiffEqFlux.Lux, Zygote, MLUtils, Random,
          ComponentArrays

using Format

x = rand(25)
y = sin.(x)
data = MLUtils.DataLoader((x, y), batchsize = 5)

# Define the neural network
model = Chain(Dense(1, 32, tanh), Dense(32, 1))
ps, st = Lux.setup(Random.default_rng(), model)
ps_ca = ComponentArray(ps)
smodel = StatefulLuxLayer{true}(model, nothing, st)

Iter = 0

function callback(state, l)
    global Iter
    Iter += 1
    printfmtln("Iteration= {}, {}, Loss= {:.3e}", state.iter, Iter, l)
    return false
end

function loss(ps, data)
    ypred = [smodel([data[1][i]], ps)[1] for i in eachindex(data[1])]
    return sum(abs2, ypred .- data[2]) 
end

optf = OptimizationFunction(loss, AutoZygote())
prob = OptimizationProblem(optf, ps_ca, data)

res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 2)
end  # --module

@vpuri3
Copy link
Member

vpuri3 commented Oct 4, 2024

Hey, the support for it has been removed, can you specify where in the docs you say this is still there? I thought I had removed it but might have missed some places.

In the example above extra is not used at all, can you specify what you use the extra return for otherwise?

@Vaibhavdixit02 the docs below need to be updated too.

https://github.com/SciML/SciMLBase.jl/blob/master/src/scimlfunctions.jl#L1813-L1815

also the docs for the callback API need to be updated
https://docs.sciml.ai/Optimization/stable/API/solve/.

@vpuri3
Copy link
Member

vpuri3 commented Oct 4, 2024

The minibatch tutorial is still misleading. The function loss_adjoint returns two arguments.

@Vaibhavdixit02
Copy link
Member

yeah sorry about the mess, #838 fixes all of these things. @John-Boik the iteration count was corrected but not released, you should have it available in the next release (in a couple of hours)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants