JuliaDiffEq with custom types

i’m writing a program where i’m using my own array type, and i’d like to use DifferentialEquations to solve some ODEs that show up. i was under the impression that it should work with any AbstractArray, but this is either not the case or (most likely) i’m doing something wrong.

below i illustrate the issue with a minimal working example.

using DifferentialEquations

struct my_type{A,N} <: AbstractArray{A,N}
    data :: A
    name :: String
end
my_type(data::AbstractArray{T,N}, name::String) where {T,N} =
    my_type{typeof(data),N}(data, name)

Base.size(var::my_type) = size(var.data)

Base.getindex(var::my_type, i::Int) = var.data[i]
Base.getindex(var::my_type, I::Vararg{Int,N}) where {N} = var.data[I...]
Base.getindex(var::my_type, ::Colon) = var.data[:]
Base.getindex(var::my_type, kr::AbstractRange) = var.data[kr]

Base.setindex!(var::my_type, v, i::Int) = (var.data[i] = v)
Base.setindex!(var::my_type, v, I::Vararg{Int,N}) where {N} = (var.data[I...] = v)
Base.setindex!(var::my_type, v, ::Colon) = (var.data[:] .= v)
Base.setindex!(var::my_type, v, kr::AbstractRange) = (var.data[kr] .= v)

function rhs_test(f, p, t)
    f
end

xmin   = -2.0*pi
xmax   =  2.0*pi
xnodes =  600
hx     = (xmax - xmin) / xnodes

xx = range(xmin, stop=xmax, length=600)

x0   = 0
w    = 0.4
A    = 1

f0    = A * exp.( -((xx .- x0) ./ w).^2 )
foo   = my_type(f0, "foo")

tspan = (0.0, 1.0)

# this works fine
prob  = ODEProblem(rhs_test, f0, tspan)
sol   = solve(prob, RK4())

# this does not
prob2 = ODEProblem(rhs_test, foo, tspan)
sol2  = solve(prob2, RK4())

basically i’m defining my own subtype of AbstractArray and define the getindex and setindex! methods as specified in the documentation, but this doesn’t seem to be enough as i get the error

ERROR: LoadError: MethodError: Cannot `convert` an object of type Float64 to an 
object of type Array{Float64,1} 

am i missing some method, is this not possible at all, or is it the wrong way of approaching the problem?

Custom number types is super easy. You just need the standard arithmetic (+, -, /, *, and then sqrt if you don’t use a custom internalnorm). Custom array types take a little bit more. They need to use a compatible number type, be validly defined, define similar (see the AbstractArray interface, it mentions the similar methods you need), handle scalar multiples, and compatible additions. You need compatibility with your chosen linear solver or Jacobian if using an implicit method. You need a valid in-place broadcast for in-place methods (usually this can happen automatically if you define indexing).

However, your main problem is with your array definition. Look at

struct my_type{A,N} <: AbstractArray{A,N}
    data :: A
    name :: String
end

what you’re saying here is that the element type of your array is A, which in your example is eltype(foo) == Vector{Float64}. Thus when it’s doing its setup, it thinks that the “number type” you’re using is Vector{Float64} and errors. What you meant to say was that your element type was Float64, as in:

struct my_type{A,N} <: AbstractArray{A,N}
    data :: Vector{A}
    name :: String
end
my_type(data::AbstractArray{T,N}, name::String) where {T,N} =
    my_type{eltype(data),N}(data, name)

Then your array type now makes sense. However, it’s not closed under arithmetic. For example:

typeof(2foo) # Vector{Float64}
typeof(foo + foo) # Vector{Float64}

How is the integrator supposed to know how to propagate your my_type? We need some kind of definition choice here. For fun, I’m going to say that we only take the left strings. So let’s define:

Base.:+(x::my_type,y::my_type) = my_type(x.data+y.data,x.name)
Base.:*(x::Number,y::my_type) = my_type(x*y.data,y.name)
Base.:/(x::my_type,y::Number) = my_type(x.data/y,x.name)
Base.similar(foo::my_type) = my_type(similar(foo.data),foo.name)

Technically, according to the AbstractArray interface you should add a few more similar methods, but this works. Together, the code is:

using DifferentialEquations

struct my_type{A,N} <: AbstractArray{A,N}
    data :: Vector{A}
    name :: String
end
my_type(data::AbstractArray{T,N}, name::String) where {T,N} =
    my_type{eltype(data),N}(data, name)

Base.size(var::my_type) = size(var.data)

Base.getindex(var::my_type, i::Int) = var.data[i]
Base.getindex(var::my_type, I::Vararg{Int,N}) where {N} = var.data[I...]
Base.getindex(var::my_type, ::Colon) = var.data[:]
Base.getindex(var::my_type, kr::AbstractRange) = var.data[kr]

Base.setindex!(var::my_type, v, i::Int) = (var.data[i] = v)
Base.setindex!(var::my_type, v, I::Vararg{Int,N}) where {N} = (var.data[I...] = v)
Base.setindex!(var::my_type, v, ::Colon) = (var.data[:] .= v)
Base.setindex!(var::my_type, v, kr::AbstractRange) = (var.data[kr] .= v)

function rhs_test(f, p, t)
    f
end

xmin   = -2.0*pi
xmax   =  2.0*pi
xnodes =  600
hx     = (xmax - xmin) / xnodes

xx = range(xmin, stop=xmax, length=600)

x0   = 0
w    = 0.4
A    = 1

f0    = A * exp.( -((xx .- x0) ./ w).^2 )
foo   = my_type(f0, "foo")

tspan = (0.0, 1.0)

Base.:+(x::my_type,y::my_type) = my_type(x.data+y.data,x.name)
Base.:*(x::Number,y::my_type) = my_type(x*y.data,y.name)
Base.similar(foo::my_type) = my_type(similar(foo.data),foo.name)

prob  = ODEProblem(rhs_test, foo, tspan)
sol   = solve(prob, RK4())

(Note that to work on master you also need to define /, but I just put in a quick PR to remove that requirement). Now if you want to make your type work with in-place, then you don’t need the arithmetic since that’s defined by broadcast. Here you just need one of your other missing similar methods:

Base.similar(foo::my_type,::Type{T}) where T = my_type(similar(foo.data,T),foo.name)

So a full working code for that is:

using DifferentialEquations

struct my_type{A,N} <: AbstractArray{A,N}
    data :: Vector{A}
    name :: String
end
my_type(data::AbstractArray{T,N}, name::String) where {T,N} =
    my_type{eltype(data),N}(data, name)

Base.size(var::my_type) = size(var.data)

Base.getindex(var::my_type, i::Int) = var.data[i]
Base.getindex(var::my_type, I::Vararg{Int,N}) where {N} = var.data[I...]
Base.getindex(var::my_type, ::Colon) = var.data[:]
Base.getindex(var::my_type, kr::AbstractRange) = var.data[kr]

Base.setindex!(var::my_type, v, i::Int) = (var.data[i] = v)
Base.setindex!(var::my_type, v, I::Vararg{Int,N}) where {N} = (var.data[I...] = v)
Base.setindex!(var::my_type, v, ::Colon) = (var.data[:] .= v)
Base.setindex!(var::my_type, v, kr::AbstractRange) = (var.data[kr] .= v)

function rhs_test(f, p, t)
    f
end

xmin   = -2.0*pi
xmax   =  2.0*pi
xnodes =  600
hx     = (xmax - xmin) / xnodes

xx = range(xmin, stop=xmax, length=600)

x0   = 0
w    = 0.4
A    = 1

f0    = A * exp.( -((xx .- x0) ./ w).^2 )
foo   = my_type(f0, "foo")

tspan = (0.0, 1.0)

Base.similar(foo::my_type) = my_type(similar(foo.data),foo.name)
Base.similar(foo::my_type,::Type{T}) where T = my_type(similar(foo.data,T),foo.name)
function rhs_test2(df, f, p, t)
    df.data .= f.data
end
prob  = ODEProblem(rhs_test2, foo, tspan)
sol   = solve(prob, RK4())

So even for arrays it’s not bad: only two methods not on the AbstractArray interface were needed for out-of-place to define arithmetic, and in-place only needed the AbstractArray interface (but needed to make sure you define the similar part!).

Note that for these “extra” features we have a little bit more to do in order to get back to where we were in Julia v0.6. Since the definition of an AbstractArray and Broadcast has changed 3 times over the last 2 years, we haven’t been able to sit down and properly define our interface here. However, with the stability of v1.0 I hope we can finally do so soon :slight_smile: .

5 Likes

When defining a wrapper type like your custom array type, I often find the macro @forward from Lazy.jl very convenient. It allows you to “forward” a set of function on your special type to the wrapped field inside, this way you do not have to manually define all those functions to simply operate on the field.

Example

julia> using Lazy

julia> struct MM
           x
       end

julia> @forward MM.x (Base.length, Base.iterate, Base.getindex)

julia> m = MM([1,2,3])
MM([1, 2, 3])

julia> length(m)
3

julia> m[3]
3
6 Likes

this is it, thanks!

Is it possible to do this with an immutable type too? I.e. with a type:

using StaticArrays
struct my_type{N,A} <: AbstractVector{A}
    data :: SVector{N,A}
    name :: String
end
my_type(data::AbstractVector{T}, name::String) where {T} =
    my_type{length(data), eltype(data)}(SVector{length(data)}(data), name)

In this case setindex! cannot be defined and also similar will have to return a different type. The fact that StaticArrays can be used in DiffEq suggests that it is possible. But a naive modification of above (not defining setindex! and similar) does not work.

Then you also need to change to an out-of-place function, i.e. omit the 1st argument and return the value:

function rhs_test2_oop(f, p, t)
    f # return value is the derivative
end
prob  = ODEProblem(rhs_test2_oop, foo, tspan)
sol   = solve(prob, RK4())

Yes, I messed something up before, The example then runs but has typeof(sol[1])==Vector and not Mtype.

And thus when trying to carry a parameter around, it does not work:

using DifferentialEquations, StaticArrays

struct Mtype{N,A} <: AbstractVector{A}
    data :: SVector{N,A}
    name :: String
    p::Float64
end
Mtype(data::AbstractVector{T}, name::String, p) where {T} =
    Mtype{length(data), eltype(data)}(SVector{length(data)}(data), name, p)

Base.size(var::Mtype{N}) where N = (N,)

Base.getindex(var::Mtype, i::Int) = var.data[i]
Base.IndexStyle(::Mtype) = IndexLinear()

rhs_test(f, p, t) = f*f.p

xmin   = -2.0*pi
xmax   =  2.0*pi
xnodes =  10
hx     = (xmax - xmin) / xnodes

xx = range(xmin, stop=xmax, length=10)

x0   = 0
w    = 0.4
A    = 1

f0    = A * exp.( -((xx .- x0) ./ w).^2 )
foo   = Mtype(f0, "foo", 2.0)

tspan = (0.0, 1.0)

Base.:+(x::Mtype,y::Mtype) = Mtype(x.data+y.data,x.name)
Base.:*(x::Number,y::Mtype) = Mtype(x*y.data,y.name)

prob  = ODEProblem(rhs_test, foo, tspan)
sol   = solve(prob, RK4())

errors with ERROR: LoadError: type Array has no field p. What seems to happen is that similar(::Mtype) is used, which creates an Array and thus has no field foo.p.

Your problem is:

struct Mtype{N,A} <: AbstractVector{A}
    data :: SVector{N,A}
    name :: String
    p::Float64
end
Mtype(data::AbstractVector{T}, name::String, p) where {T} =
    Mtype{length(data), eltype(data)}(SVector{length(data)}(data), name, p)

Base.size(var::Mtype{N}) where N = (N,)

Base.getindex(var::Mtype, i::Int) = var.data[i]
Base.IndexStyle(::Mtype) = IndexLinear()

rhs_test(f, p, t) = f*f.p

xmin   = -2.0*pi
xmax   =  2.0*pi
xnodes =  10
hx     = (xmax - xmin) / xnodes

xx = range(xmin, stop=xmax, length=10)

x0   = 0
w    = 0.4
A    = 1

f0    = A * exp.( -((xx .- x0) ./ w).^2 )
foo   = Mtype(f0, "foo", 2.0)

foo .+ foo

Broadcasted expressions on your type return arrays. You’ll want to overload broadcast so that way it returns an MType.

This works:

using DiffEqBase, OrdinaryDiffEq, StaticArrays
using Base: @propagate_inbounds
using Base.Broadcast: Broadcasted

struct Mtype{N,T} <: AbstractVector{T}
    data :: SVector{N,T}
    name :: String
    p::Float64
end
Mtype(data::AbstractVector{T}, name::String, p) where {T} =
    Mtype{length(data), T}(SVector{length(data)}(data), name, p)
Base.copy(m::Mtype) = deepcopy(m)
Base.zero(m::Mtype{N,T}) where {N,T} = Mtype{N,T}(zero(SVector{N,T}), m.name, m.p)

# AbstractVector interface
Base.size(var::Mtype{N}) where N = (N,)
@propagate_inbounds Base.getindex(var::Mtype, i::Int) = var.data[i]
Base.IndexStyle(::Mtype) = IndexLinear()

# broadcast interface
struct MtypeStyle{N,T} <: Broadcast.AbstractArrayStyle{1} end
# Whenever you subtype AbstractArrayStyle, you also need to define rules for combining
# dimensionalities, by creating a constructor for your style that takes a Val(N) argument.
MtypeStyle{N,T}(::Val{0}) where {N,T} = MtypeStyle{N,T}()
MtypeStyle{N,T}(::Val{1}) where {N,T} = MtypeStyle{N,T}()

Base.BroadcastStyle(::Type{<:Mtype{N,T}}) where {N,T} = MtypeStyle{N,T}()
function Base.copy(bc::Broadcasted{<:MtypeStyle{N,T}}) where {N,T}
    # traverse Broadcasted tree to find Mtype instances
    name, p = get_name_p(bc)

    # TODO: this is what causes the slowness:
    Mtype{N,T}(SVector(tuple((v for v in bc)...)), name, p)
    # e.g compared with this:
    # Mtype{N,T}(zeros(SVector{10,Float64}), name, p)
end

"Retrieve the fields name and p from the Broadcasted instance and make sure they all match"
function get_name_p(bc::Broadcasted)
    name::Union{String, Nothing}=nothing
    p::Union{Float64, Nothing}=nothing
    for a in bc.args
        if a isa Mtype
            if name===nothing
                name = a.name
                p = a.p
            else
                name==a.name || error("Broadcasting of Mtypes with different name, p not supported.")
                p==a.p || error("Broadcasting of Mtypes with different name, p not supported.")
            end
        elseif a isa Broadcasted
            nn, pp = get_name_p(a)
            if name===nothing
                name = nn
                p = pp
            elseif nn!==nothing
                name==nn || error("Broadcasting of Mtypes with different name, p not supported.")
                p==pp || error("Broadcasting of Mtypes with different name, p not supported.")
            end
        end
    end
    return name, p
end

# Test it
rhs_test(f::Mtype, p, t) = f*f.p

f0    = rand(10)
sf0   = SVector{length(f0)}(f0)
foo   = Mtype(sf0, "foo", 2.0)

tspan = (0.0, 100.0)

prob  = ODEProblem(rhs_test, foo, tspan)
sol   = solve(prob, RK4())
@time solve(prob, RK4()) # 0.109459 seconds (1.95 M allocations: 71.409 MiB, 13.93% gc time)

# test against straight SVector version
rhs_test2(f, p, t) = f*2
prob2  = ODEProblem(rhs_test2, sf0, tspan)
sol2   = solve(prob2, RK4())
@time solve(prob2, RK4()) # 0.000486 seconds (3.62 k allocations: 427.281 KiB)

# against Vector version
prob2  = ODEProblem(rhs_test2, f0, tspan)
sol2   = solve(prob2, RK4())
@time solve(prob2, RK4()) # 0.005367 seconds (52.43 k allocations: 7.910 MiB)


using Test
for i = 1:length(sol)
    @test sol[i]==sol2[i]
end

but is slow, about 200x slower than StaticArray (times are in the comments). The reason is the
line Mtype{N,T}(SVector(tuple((v for v in bc)...)), name, p) in the copy function.

You might want to make this <: StaticArray. That might make it have some better dispatches. There are likely some other ways to speed this up too.

The problem with subtyping from StaticVector is that then I need to define similar_type https://juliaarrays.github.io/StaticArrays.jl/stable/pages/api.html#Implementing-your-own-types-1. But that assumes, as far as I can tell, that all the fields are part of the array, i.e. there cannot be any metadata name, p.

Otherwise that would be perfect (or similarly working with FieldVector).

Edit: x-ref https://github.com/JuliaArrays/StaticArrays.jl/issues/592

This recovers the fastness of Static arrays, using Setfield.jl

function Base.copy(bc::Broadcasted{<:MtypeStyle{N,T}}) where {N,T}
    # traverse Broadcasted tree to find Mtype instances
    name, p = get_name_p(bc)
    # This trick works:
    # https://discourse.julialang.org/t/constructing-svector-with-a-loop/15372/7
    s = SVector{N,T}(1:N)
    for (i,v) in enumerate(bc)
        @set s[i] = v
    end
    Mtype{N,T}(s, name, p)
end

x-ref Constructing SVector with a loop - #7 by tkf

2 Likes

Would something like this be a good addition to the docs? Where would it go?

I think you need to use s = @set s[i] = v or its shortcut @set! s[i] = v.

It would be a good addition to the docs, probably an FAQ on how define state types, or a discussion on the problem interface page

1 Like

Oops, you’re right. This makes timing worse, but only about a factor 20x slower instead of 200x (and still ~2x slower than vector).

Aright, improving… The Mtype was not isbits, which obviously hurt performance a lot. Now it’s only 1.5x slower than the SVector version (when skipping get_name_p). There is an ugly hack with the _copy methods. If anyone can tell me how to improve on that, that be awesome.

I’ll try to put a PR for the docs together.

using DiffEqBase, OrdinaryDiffEq, StaticArrays
using Base: @propagate_inbounds
using Base.Broadcast: Broadcasted
using Setfield
using Test, BenchmarkTools

struct Mtype{N,T} <: AbstractVector{T}
    data::SVector{N,T}
    name::Int
    p::Float64
end
Mtype(data::AbstractVector{T}, name::Int, p) where {T} =
    Mtype{length(data), T}(SVector{length(data)}(data), name, p)
Base.copy(m::Mtype) = deepcopy(m)
Base.zero(m::Mtype{N,T}) where {N,T} = Mtype{N,T}(zero(SVector{N,T}), m.name, m.p)
# NOTE: that only the array-parts are hashed for <:AbstractVector
tmp1   = Mtype(1:2, -1, 2.0)
tmp2   = Mtype(1:2, -1, 3.0)
@test_broken hash(tmp1)!=hash(tmp2)
@test isbits(tmp1)

# AbstractVector interface
Base.size(var::Mtype{N}) where N = (N,)
@propagate_inbounds Base.getindex(var::Mtype, i::Int) = var.data[i]
Base.IndexStyle(::Mtype) = IndexLinear()

# broadcast interface
struct MtypeStyle{N,T} <: Broadcast.AbstractArrayStyle{1} end
# Whenever you subtype AbstractArrayStyle, you also need to define rules for combining
# dimensionalities, by creating a constructor for your style that takes a Val(N) argument.
MtypeStyle{N,T}(::Val{0}) where {N,T} = MtypeStyle{N,T}()
MtypeStyle{N,T}(::Val{1}) where {N,T} = MtypeStyle{N,T}()

Base.BroadcastStyle(::Type{<:Mtype{N,T}}) where {N,T} = MtypeStyle{N,T}()
function Base.copy(bc::Broadcasted{<:MtypeStyle{N,T}}) where {N,T}
    return _copy(bc, bc.args...)
end
# these separate _copy functions speed things up by about 2x.
# But they sure are ugly.
@inline function _copy(bc::Broadcasted{<:MtypeStyle{N,T}}, a1::Mtype{N,T}, a2::Mtype{N,T}) where {N,T}
    f = bc.f
    a1.name==a2.name || error("Broadcasting of Mtypes with different name, p not supported.")
    a1.p==a2.p || error("Broadcasting of Mtypes with different name, p not supported.")
    Mtype{N,T}(broadcast(f, a1.data, a2.data), a1.name, a1.p)
end
@inline _copy(bc::Broadcasted{<:MtypeStyle{N,T}}, a1::Mtype{N,T}, a2) where {N,T} =
    Mtype{N,T}(broadcast(bc.f, a1.data, a2), a1.name, a1.p)
@inline _copy(bc::Broadcasted{<:MtypeStyle{N,T}}, a1, a2::Mtype{N,T}) where {N,T} =
    Mtype{N,T}(broadcast(bc.f, a1, a2.data), a2.name, a2.p)
@inline function _copy(bc::Broadcasted{<:MtypeStyle{N,T}}, args...) where {N,T}
    # traverse Broadcasted tree to find Mtype instances
    name, p = -1, 2.0 # get_name_p(bc) # this uses about 120μs and 0.4MiB

    # This is slow:
    # Mtype{N,T}(SVector(tuple((v for v in bc)...)), name, p)
    # thus use Setfield.jl trick:
    # https://discourse.julialang.org/t/constructing-svector-with-a-loop/15372/7
    s = SVector{N,T}(1:N)
    for (i,v) in enumerate(bc)
        @set! s[i] = v
    end
    Mtype{N,T}(s, name, p)
end


"""
Retrieve the fields name and p from the Broadcasted
instance and make sure they all match

TODO: a wee bit slow, this needs to be done better.
"""
function get_name_p(bc::Broadcasted)
    name::Union{Int, Nothing}=nothing
    p::Union{Float64, Nothing}=nothing
    for a in bc.args
        if a isa Mtype
            if name===nothing
                name = a.name
                p = a.p
            else
                name==a.name || error("Broadcasting of Mtypes with different name, p not supported.")
                p==a.p || error("Broadcasting of Mtypes with different name, p not supported.")
            end
        elseif a isa Broadcasted
            nn, pp = get_name_p(a)
            if name===nothing
                name = nn
                p = pp
            elseif nn!==nothing
                name==nn || error("Broadcasting of Mtypes with different name, p not supported.")
                p==pp || error("Broadcasting of Mtypes with different name, p not supported.")
            end
        end
    end
    return name::Int, p::Float64
end

# Test it
rhs_test(f::Mtype, p, t) = f*f.p

f0    = rand(10)
sf0   = SVector{length(f0)}(f0)
foo   = Mtype(sf0, -1, 2.0)

tspan = (0.0, 100)

prob  = ODEProblem(rhs_test, foo, tspan)
sol   = solve(prob, RK4())
@btime solve(prob, RK4()) #  568.960 μs (6553 allocations: 1.32 MiB)

# test against straight SVector version
rhs_test2(f, p, t) = f*2
prob2  = ODEProblem(rhs_test2, sf0, tspan)
sol2   = solve(prob2, RK4())
@btime solve(prob2, RK4()) # 366.731 μs (3615 allocations: 426.84 KiB)

# against Vector version
prob2  = ODEProblem(rhs_test2, f0, tspan)
sol2   = solve(prob2, RK4())
@btime solve(prob2, RK4()) # 2.246 ms (52423 allocations: 7.91 MiB)


@test length(sol)==length(sol2)
for i = 1:length(sol)
    @test sol[i]==sol2[i]
end

On this topic, there’s currently an issue here that ArrayPartition is not compatible with esp. stiff solvers since linear algebra is not implemented for it, (it’s mentioned that PartitionedODEProblem use them internally so maybe it’s also affected? btw, PartitionedODEProblem not defined appears to be undefined, looks like code is absent)

Is there a roadmap about what needs to be done if one wishes to help with it (overloading the functions with ArrayPartition or copy it in a linearized way to a cache)? It can be very useful for a set of coupled odes that included matrices, vectors and numbers, e.g.

v’ = fv(m_i,v,s)
m_i’ = fm(m_j,v,s)
s’ = fs(m_i,v,s)
where v is a vector, m_{1…n} are matrices and s is a number.

right now, I’m handling it through reshaping v,m,s to 1d and vcat them together, speed is okay for small systems but for bigger systems it may require preallocating and passing the preallocated arrays via the p parameter (may still generate some allocations).

By the way, I suppose s = SVector{N,T}(ntuple(i -> (@inbounds bc[i]), Val(N))) is equivalent to the @set! solution in terms of type stability?

1 Like

For this case we would really want to do blocked linear algebra, but I am not sure of a good solution without something like Swizzles. We may just have to densify until we have more linear algebra tools.

1 Like