Error with a new syntax of Flux

Hi, Iā€™m trying to make my neural network compatible with a new syntax of Flux.

Note that the network I used, PLSE, receives two arguments as PLSE(x, u).

  • Code
using Test
using ParametrisedConvexApproximators
using LinearAlgebra
using Flux


function main()
    n, m = 3, 2
    d = 10
    # dataset
    X = rand(n, d)
    Y = rand(m, d)
    Z = hcat([norm(X[:, i])+norm(Y[:, i]) for i in 1:d]...)
    # network construction
    i_max = 20
    T = 1e-0
    h_array = [64, 64]
    act = Flux.leakyrelu
    N = 1_000  # The result may be poor if it's too low
    model = PLSE(n, m, i_max, T, h_array, act)
    params_init = deepcopy(Flux.params(model))
    @test all(Flux.params(model) .== params_init)
    # training
    data = Flux.DataLoader((X, Y, Z), batchsize=32)
    opt_state = Flux.setup(Adam(1e-4), model)
    for epoch in 1:10
        for (x, y, z) in data
            val, grads = Flux.withgradient(model) do m
                pred = m(x, y)
                Flux.Losses.mse(pred, z)
            end
            Flux.update!(opt_state, model, grads[1])
        end
    end
    @test all(Flux.params(model) .!= params_init)
end


@testset "dataset" begin
    main()
end
  • Error
julia> include("test/pure_train.jl")
ā”Œ Warning: setup found no trainable parameters in this model
ā”” @ Optimisers ~/.julia/packages/Optimisers/BT5bT/src/interface.jl:27
ā”Œ Warning: Number of observations less than batch-size, decreasing the batch-size to 10
ā”” @ MLUtils ~/.julia/packages/MLUtils/NDuSY/src/batchview.jl:95
dataset: Error During Test at /Users/jinrae/.julia/dev/ParametrisedConvexApproximators/test/pure_train.jl:39
  Got exception outside of a @test
  MethodError: no method matching +(::NamedTuple{(:contents,), Tuple{Array{Float64, 3}}}, ::Base.RefValue{Any})
  Closest candidates are:
    +(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
    +(::Union{InitialValues.NonspecificInitialValue, InitialValues.SpecificInitialValue{typeof(+)}}, ::Any) at ~/.julia/packages/InitialValues/OWP8V/src/InitialValues.jl:154
    +(::Union{MathOptInterface.ScalarAffineFunction{T}, MathOptInterface.ScalarQuadraticFunction{T}}, ::T) where T at ~/.julia/packages/MathOptInterface/fTxO0/src/Utilities/functions.jl:1783
    ...
  Stacktrace:
    [1] accum(x::NamedTuple{(:contents,), Tuple{Array{Float64, 3}}}, y::Base.RefValue{Any})
      @ Zygote ~/.julia/packages/Zygote/g2w9o/src/lib/lib.jl:17
    [2] accum(x::NamedTuple{(:contents,), Tuple{Array{Float64, 3}}}, y::Nothing, zs::Base.RefValue{Any})
      @ Zygote ~/.julia/packages/Zygote/g2w9o/src/lib/lib.jl:22
    [3] Pullback
      @ ~/.julia/dev/ParametrisedConvexApproximators/src/approximators/parametrised_convex_approximators/parametrised_convex_approximators.jl:15 [inlined]
    [4] (::typeof(āˆ‚(affine_map)))(Ī”::Matrix{Float64})
      @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
    [5] Pullback
      @ ~/.julia/dev/ParametrisedConvexApproximators/src/approximators/parametrised_convex_approximators/PLSE.jl:36 [inlined]
    [6] (::typeof(āˆ‚(Ī»)))(Ī”::Matrix{Float64})
      @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
    [7] Pullback
      @ ~/.julia/dev/ParametrisedConvexApproximators/test/pure_train.jl:29 [inlined]
    [8] (::typeof(āˆ‚(Ī»)))(Ī”::Float64)
      @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
    [9] (::Zygote.var"#60#61"{typeof(āˆ‚(Ī»))})(Ī”::Float64)
      @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:45
   [10] withgradient(f::Function, args::PLSE)
      @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:133
   [11] main()
      @ Main ~/.julia/dev/ParametrisedConvexApproximators/test/pure_train.jl:28
   [12] macro expansion
      @ ~/.julia/dev/ParametrisedConvexApproximators/test/pure_train.jl:40 [inlined]
   [13] macro expansion
      @ /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/Test/src/Test.jl:1363 [inlined]
   [14] top-level scope
      @ ~/.julia/dev/ParametrisedConvexApproximators/test/pure_train.jl:40
   [15] include(fname::String)
      @ Base.MainInclude ./client.jl:476
   [16] top-level scope
      @ REPL[2]:1
   [17] top-level scope
      @ ~/.julia/packages/Infiltrator/r3Hf5/src/Infiltrator.jl:710
   [18] top-level scope
      @ ~/.julia/packages/CUDA/BbliS/src/initialization.jl:52
   [19] eval
      @ ./boot.jl:368 [inlined]
   [20] eval_user_input(ast::Any, backend::REPL.REPLBackend)
      @ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:151
   [21] repl_backend_loop(backend::REPL.REPLBackend)
      @ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:247
   [22] start_repl_backend(backend::REPL.REPLBackend, consumer::Any)
      @ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:232
   [23] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool)
      @ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:369
   [24] run_repl(repl::REPL.AbstractREPL, consumer::Any)
      @ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:355
   [25] (::Base.var"#967#969"{Bool, Bool, Bool})(REPL::Module)
      @ Base ./client.jl:419
   [26] #invokelatest#2
      @ ./essentials.jl:729 [inlined]
   [27] invokelatest
      @ ./essentials.jl:726 [inlined]
   [28] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool)
      @ Base ./client.jl:404
   [29] exec_options(opts::Base.JLOptions)
      @ Base ./client.jl:318
   [30] _start()
      @ Base ./client.jl:522
Test Summary: | Pass  Error  Total  Time
dataset       |    1      1      2  5.1s
ERROR: LoadError: Some tests did not pass: 1 passed, 0 failed, 1 errored, 0 broken.
in expression starting at /Users/jinrae/.julia/dev/ParametrisedConvexApproximators/test/pure_train.jl:39

You can find the code here:
./test/pure_train.jl in branch hotfix/compatibility-flux

The problem comes from the fact that in your package, e.g. in

you improperly overload Flux.params for your types, instead of using Flux.@functor and overloading Flux.trainable. Besides the warning you observe
Warning: setup found no trainable parameters in this model
your model is also not transferable to gpu among other things.

See
https://fluxml.ai/Flux.jl/stable/models/advanced/
for a full guide.

3 Likes

I see, thank you!