Gradient of a loss function : struggling to avoid arrays mutation

Dear all,

I’m very new to Julia and mainly interested in its SciML ecosystem. I’m using Julia 1.5 through JuliaPro.

My problem deals with fitting parameters driving an ODE (eg the Lotka-Volterra ODE from the SciML tutorials) with respect to a set of several initial conditions. I actually assume that I have an array tab_u0 (100×2 Array{Float64,2}) of 100 rows and 2 columns, each row representing an initial condition as the dimension of the state space is 2 (rabbits and wolves). I also assume that I observe the final point (let’s say at t = 10) of the Lotka-Volterra ODE for each of the 100 initial conditions in tab_u0 and that the parameter p is the same for the 100 simulation. My goal is to fit the parameter p (4-element Array{Float64,1}) with respect to the observed final points, which true values are noted tab_uT_sim. As a first step I would like to differentiate the loss function corresponding to my fitting problem.

using DifferentialEquations, Plots
using DiffEqSensitivity, OrdinaryDiffEq, Zygote, Flux
using Random, Distributions

function fiip(du,u,p,t)
  du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
  du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2]
end

p = [1.5,1.0,3.0,1.0]


Random.seed!(123)
sampling_dist =  Distributions.Uniform(0.2, 2)

tab_u0 = Array{Float64}(undef, 100, 2)
for i in 1:100
    tab_u0[i, :] = rand(sampling_dist, 2)
end

I tried a first naïve version ignoring the “Mutating arrays is not supported”. To me, this the most natural way to express my problem : define a function that fills an array of end values by looping over initial values. I tried to install Zygote#mutate but there has been no commit for 16 months and I couldn’t make it work on my computer. Here is the code for this first version. It doesn’t work (“Mutating arrays is not supported”).

## Naïve array mutating version
function ode_p_test_1(p_p)
    tab_uT = Array{Float64}(undef, size(tab_u0) )
    a = [1.1, 1.2]
    for i in 1:100
        prob_test = remake( prob_ini, u0 = tab_u0[i, :], p = p_p )
        sol_test = Array( solve(prob_test, save_everystep = false, save_start = false) )
        tab_uT[i, :] = sol_test
    end
    tab_uT
end 

## The loss function
function loss_1(p_p)
    sum( abs2, ode_p_test_1(p_p) - tab_uT_sim )
end 


## True end values
tab_uT_sim = ode_p_test_1(p)
## Value of the parameter for which we compute the gradient
p_test = p .+ 0.01

Flux.gradient( loss_1, p_test )

Then I tried to get rid of my “array-filling-loop” based solution and I based my computation on map. The ode_p_test_2 function simulates the end pont of my ode for given parameter values (p_p) and initial conditions (p_u0)

####### Version with map
function ode_p_test_2(p_p, p_u0)
    prob_test = remake( prob_ini, u0 = p_u0, p = p_p )
    sol_test = Array( solve(prob_test, save_everystep = false, save_start = false) )
    return sol_test
end

I can simulate the end values for the true parameter values

## Simulate true end points values 
tab_uT_2 = map( p_u0 -> ode_p_test_2(p, p_u0), eachrow(tab_u0) )

And I can define the following loss function and use map to compute the end values for a given parameter value for all the considered initial conditions tab_u0.

function loss_2(p_p)
    diff_pred_obs = map( p_u0 -> ode_p_test_2(p_p, p_u0), eachrow(tab_u0) ) - tab_uT_2
    return sum( vcat( diff_pred_obs... ) .^2 )
end

I had some hope but it didn’t work

julia> Zygote.gradient( loss_2, p .+ 0.1 )

ERROR: MethodError: no method matching -(::NTuple{100,Array{Float64,2}})
Closest candidates are:
 -(::Any, ::VectorizationBase.Static{N}) where N at C:\Users\JC1116\.julia\packages\VectorizationBase\kIoqa\src\static.jl:153
 -(::Any, ::ChainRulesCore.DoesNotExist) at C:\Users\JC1116\.julia\packages\ChainRulesCore\AUdrf\src\differential_arithmetic.jl:25
 -(::Any, ::ChainRulesCore.Zero) at C:\Users\JC1116\.julia\packages\ChainRulesCore\AUdrf\src\differential_arithmetic.jl:58
 ...
Stacktrace:
[1] (::Zygote.var"#985#986")(::NTuple{100,Array{Float64,2}}) at C:\Users\JC1116\.julia\packages\Zygote\rqvFi\src\lib\array.jl:805
[2] (::Zygote.var"#3509#back#987"{Zygote.var"#985#986"})(::NTuple{100,Array{Float64,2}}) at C:\Users\JC1116\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:49
[3] loss_2 at .\untitled-0a03602544819dbf1e507937b8d0ba01:113 [inlined]
[4] (::typeof(∂(loss_2)))(::Float64) at C:\Users\JC1116\.julia\packages\Zygote\rqvFi\src\compiler\interface2.jl:0
[5] (::Zygote.var"#41#42"{typeof(∂(loss_2))})(::Float64) at C:\Users\JC1116\.julia\packages\Zygote\rqvFi\src\compiler\interface.jl:45
[6] gradient(::Function, ::Array{Float64,1}) at C:\Users\JC1116\.julia\packages\Zygote\rqvFi\src\compiler\interface.jl:54
[7] top-level scope at none:1

A last variation on that theme where the call to vcat is replaced by a comprehension

function loss_2bis(p_p)
    diff_pred_obs = map( p_u0 -> ode_p_test_2(p_p, p_u0), eachrow(tab_u0) ) - tab_uT_2
    return sum( [ sum( x .^2) for x in diff_pred_obs ] )
end

It doesn’t work :

julia> Zygote.gradient( loss_2bis, p .+ 0.1 )

ERROR: Need an adjoint for constructor Base.Generator{Base.OneTo{Int64},Base.var"#192#193"{Array{Float64,2}}}. Gradient is of type Array{Array{Float64,1},1}
Stacktrace:
 [1] error(::String) at .\error.jl:33
 [2] (::Zygote.Jnew{Base.Generator{Base.OneTo{Int64},Base.var"#192#193"{Array{Float64,2}}},Nothing,false})(::Array{Array{Float64,1},1}) at C:\Users\JC1116\.julia\packages\Zygote\rqvFi\src\lib\lib.jl:288
 [3] (::Zygote.var"#1763#back#192"{Zygote.Jnew{Base.Generator{Base.OneTo{Int64},Base.var"#192#193"{Array{Float64,2}}},Nothing,false}})(::Array{Array{Float64,1},1}) at C:\Users\JC1116\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:49
 [4] Generator at .\generator.jl:32 [inlined]
 [5] (::typeof(∂(Base.Generator{Base.OneTo{Int64},Base.var"#192#193"{Array{Float64,2}}})))(::Array{Array{Float64,1},1}) at C:\Users\JC1116\.julia\packages\Zygote\rqvFi\src\compiler\interface2.jl:0
 [6] Generator at .\generator.jl:32 [inlined]
 [7] (::typeof(∂(Base.Generator)))(::Array{Array{Float64,1},1}) at C:\Users\JC1116\.julia\packages\Zygote\rqvFi\src\compiler\interface2.jl:0
 [8] eachrow at .\abstractarraymath.jl:446 [inlined]
 [9] (::typeof(∂(eachrow)))(::Array{Array{Float64,1},1}) at C:\Users\JC1116\.julia\packages\Zygote\rqvFi\src\compiler\interface2.jl:0
 [10] loss_2bis at .\untitled-0a03602544819dbf1e507937b8d0ba01:122 [inlined]
 [11] (::typeof(∂(loss_2bis)))(::Float64) at C:\Users\JC1116\.julia\packages\Zygote\rqvFi\src\compiler\interface2.jl:0
 [12] (::Zygote.var"#41#42"{typeof(∂(loss_2bis))})(::Float64) at C:\Users\JC1116\.julia\packages\Zygote\rqvFi\src\compiler\interface.jl:45
 [13] gradient(::Function, ::Array{Float64,1}) at C:\Users\JC1116\.julia\packages\Zygote\rqvFi\src\compiler\interface.jl:54
 [14] top-level scope at none:1
  1. Does anyone have any idea about how to change my code to make it work ?

  2. I don’t want to blame anyone for that and I’m really amazed by the capabilities of Julia in terms of AD but I feel really uncomfortable with “Mutating arrays is not supported” policy as it goes against the natural way I consider problems. Do you think mutating arrays will be allowed in future versions ? What happened to Zygote#mutate ?

  3. Do you know any workaround to avoid the mutating arrays issue but close to a not-so-idiomatic solution ?

Thanks for any clue !

So I spent hours running this only to realise that my PC with its AMD card doesn’t do CUDA! Never mind. Just do this experiment: Don’t use the for loop so that you don’t need the Array in the line:
sol_test = Array( solve(prob_test, save_everystep = false, save_start = false) )
and instead write:
sol_test = solve(prob_test, save_everystep = false, save_start = false)
then do
dump(sol_test)
and it will show you what you are trying to coerce into an array.

  • Roger

The more natural way to do this would be to use reduce:

function loss_2(p_p)
    diff_pred_obs = map( p_u0 -> ode_p_test_2(p_p, p_u0), eachrow(tab_u0) ) - tab_uT_2
    return sum(reduce(vcat,diff_pred_obs) .^2)
end

That will avoid the problems associated with splatting, which is something you rarely want to do on a big array. Additionally, sum(abs2,reduce(vcat,diff_pred_obs)) is a slightly nicer style. Additionally, instead of using a map you can use Parallel ensembles to multithread the solves. This is demonstrated in the SDE parameter estimation tutorial:

https://diffeqflux.sciml.ai/dev/examples/optimization_sde/

So in total, all you need to do to make your code work is reduce instead of splat (@dhairyagandhi96 it might be good to try and figure out what’s up with that adjoint anyways), but using the ensembles will have some advantages and will make sum(abs2,sol-data) directly work. Cheers!

1 Like

Thank you guys,
Julia is both simple and sometimes cryptic for the newcomer that I am.
I will try your solutions and post back a working example.
Thanks again.

@mothmani
Here is a sample example that I wrote to work around mutating arrays. Or how to define a ChainRule for the offending piece of code. Maybe this can help you fix your problem.
https://github.com/rakeshvar/Zygote-Mutating-Arrays-WorkAround.jl

1 Like