Hello everybody,
I am using neuralPDE to train neural networks and I wish to save them in files to use them later. It works just fine if I use the standard Lux layers. Yet I wish to use the following custom layer at the end of the model, in order to inforce the initial conditions in the architecture of the neural network:
   struct ICLayer_1D <: Lux.AbstractExplicitContainerLayer{(:model,)}
        model       # Model of the layer
        IC_fct      # Initial condition function
        t0          # Initial time
        layers      # Layers of the model
    end
    function ICLayer_1D(model, IC_fct, t0)
        # Constructor of the layer
        return ICLayer_1D(model, IC_fct, t0, model.layers)
    end
    function (n::ICLayer_1D)(input, ps, st)
        # Operation of the layer
        # The model model must have 2 inputs and 3 outputs.
        # IC_fct must have 1 input.
        # Get model's inputs
        t = input[1,:]
        x = input[2,:]
        # Get model's outputs
        output, st_new = n.model(input, ps, st)
        # Expression of the layers inputs
        α   = (t .- n.t0) .* output[1,:] .+ 1.0
        ξ_x = (t .- n.t0) .* output[2,:] .+ x
        β   = (t .- n.t0) .* output[3,:]
        # Output of the layer
        return α.*n.IC_fct.(ξ_x) .+ β, st_new
    end
I have saved the parameters and the model is a JLD2 file and when I try to reconstruct the model using the command
ps, st = Lux.setup(Random.default_rng(), nn)
where nn is the loaded neural network, I get the following error:
ERROR: LoadError: MethodError: no method matching initialparameters(::TaskLocalRNG, ::JLD2.ReconstructedMutable{:ICLayer_1D, (:model, :IC_fct, :t0, :layers), NTuple{4, Any}})
Closest candidates are:
  initialparameters(::AbstractRNG, ::Any)
   @ LuxCore ~/.julia/packages/LuxCore/biwfu/src/LuxCore.jl:66
  initialparameters(::AbstractRNG, ::InstanceNorm)
   @ Lux ~/.julia/packages/Lux/JXc6P/src/layers/normalize.jl:354
  initialparameters(::AbstractRNG, ::NeuralPDE.dgm_lstm_layer)
   @ NeuralPDE ~/.julia/packages/NeuralPDE/IUP3N/src/dgm.jl:18
  ...
Would anyone know how to fix this?
Thank you for your answers!