So we started a new huge project for which we aim to use Optimization.jl with Enzyme.jl as AD backend.
We came up with some early prototypes and plugged that into Optimization.jl, unfortunately not very succesful (It seems to be the case that the closures we need to provide to Optimization.jl introduces type instabilities Enzyme does not like, but that is another issue).
Now I started to play around a little bit with Enzyme itself to understand the Optimization.jl wrappers better and I just hit an issue with vectors as differentiable args I don’t understand (according to the docs this should work). I reduced a lot of code to the most minimal bit that still reproduces ther error. Consider the following module and variables:
module Nodes
abstract type Node end
struct Source{T}<:Node where T<:Real
output::Vector{T}
end
struct Multiplicator{T}<:Node where T<:Real
factors::Vector{T}
output::Vector{T}
upstream_node::Node
end
struct Sink{T}<:Node where T<:Real
target_value::Vector{T}
target_fraction::Vector{T}
upstream_node::Node
end
function calculate(node::Multiplicator)
@. node.output = (node.factors) * node.upstream_node.output
end
function calculate(node::Sink)
@. node.target_fraction = node.upstream_node.output / node.target_value
end
end
begin
using Enzyme
using .Nodes
using Statistics
Enzyme.API.runtimeActivity!(true)
const system_size = 1000
node_1 = Nodes.Source(fill(10.0,system_size))
node_2 = Nodes.Multiplicator(randn(system_size).+1,zeros(system_size),node_1)
node_3 = Nodes.Sink(fill(11.0,system_size),zeros(system_size), node_2)
factor_matrix = [fill(1.2,system_size) fill(0.9,system_size) fill(1.0,system_size) fill(1.1,system_size)]
factor_matrix.+= randn(system_size,4).*0.1
end
If we now define
function target(a,b,c,d)
weights = [a, b, c , d]
factors = factor_matrix*weights
node_2.factors .= factors
Nodes.calculate(node_2)
Nodes.calculate(node_3)
y = (abs(1 -mean(node_3.target_fraction)))
return y
end
x = [0.2, 0.2, 0.2, 0.4]
We can autodiff that via Enzyme without a problem:
autodiff(Reverse, target, Active,Active(x[1]), Active(x[2]), Active(x[3]), Active(x[4])) #gives ((1274.8375256354548, 958.0874431949107, 1062.6665657969513, 1172.9304957963361),)
If I use a vector as an input
function vec_input_target(weights)
factors = factor_matrix*weights
node_2.factors .= factors
Nodes.calculate(node_2)
Nodes.calculate(node_3)
y = (abs(1 -mean(node_3.target_fraction)))
return y
end
autodiff(Reverse,vec_input_target,Active,Active(x))
This throws and assertion error
ERROR: AssertionError: !is_split
Stacktrace:
[1] (::Enzyme.Compiler.var"#397#401"{LLVM.Function, DataType, UnionAll, Enzyme.API.CDerivativeMode, Int64, Bool, Bool, UInt64, Enzyme.Compiler.Interpreter.EnzymeInterpreter, Vector{LLVM.Argument}, LLVM.DataLayout, LLVM.Function, LLVM.StructType, LLVM.StructType, Vector{UInt8}, Vector{Int64}, Vector{Type}, LLVM.PointerType, LLVM.VoidType, LLVM.Context, LLVM.Module, Bool, Bool})(builder::LLVM.IRBuilder)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/RiUxJ/src/compiler.jl:7734
[2] LLVM.IRBuilder(f::Enzyme.Compiler.var"#397#401"{LLVM.Function, DataType, UnionAll, Enzyme.API.CDerivativeMode, Int64, Bool, Bool, UInt64, Enzyme.Compiler.Interpreter.EnzymeInterpreter, Vector{LLVM.Argument}, LLVM.DataLayout, LLVM.Function, LLVM.StructType, LLVM.StructType, Vector{UInt8}, Vector{Int64}, Vector{Type}, LLVM.PointerType, LLVM.VoidType, LLVM.Context, LLVM.Module, Bool, Bool}, args::LLVM.Context; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ LLVM ~/.julia/packages/LLVM/5aiiG/src/irbuilder.jl:23
[3] LLVM.IRBuilder(f::Function, args::LLVM.Context)
@ LLVM ~/.julia/packages/LLVM/5aiiG/src/irbuilder.jl:20
[4] create_abi_wrapper(enzymefn::LLVM.Function, TT::Type, rettype::Type, actualRetType::Type, Mode::Enzyme.API.CDerivativeMode, augmented::Ptr{Nothing}, width::Int64, returnPrimal::Bool, shadow_init::Bool, world::UInt64, interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/RiUxJ/src/compiler.jl:7692
[5] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{Bool, Bool}, returnPrimal::Bool, jlrules::Vector{String}, expectedTapeType::Type)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/RiUxJ/src/compiler.jl:7433
[6] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, ctx::LLVM.ThreadSafeContext, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/RiUxJ/src/compiler.jl:8984
[7] codegen
@ ~/.julia/packages/Enzyme/RiUxJ/src/compiler.jl:8592 [inlined]
[8] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, ctx::Nothing, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/RiUxJ/src/compiler.jl:9518
[9] _thunk
@ ~/.julia/packages/Enzyme/RiUxJ/src/compiler.jl:9515 [inlined]
[10] cached_compilation
@ ~/.julia/packages/Enzyme/RiUxJ/src/compiler.jl:9553 [inlined]
[11] #s291#456
@ ~/.julia/packages/Enzyme/RiUxJ/src/compiler.jl:9615 [inlined]
[12] var"#s291#456"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ABI::Any, ::Any, #unused#::Type, #unused#::Type, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
@ Enzyme.Compiler ./none:0
[13] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
@ Core ./boot.jl:602
[14] autodiff(#unused#::ReverseMode{false, FFIABI}, f::Const{typeof(vec_input_target)}, #unused#::Type{Active}, args::Active{Vector{Float64}})
@ Enzyme ~/.julia/packages/Enzyme/RiUxJ/src/Enzyme.jl:195
[15] autodiff(::ReverseMode{false, FFIABI}, ::typeof(vec_input_target), ::Type, ::Active{Vector{Float64}})
@ Enzyme ~/.julia/packages/Enzyme/RiUxJ/src/Enzyme.jl:222
This is on Enzyme 0.11.4 and Julia 1.9. Any help is greatly appreciated!
PS: Maybe as an additional question: Does anybody has experience with constrained nonlinear optimization via Optimization.jl and Enzyme.jl as AD backend? Is it mature enough? We would really love to make it fly (and avoid Zygote.jl due to the mutation issues, ReverseDiff.jl due to recompilation of the tapes and ForwardDiff.jl due to inflexible type layout)