Obscure multiplication method error when using sciml_train

Julia newbie here.

I solve a convection diffusion reaction problem in tow variables where the reaction term is replaced with an ANN with two inputs and one output. This should be trained against measured data c_true.

Defining the problem is no issue, I use a centered finite difference scheme made with BandedMatrices.jl and stitch everything together in odesys:

using BandedMatrices, DiffEqFlux, Flux, Optim, OrdinaryDiffEq, DiffEqSensitivity, Plots, DelimitedFiles

#c_true = readdlm("onedatacol.txt") 
c_true=rand(400,1)

tend = 200
tspan = (0.0,tend)

D = 1.66e-6
v = .0166667

#IC/BC
câ‚€ = zeros(2*100,1)
câ‚€[1] = 0.34

## FUNS
function centraldifference(n)
    Δ = 1/(n-1)                                     # spacing between grid points
    l, u =  (1,1)                                   # lower and upper bandwidths
    A = BandedMatrix{Float64}(undef, (n,n), (l,u))  #intialize the banded matrix
    A[band(0)] .= -2*D/Δ^2                          # set the diagonal band
    A[band(1)] .= D/Δ^2-v/(2*Δ)                     # set the super-diagonal band
    A[band(-1)] .= D/Δ^2+v/(2*Δ)                    # set the  sub-diagonal band
    A[1,1:2] .= 0                                   #left BC
    A[n,n] = -D/Δ^2+-v/(2*Δ)                        #right BC
    A[n,n-1] = D/Δ^2+v/(2*Δ)
    return A
end

ann = FastChain(FastDense(2,16,tanh), FastDense(16,16,tanh), FastDense(16,1))

annpars = initial_params(ann)


function rhs(c,q,annpars)
     return  ann([c'; q'],annpars)'
end

odesyspars = annpars

 function odesys(dc,c,odesyspars,t)
     npoints=100;
     dc[npoints+1:2*npoints] = rhs(c[1:npoints],c[npoints+1:2*npoints],odesyspars)
     dc[1:100] = centraldifference(npoints)*c[1:npoints] + dc[npoints+1:2*npoints] # dc[1:npoints].= centraldifference(npoints)*c does not work, why?
 end

The system solves like a charm and also the loss function returns a value

 prob = ODEProblem(odesys, câ‚€, tspan,odesyspars)
 sol = solve(prob,KenCarp4(),sensealg=BacksolveAdjoint(checkpointing=true),saveat=0.5:0.5:tend)
#
#plot(sol)
#
 θ = annpars
#
function predict_mod(θ)
  return solve(prob,KenCarp4(),c₀=c₀,odesyspars=θ,reltol=1e-8,abstol=1e-8,sensealg=BacksolveAdjoint(checkpointing=true),saveat=0.5:0.5:tend)
end
#
function loss(θ)
    sol = predict_mod(θ)

    if any((s.retcode != :Success for s in sol))
        return Inf
    else
        return sum(abs2, Array(sol)[100,1,:].-c_true)
    end
end

 l = loss(θ)

If I throw everything into sciml_train however:

cb = function (θ,l)
  println(l)
  return false
end

 cb(θ,l)

 res = DiffEqFlux.sciml_train(loss, θ, ADAM(0.01), cb = cb,maxiters=200)

I get a very obscure “ambiguous multiplication” error I really do not understand nor see what’s going wrong at all nor understand the possible fix:

ERROR: LoadError: MethodError: *(::BandedMatrix{Float64,Array{Float64,2},Base.OneTo{Int64}}, ::ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}) is ambiguous. Candidates:
  *(A::ArrayLayouts.LayoutArray{T,2} where T, B::AbstractArray{T,1} where T) in ArrayLayouts at /Users/someuser/.julia/packages/ArrayLayouts/x9nhz/src/muladd.jl:462
  *(x::AbstractArray{T,2} where T, y::ReverseDiff.TrackedArray{V,D,1,VA,DA} where DA where VA) where {V, D} in ReverseDiff at /Users/someuser/.julia/packages/ReverseDiff/Thhqg/src/derivatives/linalg/arithmetic.jl:214
Possible fix, define
  *(::ArrayLayouts.LayoutArray{T,2} where T, ::ReverseDiff.TrackedArray{V,D,1,VA,DA} where DA where VA) where {V, D}

Thanks for any help/inputs in advance.

Edit: formatting

using BandedMatrices, DiffEqFlux, Flux, Optim, OrdinaryDiffEq, DiffEqSensitivity, Plots, DelimitedFiles

#c_true = readdlm("onedatacol.txt")
c_true=rand(400)

tend = 200
tspan = (0.0,tend)

D = 1.66e-6
v = .0166667

#IC/BC
câ‚€ = zeros(2*100)
câ‚€[1] = 0.34

## FUNS
function centraldifference(n)
    Δ = 1/(n-1)                                     # spacing between grid points
    l, u =  (1,1)                                   # lower and upper bandwidths
    A = BandedMatrix{Float64}(undef, (n,n), (l,u))  #intialize the banded matrix
    A[band(0)] .= -2*D/Δ^2                          # set the diagonal band
    A[band(1)] .= D/Δ^2-v/(2*Δ)                     # set the super-diagonal band
    A[band(-1)] .= D/Δ^2+v/(2*Δ)                    # set the  sub-diagonal band
    A[1,1:2] .= 0                                   #left BC
    A[n,n] = -D/Δ^2+-v/(2*Δ)                        #right BC
    A[n,n-1] = D/Δ^2+v/(2*Δ)
    return A
end
const npoints=100
const A = centraldifference(npoints)

ann = FastChain(FastDense(2,16,tanh), FastDense(16,16,tanh), FastDense(16,1))

annpars = Float64.(initial_params(ann))


function rhs(c,q,annpars)
     return  vec(ann([c'; q'],annpars))
end

odesyspars = annpars

function odesys(c,odesyspars,t)
    x2 = rhs(c[1:npoints],c[npoints+1:2*npoints],odesyspars)
    x1 = A*c[1:npoints] + x2
    vcat(x1,x2)
end


prob = ODEProblem(odesys, câ‚€, tspan,odesyspars)
sol = solve(prob,KenCarp4(),saveat=0.5:0.5:tend)
#
#plot(sol)
#
θ = annpars
#
function predict_mod(θ)
  return solve(prob,KenCarp4(autodiff=false),p=θ,sensealg=InterpolatingAdjoint(),saveat=0.5:0.5:tend)
end
#
function loss(θ)
    sol = predict_mod(θ)

    if any((s.retcode != :Success for s in sol))
        return Inf
    else
        return sum(abs2, sol[100,:].-c_true)
    end
end

l = loss(θ)

cb = function (θ,l)
    println(l)
    return false
end

cb(θ,l)

res = DiffEqFlux.sciml_train(loss, θ, ADAM(0.0005), cb = cb,maxiters=200)

That’s a working version. There’s a lot to say here:

  • The MATLABisms, i.e. zeros(100,1) instead of zeros(100) were the main cause of your issues. Whatever is doing that in the adjoints should get fixed, but at the same time you can just make vectors if you want vectors and it will act much better.
  • I removed checkpointing and the tolerances just to make it quicker to play with, but they can get added back of course.
  • The stiffness means that explicit optimizers have stiffness issues, so ADAM’s learning rate needs to be set low, or use BFGS (which is what I’d recommend there instead).

That’s the gist of it at least.

3 Likes

Thank you very much.

So just to clarify: Besides my newbie-shortcomings the main issue is that autodiff is not working in this particular case?

Yeah, it’s not your fault really. The way your code was written is in a way that technically works, but no Julia user would do those MATLABisms and the autodiff library seems to get tripped up by them. So just by taking off the ,1 it fixes the autodiff. Technically, the autodiff should support that case. But in a practical sense, no one really does that so I can see why it ran into one bug somewhere. @dhairyagandhi96 might want to take a look and see if he can find out where the sizing issue comes up, but I think he’s working on similar array sizing adjoints right now anyways which might be the solution.

2 Likes

Thank you for the clarification, it is really helpful for learning this amazing language. One additional question: Is it possible to make this code work when A is not constant? E.g. if we also want to solve the inverse problem of fitting the diffusion coefficient D? Due to the mutable restrictions of Zygote I tried Zygote.Buffer()

function centraldifference(D)
    n=100
    Δ = 1/(n-1)                                     # spacing between grid points
    l, u =  (1,1)                                   # lower and upper bandwidths
    A = Zygote.Buffer(BandedMatrix{Float64}(undef, (n,n), (l,u)))  #intialize the banded matrix
    A[band(0)] .= -2*D/Δ^2                          # set the diagonal band
    A[band(1)] .= D/Δ^2-v/(2*Δ)                     # set the super-diagonal band
    A[band(-1)] .= D/Δ^2+v/(2*Δ)                    # set the  sub-diagonal band
    A[1,1:2] .= 0                                   #left BC
    A[n,n] = -D/Δ^2+-v/(2*Δ)                        #right BC
    A[n,n-1] = D/Δ^2+v/(2*Δ)
    return copy(A)
end
const npoints=100
#const A = centraldifference(npoints)

and of course adapted the ODE system

odesyspars = [annpars, D]

function odesys(c,odesyspars,t)
    x2 = rhs(c[1:npoints],c[npoints+1:2*npoints],odesyspars[1])
    x1 = A(odesyspars[2])*c[1:npoints] + x2
    vcat(x1,x2)
end

But then A(D) is behaving really, really strange, even repeated calling of A(some_constant_number) is producing NaNs at different(!) random places in A which of course killst the solver instantly. Is this related to #254?

I also tried Zygote.@nograd centraldifference however also without success (besides the fact it is super contradictory to the purpose of a general AD system).

What would be a good workaround for this problem?

Edit: Just to clarify, of course D is bound and I’m using Fminbox(LBFGS()).

I would just not use Zygote.Buffer. Instead, you can define a banded matrix by creating all of the arrays for the bands, and using the directly in the constructor. That will not require any mutation and will thus not have issues with Zygote.

2 Likes

So what would be a Julia way to still have it with an inplace ODE system formulation? I’m asking because I need to speed up the solution significantly to make it feasible for the optimisation and I understand from the tutorials this would be a good approach.

If I benchmark the system as you formulated it above I get:

@btime sol = solve(prob,KenCarp4(),saveat=0.5:0.5:tend)
1.726 s (195896 allocations: 2.12 GiB)

but

function odesys(dc,c,odesyspars,t)
    x2 = rhs(c[1:npoints],c[npoints+1:2*npoints],odesyspars)
    x1 = A*c[1:npoints] + x2
    dc = vcat(x1,x2)
end

does not work somehow (SingularException, I guess due to NaNs like in my post above).
The only thing that works is

function odesys(dc,c,odesyspars,t)
    x2 = rhs(c[1:npoints],c[npoints+1:2*npoints],odesyspars)
    x1 = A*c[1:npoints] + x2
    dc[:] = vcat(x1,x2)
end

with:

181.850 ms (53735 allocations: 123.05 MiB)

which is already a much needed 9x improvement (without all the other additional tricks from the tutorials, but one thing at a time…) but does not work inside sciml_train/ the AD for the same reasons you elaborated above.

that’s not a mutating function. That’s allocating a new vector and setting it equal to an output of the same name, not actually changing the vector.

That’s better, but you’re still creating a vector. Watch https://youtu.be/M2i7sSRcSIw