Replacing static arrays for use with Enzyme.jl

I’m implementing AD in a code that makes use of types built upon SLArray and SArray, and am trying to understand the best way toward converting things so that the code plays well with Enzyme.

From my understanding, to work with Enzyme, 1) all arrays and custom data types must be mutable, and 2) functions must mutate the arrays in-place and not return them. Is this correct?

If so, then my my naive strategy would be to replace SArray with MArray and to replace SLArray with LArray. Does this seem like a good route?

If this thinking is correct, then I have some general questions. Is there an unavoidable performance trade-off between using SArray versus MArray in order to use Enzyme? Also, I noticed in LabelledArrays.jl that SLArray is based on SArray but that LArray is not based on MArray as I might expect. Is this the intended end state, or is it just because no one has implemented it yet? Maybe I’m misunderstanding this.

To possibly answer that last question, I guess it’s that LArray has the flexibility to allow an <: AbstractArray type, so that something like MArray would still work.

Enzyme generally works well with static arrays. For example:

julia> using Enzyme, StaticArrays

julia> x = rand(SVector{3, Float64}, 10)
10-element Vector{SVector{3, Float64}}:
 [0.703458204835627, 0.7121617637366355, 0.7313927507314387]
 [0.4621987150741128, 0.2154237571455806, 0.10735184165752176]
 [0.545521087048848, 0.2192157907272828, 0.7063366942886221]
 [0.29922059211540786, 0.00022592689698108792, 0.4566363080401129]
 [0.0804002475278649, 0.7924864781611491, 0.7757979667815256]
 [0.4175042850253522, 0.6422189150032824, 0.3510186112988213]
 [0.7120487718195188, 0.09890390308324437, 0.7345662304826344]
 [0.08079153987374355, 0.4615163823192098, 0.23119102793791468]
 [0.6539763154678109, 0.6466826908349046, 0.8395969123227822]
 [0.3203653937684048, 0.6881807184735692, 0.12676012258252278]

julia> dx = zero(x)
10-element Vector{SVector{3, Float64}}:
 [0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0]

julia> function f(x, y)
           sum(sum(y .* x))
       end
f (generic function with 1 method)

julia> f(x, 2.0)
27.626299890124855

julia> autodiff(Reverse, f, Active, Duplicated(x, dx), Active(2.0))
((nothing, 13.813149945062424),)

julia> dx
10-element Vector{SVector{3, Float64}}:
 [2.0, 2.0, 2.0]
 [2.0, 2.0, 2.0]
 [2.0, 2.0, 2.0]
 [2.0, 2.0, 2.0]
 [2.0, 2.0, 2.0]
 [2.0, 2.0, 2.0]
 [2.0, 2.0, 2.0]
 [2.0, 2.0, 2.0]
 [2.0, 2.0, 2.0]
 [2.0, 2.0, 2.0]

julia> g(x) = 2 * sum(x)
g (generic function with 1 method)

julia> autodiff(Reverse, g, Active, Active(SVector(1.0, 2.0, 3.0)))
(([2.0, 2.0, 2.0],),)

Maybe give it a go as-is and see if you get any errors.

3 Likes

Huh. I did try as-is before but got this error, which lead me to believe that SArray was the issue (this is partial):

  Got exception outside of a @test
  setfield!: immutable struct of type SArray cannot be changed
  Stacktrace:
    [1] rt_jl_getfield_rev(::SMatrix{3, 3, Float64, 9}, ::Base.RefValue{NTuple{9, Float64}}, ::Type{Val{:data}}, ::Val{false})
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/Dd2LU/src/rules/typeunstablerules.jl:257
    [2] getindex
      @ ~/.julia/packages/StaticArrays/eGKzB/src/SArray.jl:62 [inlined]
    [3] view
      @ ~/.julia/packages/StaticArrays/eGKzB/src/abstractarray.jl:291 [inlined]
    [4] getindex

This error remains, even after I make the custom type a mutable struct rather than struct

That seems to be a bug with one of the custom rules, so if you have a MWE that triggers that please open an issue.

1 Like

Here’s an MWE. I haven’t opened an issue because I’m not sure whether it’s on my end or not. The issue seems to be the getindex function.

using Enzyme
using Random 
using StaticArrays

abstract type AbstractBasisType end

struct Contravariant <: AbstractBasisType end

struct CurvilinearBasisVectors{N, T, C, B, V <: AbstractBasisType} <: StaticMatrix{N, N, T}
    __x::Union{SMatrix{N, N, T}}
    function CurvilinearBasisVectors{N, T, C, B, V}(b::AbstractMatrix) where {N, T, C, B, V}
        return new{N, T, C, B, V}(SMatrix{N, N, T}(b))
    end
end

Base.@propagate_inbounds function Base.getindex(v::CurvilinearBasisVectors{N, T, C, B, V}, i::Int) where {N, T, C, B, V}
    return view(getfield(v, :__x), i)[]
end

basis_labels = (:∇x, :∇y, :∇z);
coord_labels = (:x, :y, :z);

a = CurvilinearBasisVectors{3,Float64,basis_labels,coord_labels,Contravariant}(
    rand(3,3),
);
da = CurvilinearBasisVectors{3,Float64,basis_labels,coord_labels,Contravariant}(
    zeros(3,3),
);

@show a,da

function f(x, y)
    sum(sum(y .* x))
end

@show f(a, 2.0)

@show autodiff(Reverse, f, Active, Duplicated(a, da), Active(2.0))
@show da

For duplicated in reverse mode the arguments need to be mutable you defined CurvilinearBasisVectors as a struct, if you use mutable struct it should work.

That was the first thing I tried (but forgot to mention it in the OP). Trying it again with this MWE, I get the same error:

ERROR: LoadError: setfield!: immutable struct of type SArray cannot be changed
Stacktrace:
  [1] rt_jl_getfield_rev(::SMatrix{3, 3, Float64, 9}, ::Base.RefValue{NTuple{9, Float64}}, ::Type{Val{:data}}, ::Val{false})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/rules/typeunstablerules.jl:257
  [2] getindex
    @ ~/.julia/packages/StaticArrays/eGKzB/src/SArray.jl:62 [inlined]
  [3] view
    @ ~/.julia/packages/StaticArrays/eGKzB/src/abstractarray.jl:291 [inlined]
  [4] getindex
    @ ~/software/julia/enzyme/test-basis.jl:17 [inlined]
  [5] macro expansion
    @ ~/.julia/packages/StaticArrays/eGKzB/src/broadcast.jl:135 [inlined]
  [6] __broadcast
    @ ~/.julia/packages/StaticArrays/eGKzB/src/broadcast.jl:123 [inlined]
  [7] _broadcast
    @ ~/.julia/packages/StaticArrays/eGKzB/src/broadcast.jl:119 [inlined]
  [8] copy
    @ ~/.julia/packages/StaticArrays/eGKzB/src/broadcast.jl:60 [inlined]
  [9] materialize
    @ ./broadcast.jl:903 [inlined]
 [10] f
    @ ~/software/julia/enzyme/test-basis.jl:33 [inlined]
 [11] diffejulia_f_2626wrap
    @ ~/software/julia/enzyme/test-basis.jl:0
 [12] macro expansion
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/compiler.jl:5306 [inlined]
 [13] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Active{…}, ::Float64)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/compiler.jl:4984
 [14] (::Enzyme.Compiler.CombinedAdjointThunk{…})(::Const{…}, ::Duplicated{…}, ::Vararg{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/compiler.jl:4926
 [15] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Vararg{…})
    @ Enzyme ~/.julia/packages/Enzyme/jOGYG/src/Enzyme.jl:215
 [16] autodiff(::ReverseMode{false, FFIABI}, ::typeof(f), ::Type, ::Duplicated{CurvilinearBasisVectors{…}}, ::Vararg{Any})
    @ Enzyme ~/.julia/packages/Enzyme/jOGYG/src/Enzyme.jl:224
 [17] macro expansion
    @ show.jl:1181 [inlined]
 [18] top-level scope
    @ ~/software/julia/enzyme/test-basis.jl:38
 [19] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [20] top-level scope
    @ REPL[2]:1

If I also change the SMatrix in the example above to MMatrix, then I get a different error:

ERROR: LoadError: Enzyme Mutability Error: Cannot add one in place to immutable value (Base.RefValue{Float64}(0.0), Base.RefValue{Float64}(0.0), Base.RefValue{Float64}(0.0), Base.RefValue{Float64}(0.0), Base.RefValue{Float64}(0.0), Base.RefValue{Float64}(0.0), Base.RefValue{Float64}(0.0), Base.RefValue{Float64}(0.0), Base.RefValue{Float64}(0.0))
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] runtime_generic_rev(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, tape::Enzyme.Compiler.Tape{…}, f::Type{…}, df::Type{…}, primal_1::NTuple{…}, shadow_1_1::NTuple{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/rules/jitrules.jl:242
  [3] _broadcast
    @ ~/.julia/packages/StaticArrays/eGKzB/src/broadcast.jl:120 [inlined]
  [4] copy
    @ ~/.julia/packages/StaticArrays/eGKzB/src/broadcast.jl:60 [inlined]
  [5] materialize
    @ ./broadcast.jl:903 [inlined]
  [6] f
    @ ~/software/julia/enzyme/test-basis.jl:33 [inlined]
  [7] diffejulia_f_2725wrap
    @ ~/software/julia/enzyme/test-basis.jl:0
  [8] macro expansion
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/compiler.jl:5306 [inlined]
  [9] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Active{…}, ::@NamedTuple{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/compiler.jl:4984
 [10] (::Enzyme.Compiler.AdjointThunk{…})(::Const{…}, ::Duplicated{…}, ::Vararg{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/compiler.jl:4932
 [11] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Vararg{…})
    @ Enzyme ~/.julia/packages/Enzyme/jOGYG/src/Enzyme.jl:203
 [12] autodiff(::ReverseMode{false, FFIABI}, ::typeof(f), ::Type, ::Duplicated{CurvilinearBasisVectors{…}}, ::Vararg{Any})
    @ Enzyme ~/.julia/packages/Enzyme/jOGYG/src/Enzyme.jl:224
 [13] macro expansion
    @ show.jl:1181 [inlined]
 [14] top-level scope
    @ ~/software/julia/enzyme/test-basis.jl:38
 [15] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [16] top-level scope
    @ REPL[2]:1

So this specifically is a “not-yet implemented” for a type unstable update of an immutable variable.

We should add support for this, but in the interim, can you type stabilize your code?

Regardless please open an issue.

Issue opened. https://github.com/EnzymeAD/Enzyme.jl/issues/1263

And I’ll try to type stabilize the code. Thanks for the suggestion.