JuliaDiffEq with custom types

diffeq

#1

i’m writing a program where i’m using my own array type, and i’d like to use DifferentialEquations to solve some ODEs that show up. i was under the impression that it should work with any AbstractArray, but this is either not the case or (most likely) i’m doing something wrong.

below i illustrate the issue with a minimal working example.

using DifferentialEquations

struct my_type{A,N} <: AbstractArray{A,N}
    data :: A
    name :: String
end
my_type(data::AbstractArray{T,N}, name::String) where {T,N} =
    my_type{typeof(data),N}(data, name)

Base.size(var::my_type) = size(var.data)

Base.getindex(var::my_type, i::Int) = var.data[i]
Base.getindex(var::my_type, I::Vararg{Int,N}) where {N} = var.data[I...]
Base.getindex(var::my_type, ::Colon) = var.data[:]
Base.getindex(var::my_type, kr::AbstractRange) = var.data[kr]

Base.setindex!(var::my_type, v, i::Int) = (var.data[i] = v)
Base.setindex!(var::my_type, v, I::Vararg{Int,N}) where {N} = (var.data[I...] = v)
Base.setindex!(var::my_type, v, ::Colon) = (var.data[:] .= v)
Base.setindex!(var::my_type, v, kr::AbstractRange) = (var.data[kr] .= v)

function rhs_test(f, p, t)
    f
end

xmin   = -2.0*pi
xmax   =  2.0*pi
xnodes =  600
hx     = (xmax - xmin) / xnodes

xx = range(xmin, stop=xmax, length=600)

x0   = 0
w    = 0.4
A    = 1

f0    = A * exp.( -((xx .- x0) ./ w).^2 )
foo   = my_type(f0, "foo")

tspan = (0.0, 1.0)

# this works fine
prob  = ODEProblem(rhs_test, f0, tspan)
sol   = solve(prob, RK4())

# this does not
prob2 = ODEProblem(rhs_test, foo, tspan)
sol2  = solve(prob2, RK4())

basically i’m defining my own subtype of AbstractArray and define the getindex and setindex! methods as specified in the documentation, but this doesn’t seem to be enough as i get the error

ERROR: LoadError: MethodError: Cannot `convert` an object of type Float64 to an 
object of type Array{Float64,1} 

am i missing some method, is this not possible at all, or is it the wrong way of approaching the problem?


#2

Custom number types is super easy. You just need the standard arithmetic (+, -, /, *, and then sqrt if you don’t use a custom internalnorm). Custom array types take a little bit more. They need to use a compatible number type, be validly defined, define similar (see the AbstractArray interface, it mentions the similar methods you need), handle scalar multiples, and compatible additions. You need compatibility with your chosen linear solver or Jacobian if using an implicit method. You need a valid in-place broadcast for in-place methods (usually this can happen automatically if you define indexing).

However, your main problem is with your array definition. Look at

struct my_type{A,N} <: AbstractArray{A,N}
    data :: A
    name :: String
end

what you’re saying here is that the element type of your array is A, which in your example is eltype(foo) == Vector{Float64}. Thus when it’s doing its setup, it thinks that the “number type” you’re using is Vector{Float64} and errors. What you meant to say was that your element type was Float64, as in:

struct my_type{A,N} <: AbstractArray{A,N}
    data :: Vector{A}
    name :: String
end
my_type(data::AbstractArray{T,N}, name::String) where {T,N} =
    my_type{eltype(data),N}(data, name)

Then your array type now makes sense. However, it’s not closed under arithmetic. For example:

typeof(2foo) # Vector{Float64}
typeof(foo + foo) # Vector{Float64}

How is the integrator supposed to know how to propagate your my_type? We need some kind of definition choice here. For fun, I’m going to say that we only take the left strings. So let’s define:

Base.:+(x::my_type,y::my_type) = my_type(x.data+y.data,x.name)
Base.:*(x::Number,y::my_type) = my_type(x*y.data,y.name)
Base.:/(x::my_type,y::Number) = my_type(x.data/y,x.name)
Base.similar(foo::my_type) = my_type(similar(foo.data),foo.name)

Technically, according to the AbstractArray interface you should add a few more similar methods, but this works. Together, the code is:

using DifferentialEquations

struct my_type{A,N} <: AbstractArray{A,N}
    data :: Vector{A}
    name :: String
end
my_type(data::AbstractArray{T,N}, name::String) where {T,N} =
    my_type{eltype(data),N}(data, name)

Base.size(var::my_type) = size(var.data)

Base.getindex(var::my_type, i::Int) = var.data[i]
Base.getindex(var::my_type, I::Vararg{Int,N}) where {N} = var.data[I...]
Base.getindex(var::my_type, ::Colon) = var.data[:]
Base.getindex(var::my_type, kr::AbstractRange) = var.data[kr]

Base.setindex!(var::my_type, v, i::Int) = (var.data[i] = v)
Base.setindex!(var::my_type, v, I::Vararg{Int,N}) where {N} = (var.data[I...] = v)
Base.setindex!(var::my_type, v, ::Colon) = (var.data[:] .= v)
Base.setindex!(var::my_type, v, kr::AbstractRange) = (var.data[kr] .= v)

function rhs_test(f, p, t)
    f
end

xmin   = -2.0*pi
xmax   =  2.0*pi
xnodes =  600
hx     = (xmax - xmin) / xnodes

xx = range(xmin, stop=xmax, length=600)

x0   = 0
w    = 0.4
A    = 1

f0    = A * exp.( -((xx .- x0) ./ w).^2 )
foo   = my_type(f0, "foo")

tspan = (0.0, 1.0)

Base.:+(x::my_type,y::my_type) = my_type(x.data+y.data,x.name)
Base.:*(x::Number,y::my_type) = my_type(x*y.data,y.name)
Base.similar(foo::my_type) = my_type(similar(foo.data),foo.name)

prob  = ODEProblem(rhs_test, foo, tspan)
sol   = solve(prob, RK4())

(Note that to work on master you also need to define /, but I just put in a quick PR to remove that requirement). Now if you want to make your type work with in-place, then you don’t need the arithmetic since that’s defined by broadcast. Here you just need one of your other missing similar methods:

Base.similar(foo::my_type,::Type{T}) where T = my_type(similar(foo.data,T),foo.name)

So a full working code for that is:

using DifferentialEquations

struct my_type{A,N} <: AbstractArray{A,N}
    data :: Vector{A}
    name :: String
end
my_type(data::AbstractArray{T,N}, name::String) where {T,N} =
    my_type{eltype(data),N}(data, name)

Base.size(var::my_type) = size(var.data)

Base.getindex(var::my_type, i::Int) = var.data[i]
Base.getindex(var::my_type, I::Vararg{Int,N}) where {N} = var.data[I...]
Base.getindex(var::my_type, ::Colon) = var.data[:]
Base.getindex(var::my_type, kr::AbstractRange) = var.data[kr]

Base.setindex!(var::my_type, v, i::Int) = (var.data[i] = v)
Base.setindex!(var::my_type, v, I::Vararg{Int,N}) where {N} = (var.data[I...] = v)
Base.setindex!(var::my_type, v, ::Colon) = (var.data[:] .= v)
Base.setindex!(var::my_type, v, kr::AbstractRange) = (var.data[kr] .= v)

function rhs_test(f, p, t)
    f
end

xmin   = -2.0*pi
xmax   =  2.0*pi
xnodes =  600
hx     = (xmax - xmin) / xnodes

xx = range(xmin, stop=xmax, length=600)

x0   = 0
w    = 0.4
A    = 1

f0    = A * exp.( -((xx .- x0) ./ w).^2 )
foo   = my_type(f0, "foo")

tspan = (0.0, 1.0)

Base.similar(foo::my_type) = my_type(similar(foo.data),foo.name)
Base.similar(foo::my_type,::Type{T}) where T = my_type(similar(foo.data,T),foo.name)
function rhs_test2(df, f, p, t)
    df.data .= f.data
end
prob  = ODEProblem(rhs_test2, foo, tspan)
sol   = solve(prob, RK4())

So even for arrays it’s not bad: only two methods not on the AbstractArray interface were needed for out-of-place to define arithmetic, and in-place only needed the AbstractArray interface (but needed to make sure you define the similar part!).

Note that for these “extra” features we have a little bit more to do in order to get back to where we were in Julia v0.6. Since the definition of an AbstractArray and Broadcast has changed 3 times over the last 2 years, we haven’t been able to sit down and properly define our interface here. However, with the stability of v1.0 I hope we can finally do so soon :slight_smile: .


#3

When defining a wrapper type like your custom array type, I often find the macro @forward from Lazy.jl very convenient. It allows you to “forward” a set of function on your special type to the wrapped field inside, this way you do not have to manually define all those functions to simply operate on the field.

Example

julia> using Lazy

julia> struct MM
           x
       end

julia> @forward MM.x (Base.length, Base.iterate, Base.getindex)

julia> m = MM([1,2,3])
MM([1, 2, 3])

julia> length(m)
3

julia> m[3]
3

#4

this is it, thanks!