JuliaDiffEq with custom types

Custom number types is super easy. You just need the standard arithmetic (+, -, /, *, and then sqrt if you don’t use a custom internalnorm). Custom array types take a little bit more. They need to use a compatible number type, be validly defined, define similar (see the AbstractArray interface, it mentions the similar methods you need), handle scalar multiples, and compatible additions. You need compatibility with your chosen linear solver or Jacobian if using an implicit method. You need a valid in-place broadcast for in-place methods (usually this can happen automatically if you define indexing).

However, your main problem is with your array definition. Look at

struct my_type{A,N} <: AbstractArray{A,N}
    data :: A
    name :: String
end

what you’re saying here is that the element type of your array is A, which in your example is eltype(foo) == Vector{Float64}. Thus when it’s doing its setup, it thinks that the “number type” you’re using is Vector{Float64} and errors. What you meant to say was that your element type was Float64, as in:

struct my_type{A,N} <: AbstractArray{A,N}
    data :: Vector{A}
    name :: String
end
my_type(data::AbstractArray{T,N}, name::String) where {T,N} =
    my_type{eltype(data),N}(data, name)

Then your array type now makes sense. However, it’s not closed under arithmetic. For example:

typeof(2foo) # Vector{Float64}
typeof(foo + foo) # Vector{Float64}

How is the integrator supposed to know how to propagate your my_type? We need some kind of definition choice here. For fun, I’m going to say that we only take the left strings. So let’s define:

Base.:+(x::my_type,y::my_type) = my_type(x.data+y.data,x.name)
Base.:*(x::Number,y::my_type) = my_type(x*y.data,y.name)
Base.:/(x::my_type,y::Number) = my_type(x.data/y,x.name)
Base.similar(foo::my_type) = my_type(similar(foo.data),foo.name)

Technically, according to the AbstractArray interface you should add a few more similar methods, but this works. Together, the code is:

using DifferentialEquations

struct my_type{A,N} <: AbstractArray{A,N}
    data :: Vector{A}
    name :: String
end
my_type(data::AbstractArray{T,N}, name::String) where {T,N} =
    my_type{eltype(data),N}(data, name)

Base.size(var::my_type) = size(var.data)

Base.getindex(var::my_type, i::Int) = var.data[i]
Base.getindex(var::my_type, I::Vararg{Int,N}) where {N} = var.data[I...]
Base.getindex(var::my_type, ::Colon) = var.data[:]
Base.getindex(var::my_type, kr::AbstractRange) = var.data[kr]

Base.setindex!(var::my_type, v, i::Int) = (var.data[i] = v)
Base.setindex!(var::my_type, v, I::Vararg{Int,N}) where {N} = (var.data[I...] = v)
Base.setindex!(var::my_type, v, ::Colon) = (var.data[:] .= v)
Base.setindex!(var::my_type, v, kr::AbstractRange) = (var.data[kr] .= v)

function rhs_test(f, p, t)
    f
end

xmin   = -2.0*pi
xmax   =  2.0*pi
xnodes =  600
hx     = (xmax - xmin) / xnodes

xx = range(xmin, stop=xmax, length=600)

x0   = 0
w    = 0.4
A    = 1

f0    = A * exp.( -((xx .- x0) ./ w).^2 )
foo   = my_type(f0, "foo")

tspan = (0.0, 1.0)

Base.:+(x::my_type,y::my_type) = my_type(x.data+y.data,x.name)
Base.:*(x::Number,y::my_type) = my_type(x*y.data,y.name)
Base.similar(foo::my_type) = my_type(similar(foo.data),foo.name)

prob  = ODEProblem(rhs_test, foo, tspan)
sol   = solve(prob, RK4())

(Note that to work on master you also need to define /, but I just put in a quick PR to remove that requirement). Now if you want to make your type work with in-place, then you don’t need the arithmetic since that’s defined by broadcast. Here you just need one of your other missing similar methods:

Base.similar(foo::my_type,::Type{T}) where T = my_type(similar(foo.data,T),foo.name)

So a full working code for that is:

using DifferentialEquations

struct my_type{A,N} <: AbstractArray{A,N}
    data :: Vector{A}
    name :: String
end
my_type(data::AbstractArray{T,N}, name::String) where {T,N} =
    my_type{eltype(data),N}(data, name)

Base.size(var::my_type) = size(var.data)

Base.getindex(var::my_type, i::Int) = var.data[i]
Base.getindex(var::my_type, I::Vararg{Int,N}) where {N} = var.data[I...]
Base.getindex(var::my_type, ::Colon) = var.data[:]
Base.getindex(var::my_type, kr::AbstractRange) = var.data[kr]

Base.setindex!(var::my_type, v, i::Int) = (var.data[i] = v)
Base.setindex!(var::my_type, v, I::Vararg{Int,N}) where {N} = (var.data[I...] = v)
Base.setindex!(var::my_type, v, ::Colon) = (var.data[:] .= v)
Base.setindex!(var::my_type, v, kr::AbstractRange) = (var.data[kr] .= v)

function rhs_test(f, p, t)
    f
end

xmin   = -2.0*pi
xmax   =  2.0*pi
xnodes =  600
hx     = (xmax - xmin) / xnodes

xx = range(xmin, stop=xmax, length=600)

x0   = 0
w    = 0.4
A    = 1

f0    = A * exp.( -((xx .- x0) ./ w).^2 )
foo   = my_type(f0, "foo")

tspan = (0.0, 1.0)

Base.similar(foo::my_type) = my_type(similar(foo.data),foo.name)
Base.similar(foo::my_type,::Type{T}) where T = my_type(similar(foo.data,T),foo.name)
function rhs_test2(df, f, p, t)
    df.data .= f.data
end
prob  = ODEProblem(rhs_test2, foo, tspan)
sol   = solve(prob, RK4())

So even for arrays it’s not bad: only two methods not on the AbstractArray interface were needed for out-of-place to define arithmetic, and in-place only needed the AbstractArray interface (but needed to make sure you define the similar part!).

Note that for these “extra” features we have a little bit more to do in order to get back to where we were in Julia v0.6. Since the definition of an AbstractArray and Broadcast has changed 3 times over the last 2 years, we haven’t been able to sit down and properly define our interface here. However, with the stability of v1.0 I hope we can finally do so soon :slight_smile: .

5 Likes