Hi, Iām trying to make my neural network compatible with a new syntax of Flux.
Note that the network I used, PLSE
, receives two arguments as PLSE(x, u)
.
- Code
using Test
using ParametrisedConvexApproximators
using LinearAlgebra
using Flux
function main()
n, m = 3, 2
d = 10
# dataset
X = rand(n, d)
Y = rand(m, d)
Z = hcat([norm(X[:, i])+norm(Y[:, i]) for i in 1:d]...)
# network construction
i_max = 20
T = 1e-0
h_array = [64, 64]
act = Flux.leakyrelu
N = 1_000 # The result may be poor if it's too low
model = PLSE(n, m, i_max, T, h_array, act)
params_init = deepcopy(Flux.params(model))
@test all(Flux.params(model) .== params_init)
# training
data = Flux.DataLoader((X, Y, Z), batchsize=32)
opt_state = Flux.setup(Adam(1e-4), model)
for epoch in 1:10
for (x, y, z) in data
val, grads = Flux.withgradient(model) do m
pred = m(x, y)
Flux.Losses.mse(pred, z)
end
Flux.update!(opt_state, model, grads[1])
end
end
@test all(Flux.params(model) .!= params_init)
end
@testset "dataset" begin
main()
end
- Error
julia> include("test/pure_train.jl")
ā Warning: setup found no trainable parameters in this model
ā @ Optimisers ~/.julia/packages/Optimisers/BT5bT/src/interface.jl:27
ā Warning: Number of observations less than batch-size, decreasing the batch-size to 10
ā @ MLUtils ~/.julia/packages/MLUtils/NDuSY/src/batchview.jl:95
dataset: Error During Test at /Users/jinrae/.julia/dev/ParametrisedConvexApproximators/test/pure_train.jl:39
Got exception outside of a @test
MethodError: no method matching +(::NamedTuple{(:contents,), Tuple{Array{Float64, 3}}}, ::Base.RefValue{Any})
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
+(::Union{InitialValues.NonspecificInitialValue, InitialValues.SpecificInitialValue{typeof(+)}}, ::Any) at ~/.julia/packages/InitialValues/OWP8V/src/InitialValues.jl:154
+(::Union{MathOptInterface.ScalarAffineFunction{T}, MathOptInterface.ScalarQuadraticFunction{T}}, ::T) where T at ~/.julia/packages/MathOptInterface/fTxO0/src/Utilities/functions.jl:1783
...
Stacktrace:
[1] accum(x::NamedTuple{(:contents,), Tuple{Array{Float64, 3}}}, y::Base.RefValue{Any})
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/lib/lib.jl:17
[2] accum(x::NamedTuple{(:contents,), Tuple{Array{Float64, 3}}}, y::Nothing, zs::Base.RefValue{Any})
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/lib/lib.jl:22
[3] Pullback
@ ~/.julia/dev/ParametrisedConvexApproximators/src/approximators/parametrised_convex_approximators/parametrised_convex_approximators.jl:15 [inlined]
[4] (::typeof(ā(affine_map)))(Ī::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
[5] Pullback
@ ~/.julia/dev/ParametrisedConvexApproximators/src/approximators/parametrised_convex_approximators/PLSE.jl:36 [inlined]
[6] (::typeof(ā(Ī»)))(Ī::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
[7] Pullback
@ ~/.julia/dev/ParametrisedConvexApproximators/test/pure_train.jl:29 [inlined]
[8] (::typeof(ā(Ī»)))(Ī::Float64)
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
[9] (::Zygote.var"#60#61"{typeof(ā(Ī»))})(Ī::Float64)
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:45
[10] withgradient(f::Function, args::PLSE)
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:133
[11] main()
@ Main ~/.julia/dev/ParametrisedConvexApproximators/test/pure_train.jl:28
[12] macro expansion
@ ~/.julia/dev/ParametrisedConvexApproximators/test/pure_train.jl:40 [inlined]
[13] macro expansion
@ /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/Test/src/Test.jl:1363 [inlined]
[14] top-level scope
@ ~/.julia/dev/ParametrisedConvexApproximators/test/pure_train.jl:40
[15] include(fname::String)
@ Base.MainInclude ./client.jl:476
[16] top-level scope
@ REPL[2]:1
[17] top-level scope
@ ~/.julia/packages/Infiltrator/r3Hf5/src/Infiltrator.jl:710
[18] top-level scope
@ ~/.julia/packages/CUDA/BbliS/src/initialization.jl:52
[19] eval
@ ./boot.jl:368 [inlined]
[20] eval_user_input(ast::Any, backend::REPL.REPLBackend)
@ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:151
[21] repl_backend_loop(backend::REPL.REPLBackend)
@ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:247
[22] start_repl_backend(backend::REPL.REPLBackend, consumer::Any)
@ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:232
[23] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool)
@ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:369
[24] run_repl(repl::REPL.AbstractREPL, consumer::Any)
@ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:355
[25] (::Base.var"#967#969"{Bool, Bool, Bool})(REPL::Module)
@ Base ./client.jl:419
[26] #invokelatest#2
@ ./essentials.jl:729 [inlined]
[27] invokelatest
@ ./essentials.jl:726 [inlined]
[28] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool)
@ Base ./client.jl:404
[29] exec_options(opts::Base.JLOptions)
@ Base ./client.jl:318
[30] _start()
@ Base ./client.jl:522
Test Summary: | Pass Error Total Time
dataset | 1 1 2 5.1s
ERROR: LoadError: Some tests did not pass: 1 passed, 0 failed, 1 errored, 0 broken.
in expression starting at /Users/jinrae/.julia/dev/ParametrisedConvexApproximators/test/pure_train.jl:39
You can find the code here:
./test/pure_train.jl in branch hotfix/compatibility-flux