Broadcasting error diffeqflux

Built a model according to the augmented Neural ODE tutorial as follows:


@with_kw struct Args
	η = 1e-3                # learning rate
    λ = 0.01f0              # regularization paramater
    batch_size = 64        
    sample_size = 10        # sampling size for output    
    epochs = 5             
    seed = 2022                # random seed
    save_path = "output"    # results path
end

function build_model(input_dim, augment_dim)
	
	input_dim = input_dim + augment_dim
	
	nn = Chain(Dense(input_dim , 1024, elu),
	BatchNorm(1024),
	Dense(1024 , 512, selu),
	BatchNorm(512, elu),
	Dense(512 , 512, hardtanh),
	BatchNorm(512),
	Dropout(0.5))

	nn_ode = NeuralODE(nn, (0.f0, 1.f0), Tsit5(),
				   save_everystep = false,
				   reltol = 1e-3, abstol = 1e-3,
				   save_start = false)
	
    DiffEqArray_to_Array(x) = reshape(Array(x), size(x)[1:2])
	
    nn_ode = augment_dim == 0 ? nn_ode : AugmentedNDELayer(nn_ode, augment_dim)
	
	fc = Chain(Dense(512 , 64 , celu),
		Dropout(0.5),
		Dense(64, 12),
	    logsoftmax)

	model = Chain((x, p=nn_ode.p) -> nn_ode(x, p),
              nn_ode,
              DiffEqArray_to_Array,
              fc)
	
	return model , nn_ode.p
end


function train(; kws...)
	
	args = Args(kws...)
	
	args.seed > 0 && Random.seed!(args.seed)
	
	!ispath(args.save_path) && mkpath(args.save_path)
	
    train_loader = DataLoader((X_train',y_train'), batchsize=args.batch_size, shuffle=true)
	test_loader = DataLoader((X_test',y_test'),batchsize=args.batch_size, shuffle=true)
	
	@info "Building Model..."
	model, parameters = build_model(1024, 1)
	
	function loss(x, y)
		mask = y .≥ 0
		return mean(logitbinarycrossentropy(model(x),y) .* mask)
    end 

	opt = Optimiser(WeightDecay(args.λ), ADAM(args.η))

	cb = function()
		global iter += 1
		if iter % 10 == 1
			println("Iteration $iter || Loss = $(loss(train_loader.data[1], train_loader.data[2]))")
		end
	end

	@info "Start Training, total $(args.epochs) epochs"
	for epoch ∈ 1:args.epochs
        @info "Epoch $(epoch)"
		train!(loss, params([parameters, model]), train_loader, opt, cb = cb)
	end
end

for a test train()


DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 512 and 1025

_bcs1@broadcast.jl:516[inlined]
_bcs@broadcast.jl:510[inlined]
broadcast_shape@broadcast.jl:504[inlined]
combine_axes@broadcast.jl:499[inlined]
_axes@broadcast.jl:224[inlined]
axes@broadcast.jl:222[inlined]
combine_axes@broadcast.jl:499[inlined]
instantiate@broadcast.jl:281[inlined]
materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(/), Tuple{Matrix{Float32}, Matrix{Float64}}}, Float32}})@broadcast.jl:860
ode_determine_initdt(::Matrix{Float32}, ::Float32, ::Float32, ::Float32, ::Float64, ::Float64, ::typeof(DiffEqBase.ODE_DEFAULT_NORM), ::SciMLBase.ODEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, Vector{Float32}, SciMLBase.ODEFunction{false, DiffEqFlux.var"#dudt_#118"{DiffEqFlux.NeuralODE{Flux.Chain{Tuple{Flux.Dense{typeof(NNlib.elu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.selu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(NNlib.elu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.hardtanh), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dropout{Float64, Colon}}}, Vector{Float32}, Flux.var"#66#68"{Flux.Chain{Tuple{Flux.Dense{typeof(NNlib.elu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.selu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(NNlib.elu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.hardtanh), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dropout{Float64, Colon}}}}, Tuple{Float32, Float32}, Tuple{OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Real, NTuple{4, Symbol}, NamedTuple{(:save_everystep, :reltol, :abstol, :save_start), Tuple{Bool, Float64, Float64, Bool}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::OrdinaryDiffEq.ODEIntegrator{OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, false, Matrix{Float32}, Nothing, Float32, Vector{Float32}, Float32, Float32, Float32, Float32, Vector{Matrix{Float32}}, SciMLBase.ODESolution{Float32, 3, Vector{Matrix{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Matrix{Float32}}}, SciMLBase.ODEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, Vector{Float32}, SciMLBase.ODEFunction{false, DiffEqFlux.var"#dudt_#118"{DiffEqFlux.NeuralODE{Flux.Chain{Tuple{Flux.Dense{typeof(NNlib.elu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.selu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(NNlib.elu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.hardtanh), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dropout{Float64, Colon}}}, Vector{Float32}, Flux.var"#66#68"{Flux.Chain{Tuple{Flux.Dense{typeof(NNlib.elu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.selu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(NNlib.elu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.hardtanh), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dropout{Float64, Colon}}}}, Tuple{Float32, Float32}, Tuple{OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Real, NTuple{4, Symbol}, NamedTuple{(:save_everystep, :reltol, :abstol, :save_start), Tuple{Bool, Float64, Float64, Bool}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, DiffEqFlux.var"#dudt_#118"
{DiffEqFlux.NeuralODE{Flux.Chain{Tuple{Flux.Dense{typeof(NNlib.elu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.selu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(NNlib.elu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.hardtanh), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dropout{Float64, Colon}}}, Vector{Float32}, Flux.var"#66#68"{Flux.Chain{Tuple{Flux.Dense{typeof(NNlib.elu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.selu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(NNlib.elu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.hardtanh), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dropout{Float64, Colon}}}}, Tuple{Float32, Float32}, Tuple{OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Real, NTuple{4, Symbol}, NamedTuple{(:save_everystep, :reltol, :abstol, :save_start), Tuple{Bool, Float64, Float64, Bool}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Matrix{Float32}}, Vector{Float32}, Vector{Vector{Matrix{Float32}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}}, DiffEqBase.DEStats}, SciMLBase.ODEFunction{false, DiffEqFlux.var"#dudt_#118"{DiffEqFlux.NeuralODE{Flux.Chain{Tuple{Flux.Dense{typeof(NNlib.elu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.selu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(NNlib.elu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.hardtanh), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dropout{Float64, Colon}}}, Vector{Float32}, Flux.var"#66#68"{Flux.Chain{Tuple{Flux.Dense{typeof(NNlib.elu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.selu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(NNlib.elu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.hardtanh), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dropout{Float64, Colon}}}}, Tuple{Float32, Float32}, Tuple{OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Real, NTuple{4, Symbol}, NamedTuple{(:save_everystep, :reltol, :abstol, :save_start), Tuple{Bool, Float64, Float64, Bool}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, OrdinaryDiffEq.DEOptions{Float64, Float64, Float32, Float32, OrdinaryDiffEq.PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Bool, SciMLBase.CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, Tuple{}, Tuple{}}, Matrix{Float32}, Float32, Nothing, OrdinaryDiffEq.DefaultInit})@initdt.jl:206
auto_dt_reset!(::OrdinaryDiffEq.ODEIntegrator{OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, false, Matrix{Float32}, Nothing, Float32, Vector{Float32}, Float32, Float32, Float32, Float32, Vector{Matrix{Float32}}, SciMLBase.ODESolution{Float32, 3, Vector{Matrix{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Matrix{Float32}}}, SciMLBase.ODEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, Vector{Float32}, SciMLBase.ODEFunction{false, DiffEqFlux.var"#dudt_#118"{DiffEqFlux.NeuralODE{Flux.Chain{Tuple{Flux.Dense{typeof(NNlib.elu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.selu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(NNlib.elu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.hardtanh), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dropout{Float64, Colon}}}, Vector{Float32}, Flux.var"#66#68"{Flux.Chain{Tuple{Flux.Dense{typeof(NNlib.elu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.selu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(NNlib.elu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.hardtanh), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dropout{Float64, Colon}}}}, Tuple{Float32, Float32}, Tuple{OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Real, NTuple{4, Symbol}, NamedTuple{(:save_everystep, :reltol, :abstol, :save_start), Tuple{Bool, Float64, Float64, Bool}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, DiffEqFlux.var"#dudt_#118"{DiffEqFlux.NeuralODE{Flux.Chain{Tuple{Flux.Dense{typeof(NNlib.elu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.selu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(NNlib.elu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.hardtanh), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dropout{Float64, Colon}}}, Vector{Float32}, Flux.var"#66#68"{Flux.Chain{Tuple{Flux.Dense{typeof(NNlib.elu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.selu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(NNlib.elu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.hardtanh), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dropout{Float64, Colon}}}}, Tuple{Float32, Float32}, Tuple{OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Real, NTuple{4, Symbol}, NamedTuple{(:save_everystep, :reltol, :abstol, :save_start), Tuple{Bool, Float64, Float64, Bool}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Matrix{Float32}}, Vector{Float32}, Vector{Vector{Matrix{Float32}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}}, DiffEqBase.DEStats}, SciMLBase.ODEFunction{false, DiffEqFlux.var"#dudt_#118"{DiffEqFlux.NeuralODE{Flux.Chain{Tuple{Flux.Dense{typeof(NNlib.elu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.selu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(NNlib.elu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.hardtanh), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dropout{Float64, Colon}}}, Vector{Float32}, Flux.var"#66#68"{Flux.Chain{Tuple{Flux.Dense{typeof(NNlib.elu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.selu), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(NNlib.elu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dense{typeof(NNlib.hardtanh), Matrix{Float32}, Vector{Float32}}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, Flux.Dropout{Float64, Colon}}}}, Tuple{Float32, Float32},

67 errors in total are too long, the above error is just an excerpt.

my X_train is a 12053×1024 Matrix{Float32} y_train is a 12053×12 Matrix{Float32}

What did I miss?

Tbh, the error messages are not helpful.

I know the stack traces can be overwhelming. In such cases I find it useful to dissect every piece on it’s own. Going through the stages of the model manually should lead to the following conclusions.

  1. nn needs to have input- equal to output dimension. It’s the right-hand-side of an ODE after all.
  2. model = Chain((x, p=nn_ode.p) -> nn_ode(x, p),
                  nn_ode, < *this is superfluous*
                  DiffEqArray_to_Array,
                  fc)
    
  3. You may want logitbinarycrossentropy(model(x), y.*mask) for the loss to be scalar.

1, Yeah, that’s why I transposed the X_train and y_train
2, so you mean first we need to a “down Layer” which reduces the real input dimension into output, in this case

down = Chain(Dense( 1024 => 12))  
nn = ...

could you please elaborate on this?

3, I did not change this part it is from the tutorial, and I am sure I understand that.

  1. thanks for the tip.

Thank you!

1, Yeah, that’s why I transposed the X_train and y_train

I saw that too by now :slight_smile:

To which tutorial are you referring? I know of Augmented Neural Ordinary Differential Equations · DiffEqFlux.jl where the model is constructed like

diffeqarray_to_array(x) = reshape(gpu(x), size(x)[1:2])

function construct_model(out_dim, input_dim, hidden_dim, augment_dim)
    input_dim = input_dim + augment_dim
    node = NeuralODE(Chain(Dense(input_dim, hidden_dim, relu),
                           Dense(hidden_dim, hidden_dim, relu),
                           Dense(hidden_dim, input_dim)) |> gpu,
                     (0.f0, 1.f0), Tsit5(), save_everystep = false,
                     reltol = 1e-3, abstol = 1e-3, save_start = false) |> gpu
    node = augment_dim == 0 ? node : (AugmentedNDELayer(node, augment_dim) |> gpu)
    return Chain((x, p=node.p) -> node(x, p),
                 diffeqarray_to_array,
                 Dense(input_dim, out_dim) |> gpu), node.p |> gpu
end

You can see that the output dimension of node is equal to its input dimension. This has to be the case, because it solves an ODE (\mathbf{h(t)}\in R^d).
\mathbf{h}'(t) = \mathbf{f}(\mathbf{h(t))} where \mathbf{f} is a map R^d\to R^d. f is nn and h(0) is the input X_train.
I don’t think I can make a qualified comment on the internal structure of the network.

The reduction to the dimension of y is done by the second network, but not by the NODE.

Also, the line I marked as ‘superfluous’ is not present.

1 Like

thanks for your clear explanation :smiley:

1 Like