Thank you very much for your help! My project is however to do this from scratch, but I have drawn inpiration from your github. I do still struggle though, when I call upon OptimizationFunction, I get an error which I do not understand, is this implementation on the right path? I know there are a few inconsistensies with the code, but right now, I just want it to at least run through:
using Optimization
using Zygote, Lux, Optimisers, Random, Statistics
ad=Lux.Training.AutoZygote()
ps, st = Lux.setup(Xoshiro(0), model_pre)
function hnn(ad, model, x, ps, st)#(hnn::HamiltonianNN{<:LuxCore.AbstractExplicitLayer})(x, ps, st)
model2 = Lux.Experimental.StatefulLuxLayer(model_pre, ps, st)
H = __hamiltonian_forward(model2, x)
n = size(x, 1) ÷ 2
return vcat(selectdim(H, 1, (n + 1):(2n)), -selectdim(H, 1, 1:n)), model2.st
end
function loss_function2(ps, data, target)
pred, st_ = hnn(data, ps, st)
return mean(abs2, pred .- target),pred#, st, ps#, pred
end
opt_func = OptimizationFunction((ps, _, data, target) -> loss_function2(ps, data, target),
AutoForwardDiff())
opt_prob = OptimizationProblem(opt_func, ps_c)
Error:
MethodError: no method matching (OptimizationFunction{true})(::var"#19#20", ::AutoForwardDiff{nothing, Nothing})
Closest candidates are:
(OptimizationFunction{iip})(::Any) where iip at ~/.julia/packages/SciMLBase/QqtZA/src/scimlfunctions.jl:3583
(OptimizationFunction{iip})(::Any, ::SciMLBase.AbstractADType; grad, hess, hv, cons, cons_j, cons_h, lag_h, hess_prototype, cons_jac_prototype, cons_hess_prototype, lag_hess_prototype, syms, paramsyms, observed, hess_colorvec, cons_jac_colorvec, cons_hess_colorvec, lag_hess_colorvec, expr, cons_expr, sys) where iip at ~/.julia/packages/SciMLBase/QqtZA/src/scimlfunctions.jl:3583
Stacktrace:
[1] OptimizationFunction(::Function, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ SciMLBase ~/.julia/packages/SciMLBase/QqtZA/src/scimlfunctions.jl:3581
[2] OptimizationFunction(::Function, ::Vararg{Any})
@ SciMLBase ~/.julia/packages/SciMLBase/QqtZA/src/scimlfunctions.jl:3581
[3] top-level scope
@ In[48]:2