Second order gradient with Lux, Zygote, CUDA, Enzyme

Dear All,

I had a small pet project over the christmas, for some very specific application (steganography) I want to train a neural network while regularizing the gradient with respect to the input, which means that during training, I need to compute 2nd order gradient.

I have chosen Lux.jl instead my usual go-to Flux.jl. The reason behind was (i) to try it and (ii) i believe it is more “functional” in the sense that the framework has less implicit states, which might be nicer for AD.

The start on CPU flawless. The following code has produced the second order gradient without a flaw

using Lux, CUDA, Random, OneHotArrays, Zygote
using Functors, Optimisers, Printf
model = Chain(
    Conv((5, 5), 1 => 6, relu),
    MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu),
    MaxPool((2, 2)),
    FlattenLayer(3),
    Chain(
        Dense(256 => 128, relu),
        Dense(128 => 84, relu),
        Dense(84 => 2)
    )
)

ps, st = Lux.setup(Random.default_rng(), model);

x = randn(Float32, 28,28,1,32)
δ = randn(Float32, 28,28,1,32)
y =  onehotbatch(rand((1,2),32), 1:2)

const celoss = CrossEntropyLoss(;logits=true)
const regloss = MSELoss()

function loss_function(model, ps, st, x, y)
    pred, _ = model(x, ps, st)
    return celoss(pred, y)
end

function ∂xloss_function(model, ps, st, x, δ, y)
    smodel = StatefulLuxLayer{true}(model, ps, st)
    ∂x = only(Zygote.gradient(Base.Fix2(celoss, y) ∘ smodel, x))
    regloss(∂x, δ) + loss_function(model, ps, st, x, y)
end

function ∂∂xloss_function(model, ps, st, x, δ, y)
    only(Zygote.gradient(ps -> ∂xloss_function(model, ps, st, x, δ, y), ps))
end

∂∂xloss_function(model, ps, st, x, δ, y)

But then, I need to move the computation to CUDA and that is where the problem starts. I have invoked CUDA version by moving arrays to gpu and call the same function, i.e.

x_cu= x |> cu
δ_cu= δ |> cu
y_cu= y |> cu
ps_cu = ps |> cu
st_cu = st |> cu

∂∂xloss_function(model, ps_cu, st_cu, x_cu, δ_cu, y_cu)

which breaks because maxpool_direct! realizing max-pooling in nnlib is not twice differentiable on gpu because of the scalar indexing. If I change the model to remove the MaxPool to rely just on conv, it crashes in the logsoftmax used in cross-entropy, which is not twice differentiable on CUDA either, because implementation uses different path than CPU version to use functions from cuDNN, which are not twice differentiable.

I have also tried the luck with Enzyme.jl with eyeing to eventually use Reactant.jl, though I confess I am not that proficient with it. The code modified to use Enzyme looked like

using Lux, Reactant, Enzyme, Random, OneHotArrays
using Functors, Optimisers, Printf

model = Chain(
    Conv((5, 5), 1 => 6, relu),
    MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu),
    MaxPool((2, 2)),
    FlattenLayer(3),
    Chain(
        Dense(256 => 128, relu),
        Dense(128 => 84, relu),
        Dense(84 => 2)
    )
)

ps, st = Lux.setup(Random.default_rng(), model);

x = randn(Float32, 28,28,1,32)
δ = randn(Float32, 28,28,1,32)
y =  onehotbatch(rand((1,2),32), 1:2)

const celoss = CrossEntropyLoss(;logits=true)
const regloss = MSELoss()

function loss_function(model, ps, st, x, y)
    pred, _ = model(x, ps, st)
    return celoss(pred, y)
end

function ∂xloss_function(model, ps, st, x, δ, y)
	smodel = StatefulLuxLayer{true}(model, ps, st)
    ∂x = Enzyme.gradient(Enzyme.Reverse, Const(Base.Fix2(celoss, y) ∘ smodel), x)[1]
    regloss(∂x, δ) + loss_function(model, ps, st, x, y)
end

function ∂∂xloss_function(model, ps, st, x, δ, y)
    Enzyme.gradient(Enzyme.Reverse, Const(∂xloss_function), Const(model),
        ps, Const(st), Const(x), Const(δ), Const(y))[2]
end

∂∂xloss_function(model, ps, st, x, δ, y)

which crashes with the following error

Error message
ERROR: MethodError: no method matching get_base_and_offset(::Nothing; offsetAllowed::Bool, inttoptr::Bool)
The function `get_base_and_offset` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  get_base_and_offset(::LLVM.Value; offsetAllowed, inttoptr)
   @ Enzyme ~/.julia/packages/Enzyme/DiEvV/src/absint.jl:235

Stacktrace:
  [1] check_ir!(job::GPUCompiler.CompilerJob, errors::Vector{…}, imported::Set{…}, f::LLVM.Function, deletedfns::Vector{…}, mod::LLVM.Module)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler/validation.jl:271
  [2] check_ir!(job::GPUCompiler.CompilerJob, errors::Vector{Tuple{String, Vector{Base.StackTraces.StackFrame}, Any}}, mod::LLVM.Module)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler/validation.jl:221
  [3] check_ir
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler/validation.jl:179 [inlined]
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:3413
  [5] codegen
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:3338 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5387
  [7] _thunk
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5387 [inlined]
  [8] cached_compilation
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5439 [inlined]
  [9] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5550
 [10] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5735
 [11] autodiff
    @ ~/.julia/packages/Enzyme/DiEvV/src/Enzyme.jl:485 [inlined]
 [12] macro expansion
    @ ~/.julia/packages/Enzyme/DiEvV/src/sugar.jl:275 [inlined]
 [13] gradient
    @ ~/.julia/packages/Enzyme/DiEvV/src/sugar.jl:263 [inlined]
 [14] ∂∂xloss_function(model::Chain{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}, x::Array{…}, δ::Array{…}, y::OneHotMatrix{…})
    @ Main ./REPL[18]:2
 [15] top-level scope
    @ REPL[19]:1
Some type information was truncated. Use `show(err)` to see complete types.

I have naively tried Reactant.jl as follows, which has crashed

Reactant with error message
x_ra = x |> xdev
δ_ra = δ |> xdev
y_ra = y |> xdev
ps_ra = ps |> xdev
st_ra = st |> xdev

julia> ∂∂xloss_function_compiled = @compile ∂∂xloss_function(model, ps_ra, st_ra, x_ra, δ_ra, y_ra)
error: could not compute the adjoint for this operation "enzyme.push"(%56, %125) : (!enzyme.Cache<tensor<32xf32>>, tensor<32xf32>) -> ()
loc("subtract"("/Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Ops.jl":193:0)): error: could not compute the adjoint for this operation "enzyme.push"(%51, %123) : (!enzyme.Cache<tensor<2x32xf32>>, tensor<2x32xf32>) -> ()
error: could not compute the adjoint for this operation "enzyme.push"(%42, %arg9) : (!enzyme.Cache<tensor<84x2xf32>>, tensor<84x2xf32>) -> ()
error: could not compute the adjoint for this operation "enzyme.push"(%34, %arg7) : (!enzyme.Cache<tensor<128x84xf32>>, tensor<128x84xf32>) -> ()
error: could not compute the adjoint for this operation "enzyme.push"(%26, %arg5) : (!enzyme.Cache<tensor<256x128xf32>>, tensor<256x128xf32>) -> ()
error: could not compute the adjoint for this operation "enzyme.push"(%19, %104) : (!enzyme.Cache<tensor<8x8x16x32xf32>>, tensor<8x8x16x32xf32>) -> ()
loc("reverse"("/Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Ops.jl":1038:0)): error: could not compute the adjoint for this operation "enzyme.push"(%11, %99) : (!enzyme.Cache<tensor<5x5x6x16xf32>>, tensor<5x5x6x16xf32>) -> ()
error: could not compute the adjoint for this operation "enzyme.push"(%8, %97) : (!enzyme.Cache<tensor<24x24x6x32xf32>>, tensor<24x24x6x32xf32>) -> ()
loc("reverse"("/Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Ops.jl":1038:0)): error: could not compute the adjoint for this operation "enzyme.push"(%0, %92) : (!enzyme.Cache<tensor<5x5x1x6xf32>>, tensor<5x5x1x6xf32>) -> ()

[85175] signal 11 (2): Segmentation fault: 11
in expression starting at REPL[21]:1
_ZN4mlir9Operation7getAttrEN4llvm9StringRefE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir13SymbolRefAttr3getEPNS_9OperationE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir4func6CallOp5buildERNS_9OpBuilderERNS_14OperationStateENS0_6FuncOpENS_10ValueRangeE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir9OpBuilder6createINS_4func6CallOpEJRNS2_6FuncOpERN4llvm11SmallVectorINS_5ValueELj6EEEEEET_NS_8LocationEDpOT0_ at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZNK15AutoDiffCallRev24createReverseModeAdjointEPN4mlir9OperationERNS0_9OpBuilderEPNS0_6enzyme21MGradientUtilsReverseEN4llvm11SmallVectorINS0_5ValueELj6EEE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir6enzyme6detail41ReverseAutoDiffOpInterfaceInterfaceTraits13FallbackModelI15AutoDiffCallRevE24createReverseModeAdjointEPKNS2_7ConceptEPNS_9OperationERNS_9OpBuilderEPNS0_21MGradientUtilsReverseEN4llvm11SmallVectorINS_5ValueELj6EEE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir6enzyme26ReverseAutoDiffOpInterface24createReverseModeAdjointERNS_9OpBuilderEPNS0_21MGradientUtilsReverseEN4llvm11SmallVectorINS_5ValueELj6EEE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir6enzyme12MEnzymeLogic10visitChildEPNS_9OperationERNS_9OpBuilderEPNS0_21MGradientUtilsReverseE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir6enzyme12MEnzymeLogic13differentiateEPNS0_21MGradientUtilsReverseERNS_6RegionES5_N4llvm12function_refIFvRNS_9OpBuilderEPNS_5BlockEEEENSt3__18functionIFNSE_4pairINS_5ValueESH_EENS_4TypeEEEE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir6enzyme12MEnzymeLogic17CreateReverseDiffENS_19FunctionOpInterfaceENSt3__16vectorI10DIFFE_TYPENS3_9allocatorIS5_EEEES8_RNS0_13MTypeAnalysisENS4_IbNS6_IbEEEESC_14DerivativeModebmNS_4TypeENS0_11MFnTypeInfoESC_Pv at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN12_GLOBAL__N_117DifferentiatePass16lowerEnzymeCallsERN4mlir21SymbolTableCollectionENS1_19FunctionOpInterfaceE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir6detail4walkINS_15ForwardIteratorEEEvPNS_9OperationEN4llvm12function_refIFvS4_EEENS_9WalkOrderE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN12_GLOBAL__N_117DifferentiatePass14runOnOperationEv at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir6detail17OpToOpPassAdaptor3runEPNS_4PassEPNS_9OperationENS_15AnalysisManagerEbj at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir11PassManager9runPassesEPNS_9OperationENS_15AnalysisManagerE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir11PassManager3runEPNS_9OperationE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
mlirPassManagerRunOnOp at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
mlirPassManagerRunOnOp at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/mlir/libMLIR_h.jl:5867 [inlined]
run! at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/mlir/IR/Pass.jl:74 [inlined]
#run_pass_pipeline!#1 at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Compiler.jl:284
run_pass_pipeline! at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Compiler.jl:279 [inlined]
#compile_mlir!#8 at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Compiler.jl:338
compile_mlir! at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Compiler.jl:314 [inlined]
#32 at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Compiler.jl:799
context! at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/mlir/IR/Context.jl:76
unknown function (ip: 0x33a6e413b)
#compile_xla#31 at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Compiler.jl:796
compile_xla at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Compiler.jl:791 [inlined]
#compile#36 at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Compiler.jl:823
compile at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Compiler.jl:822
unknown function (ip: 0x33a580067)
jl_apply at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/./julia.h:2157 [inlined]
do_call at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/interpreter.c:126
eval_stmt_value at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/interpreter.c:174
eval_body at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/interpreter.c:663
jl_interpret_toplevel_thunk at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/interpreter.c:821
jl_toplevel_eval_flex at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/toplevel.c:943
jl_toplevel_eval_flex at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/toplevel.c:886
jl_toplevel_eval_flex at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/toplevel.c:886
jl_toplevel_eval_flex at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/toplevel.c:886
ijl_toplevel_eval at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/toplevel.c:952 [inlined]
ijl_toplevel_eval_in at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/toplevel.c:994
eval at ./boot.jl:430 [inlined]
eval_user_input at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:245
repl_backend_loop at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:342
#start_repl_backend#59 at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:327
start_repl_backend at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:324
#run_repl#72 at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:483
run_repl at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:469
jfptr_run_repl_10091.1 at /Users/tomas.pevny/.julia/juliaup/julia-1.11.2+0.aarch64.apple.darwin14/share/julia/compiled/v1.11/REPL/u0gqU_3gH4d.dylib (unknown line)
#1150 at ./client.jl:446
jfptr_YY.1150_14648.1 at /Users/tomas.pevny/.julia/juliaup/julia-1.11.2+0.aarch64.apple.darwin14/share/julia/compiled/v1.11/REPL/u0gqU_3gH4d.dylib (unknown line)
jl_apply at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/./julia.h:2157 [inlined]
jl_f__call_latest at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/builtins.c:875
#invokelatest#2 at ./essentials.jl:1055 [inlined]
invokelatest at ./essentials.jl:1052 [inlined]
run_main_repl at ./client.jl:430
repl_main at ./client.jl:567 [inlined]
_start at ./client.jl:541
jfptr__start_73877.1 at /Users/tomas.pevny/.julia/juliaup/julia-1.11.2+0.aarch64.apple.darwin14/lib/julia/sys.dylib (unknown line)
jl_apply at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/./julia.h:2157 [inlined]
true_main at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/jlapi.c:900
jl_repl_entrypoint at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/jlapi.c:1059
Allocations: 128913871 (Pool: 128908824; Big: 5047); GC: 78
zsh: segmentation fault  julia --project=.

So few questions.

  1. Is there any hope to make the above to run on GPU for a person knowledgeable about AD and relatively proficient in writing ADRules? While the common lore is that Zygote.jl does not do 2nd order gradients, I have a feeling it is more a problem of libraries than that of the Zygote.jl itself. Seems like NNLib was not really designed with high order gradients in mind.
    I can probably write some custom logsoftmax to work on CUDA, though my first attempt has the problem described here LoadError: `llvmcall` must be compiled to be called when calling Zygote.Jacobian, very likely due to doing AD over broadcasting.
  2. Am I doing something wrong with Enzyme.jl? Or the problem is somewhere deeper?
  3. Is there any chance to make this work with Reactant.jl? Do I invoke it incorrectly?

Many thanks for answers, suggestions, and help in advance.

Information about my environment is below

version info
(tmp) pkg> st
Status `/private/tmp/Project.toml`
  [d360d2e6] ChainRulesCore v1.25.0
  [7da242da] Enzyme v0.13.26
  [26cc04aa] FiniteDifferences v0.12.32
  [d9f16b24] Functors v0.5.2
  [b2108857] Lux v1.4.3
  [872c559c] NNlib v0.9.26
  [0b1bfda6] OneHotArrays v0.2.6
  [3bd65402] Optimisers v0.4.2
  [3c362404] Reactant v0.2.12
  [e88e6eb3] Zygote v0.6.74

julia> versioninfo()
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin24.0.0)
  CPU: 11 × Apple M3 Pro
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, apple-m2)
Threads: 1 default, 0 interactive, 1 GC (on 5 virtual cores)

Perhaps you already saw it, but this page could maybe help?

1 Like

Yes, I have seen that. It was source my inspiration how to handle Lux and nested AD with it.

The reactant issue is that we need to tell EnzymeMLIR to run optimizations between the nested AD levels and we haven’t done that yet (though it’s in the list of things to do).

Cc @avikpal and @Pangoraw

Can you share a stacktrace for this? I had a look through the logsoftmax implementation, and it should be twice-differentiable by Zygote even on GPU. There are a few leftover log softmax gradient functions like NNlib.jl/ext/NNlibCUDACUDNNExt/softmax.jl at be9c1c8b932110c39955b4dd41785df8ce2cd5db · FluxML/NNlib.jl · GitHub which use cuDNN, but IIRC they should never be called in ordinary use.

Lux swaps in ForwardDiff for nested AD with Zygote (Nested AD with Lux etc) and that will fail here without a ForwardDiff.Dual overload

This was my suspicion. Therefore I am also trying Flux.jl, with which I am more familiar. The Flux code looks like this

using CUDA, Flux, Random, OneHotArrays, Zygote
using Functors, Optimisers, Printf

model = Chain(
    Conv((5, 5), 1 => 6, relu),
    Conv((5, 5), 6 => 16, relu),
    Flux.flatten,
    Chain(
        Dense(6400 => 128, relu),
        Dense(128 => 84, relu),
        Dense(84 => 2)
    )
)


x = randn(Float32, 28,28,1,32)
δ = randn(Float32, 28,28,1,32)
y =  onehotbatch(rand((1,2),32), 1:2)

model_cu= model |> gpu
x_cu= x |> gpu
δ_cu= δ |> gpu
y_cu= y |> gpu

const celoss = Flux.Losses.logitcrossentropy
const regloss = Flux.Losses.mse

function loss_function(model, x, y)
    return celoss(model(x), y)
end

function ∂xloss_function(model, x, δ, y)
    ∂x = only(Zygote.gradient(Base.Fix2(celoss, y) ∘ model, x))
    regloss(∂x, δ) + loss_function(model, x, y)
end

function ∂∂xloss_function(model, x, δ, y)
    only(Zygote.gradient(model -> ∂xloss_function(model, x, δ, y), model))
end

∂xloss_function(model, x, δ, y)

∂∂xloss_function(model, x, δ, y)

∂xloss_function(model_cu, x_cu, δ_cu, y_cu)

∂∂xloss_function(model_cu, x_cu, δ_cu, y_cu)

With Flux.jl without a handwritten logitcrossentropy even CPU version fails, i.e. ∂∂xloss_function(model, x, δ, y)

Stacktrace of failure on CPU
ERROR: Mutating arrays is not supported -- called copyto!(Matrix{Float32}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/lib/array.jl:70
  [3] (::Zygote.var"#547#548"{Matrix{Float32}})(::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/lib/array.jl:85
  [4] (::Zygote.var"#2633#back#549"{Zygote.var"#547#548"{Matrix{Float32}}})(Δ::Matrix{Float32})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
  [5] materialize!
    @ ./broadcast.jl:878 [inlined]
  [6] materialize!
    @ ./broadcast.jl:875 [inlined]
  [7] materialize!
    @ ./broadcast.jl:871 [inlined]
  [8] #783
    @ ~/.julia/packages/Zygote/zdPiG/src/lib/array.jl:344 [inlined]
  [9] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Matrix{…}})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [10] #3028#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:81 [inlined]
 [11] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Nothing, Matrix{…}})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/.julia/packages/Flux/1wZQP/src/losses/functions.jl:272 [inlined]
 [13] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Nothing, Nothing, Matrix{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/.julia/packages/Flux/1wZQP/src/losses/functions.jl:270 [inlined]
 [15] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Matrix{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/.julia/packages/Zygote/zdPiG/src/lib/base.jl:243 [inlined]
 [17] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Matrix{…}})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [18] #2430#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [19] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Matrix{…}})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [20] Pullback
    @ ./operators.jl:1053 [inlined]
 [21] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Tuple{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [22] Pullback
    @ ./operators.jl:1050 [inlined]
 [23] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Nothing, Array{…}})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [24] #294
    @ ~/.julia/packages/Zygote/zdPiG/src/lib/lib.jl:206 [inlined]
 [25] (::Zygote.Pullback{Tuple{Zygote.var"#294#295"{…}, Float32}, Any})(Δ::Tuple{Nothing, Nothing, Nothing, Tuple{Array{…}}})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [26] #2169#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [27] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Nothing, Nothing, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [28] Pullback
    @ ./operators.jl:1050 [inlined]
 [29] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Array{…}})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [30] #78
    @ ~/.julia/packages/Zygote/zdPiG/src/compiler/interface.jl:91 [inlined]
 [31] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Array{…}})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [32] gradient
    @ ~/.julia/packages/Zygote/zdPiG/src/compiler/interface.jl:148 [inlined]
 [33] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Array{…}})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [34] ∂xloss_function
    @ ./REPL[16]:2 [inlined]
 [35] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [36] #1
    @ ./REPL[17]:2 [inlined]
 [37] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [38] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface.jl:91
 [39] gradient(f::Function, args::Chain{Tuple{Conv{…}, Conv{…}, typeof(Flux.flatten), Chain{…}}})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface.jl:148
 [40] ∂∂xloss_function(model::Chain{Tuple{…}}, x::Array{Float32, 4}, δ::Array{Float32, 4}, y::OneHotMatrix{UInt32, Vector{…}})
    @ Main ./REPL[17]:2
 [41] top-level scope
    @ REPL[19]:1
Some type information was truncated. Use `show(err)` to see complete types.

The stacktrace from the Flux.jl version from GPU call ∂∂xloss_function(model_cu, x_cu, δ_cu, y_cu)

Stacktrace of failure on GPU
julia> ∂∂xloss_function(model_cu, x_cu, δ_cu, y_cu)
ERROR: `llvmcall` requires the compiler
Stacktrace:
  [1] macro expansion                                                                                         @ ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(::Zygote.Context{false}, ::Core.IntrinsicFunction, ::Tuple{String, String}, ::Type{Nothing
}, ::Type{Tuple{Bool}}, ::Bool)
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:91
  [3] assume
    @ ~/.julia/packages/LLVM/wMjUU/src/interop/intrinsics.jl:16 [inlined]
  [4] driver_version
    @ ~/.julia/packages/CUDA/2kjXI/lib/cudadrv/version.jl:20 [inlined]
  [5] isvalid                                                                                                 @ ~/.julia/packages/CUDA/2kjXI/lib/cudadrv/context.jl:71 [inlined]
  [6] _pullback(ctx::Zygote.Context{false}, f::typeof(CUDA.isvalid), args::CuContext)
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
  [7] validate_task_local_state
    @ ~/.julia/packages/CUDA/2kjXI/lib/cudadrv/state.jl:61 [inlined]                                        [8] _pullback(ctx::Zygote.Context{false}, f::typeof(CUDA.validate_task_local_state), args::CUDA.TaskLoca
lState)
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0                                    [9] task_local_state!
    @ ~/.julia/packages/CUDA/2kjXI/lib/cudadrv/state.jl:72 [inlined]
 [10] _pullback(::Zygote.Context{false}, ::typeof(CUDA.task_local_state!))
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [11] active_state
    @ ~/.julia/packages/CUDA/2kjXI/lib/cudadrv/state.jl:110 [inlined]
 [12] #cufunction#1169
    @ ~/.julia/packages/CUDA/2kjXI/src/compiler/execution.jl:373 [inlined]
 [13] _pullback(::Zygote.Context{…}, ::CUDA.var"##cufunction#1169", ::@Kwargs{}, ::typeof(cufunction), ::G
PUArrays.var"#35#37", ::Type{…})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [14] cufunction
    @ ~/.julia/packages/CUDA/2kjXI/src/compiler/execution.jl:372 [inlined]
 [15] _pullback(::Zygote.Context{false}, ::typeof(cufunction), ::GPUArrays.var"#35#37", ::Type{Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [16] #launch_heuristic#1200
    @ ~/.julia/packages/CUDA/2kjXI/src/compiler/execution.jl:112 [inlined]
 [17] _pullback(::Zygote.Context{…}, ::CUDA.var"##launch_heuristic#1200", ::Int64, ::Int64, ::typeof(GPUAr
rays.launch_heuristic), ::CUDA.CuArrayBackend, ::GPUArrays.var"#35#37", ::CuArray{…}, ::Base.Broadcast.Bro
adcasted{…}, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [18] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:946
 [19] adjoint
    @ ~/.julia/packages/Zygote/zdPiG/src/lib/lib.jl:203 [inlined]
 [20] adjoint(::Zygote.Context{…}, ::typeof(Core._apply_iterate), ::typeof(iterate), ::Function, ::Tuple{…
}, ::Tuple{…})
    @ Zygote ./none:0
 [21] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [22] launch_heuristic
    @ ~/.julia/packages/CUDA/2kjXI/src/gpuarrays.jl:15 [inlined]
 [23] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(GPUArrays.launch_he
uristic), ::CUDA.CuArrayBackend, ::GPUArrays.var"#35#37", ::CuArray{…}, ::Base.Broadcast.Broadcasted{…}, :
:Int64)
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [24] _copyto!
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:78 [inlined]
 [25] _pullback(::Zygote.Context{…}, ::typeof(GPUArrays._copyto!), ::CuArray{…}, ::Base.Broadcast.Broadcas
ted{…})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [26] materialize!
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:38 [inlined]
 [27] materialize!
    @ ./broadcast.jl:875 [inlined]
 [28] _pullback(::Zygote.Context{…}, ::typeof(Base.Broadcast.materialize!), ::CuArray{…}, ::Base.Broadcast
.Broadcasted{…})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [29] materialize!
    @ ./broadcast.jl:871 [inlined]
 [30] #1459
    @ ~/.julia/packages/Zygote/zdPiG/src/lib/broadcast.jl:369 [inlined]
 [31] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#1459#1462"{CuArray{…}}, args::CuArray{Float32,
2, CUDA.DeviceMemory})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [32] #4238#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:81 [inlined]
 [33] _pullback(ctx::Zygote.Context{…}, f::Zygote.var"#4238#back#1465"{…}, args::CuArray{…})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [34] Pullback
    @ ~/.julia/packages/Flux/1wZQP/src/losses/functions.jl:272 [inlined]
 [35] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [36] Pullback
    @ ~/.julia/packages/Flux/1wZQP/src/losses/functions.jl:270 [inlined]
 [37] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [38] Pullback
    @ ~/.julia/packages/Zygote/zdPiG/src/lib/base.jl:243 [inlined]
 [39] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [40] #2430#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [41] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#2430#back#456"{Zygote.Pullback{Tuple{…}, Tuple{
…}}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [42] Pullback
    @ ./operators.jl:1053 [inlined]
 [43] Pullback
    @ ./operators.jl:1050 [inlined]
 [44] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [45] #294
    @ ~/.julia/packages/Zygote/zdPiG/src/lib/lib.jl:206 [inlined]
 [46] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#294#295"{Tuple{…}, Zygote.Pullback{…}}, args::F
loat32)
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [47] #2169#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [48] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#2169#back#296"{Zygote.var"#294#295"{Tuple{…}, Z
ygote.Pullback{…}}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [49] Pullback
    @ ./operators.jl:1050 [inlined]
 [50] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [51] #78
    @ ~/.julia/packages/Zygote/zdPiG/src/compiler/interface.jl:91 [inlined]
 [52] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}}, ar
gs::Float32)
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [53] gradient
    @ ~/.julia/packages/Zygote/zdPiG/src/compiler/interface.jl:148 [inlined]
 [54] _pullback(::Zygote.Context{…}, ::typeof(Zygote.gradient), ::ComposedFunction{…}, ::CuArray{…})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [55] ∂xloss_function
    @ ./REPL[16]:2 [inlined]
 [56] _pullback(::Zygote.Context{…}, ::typeof(∂xloss_function), ::Chain{…}, ::CuArray{…}, ::CuArray{…}, ::
OneHotMatrix{…})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [57] #1
    @ ./REPL[17]:2 [inlined]
 [58] _pullback(ctx::Zygote.Context{false}, f::var"#1#2"{CuArray{…}, CuArray{…}, OneHotMatrix{…}}, args::C
hain{Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface2.jl:0
 [59] pullback(f::Function, cx::Zygote.Context{false}, args::Chain{Tuple{Conv{…}, Conv{…}, typeof(Flux.fla
tten), Chain{…}}})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface.jl:90
 [60] pullback
    @ ~/.julia/packages/Zygote/zdPiG/src/compiler/interface.jl:88 [inlined]
 [61] gradient(f::Function, args::Chain{Tuple{Conv{…}, Conv{…}, typeof(Flux.flatten), Chain{…}}})
    @ Zygote ~/.julia/packages/Zygote/zdPiG/src/compiler/interface.jl:147
 [62] ∂∂xloss_function(model::Chain{…}, x::CuArray{…}, δ::CuArray{…}, y::OneHotMatrix{…})
    @ Main ./REPL[17]:2
Some type information was truncated. Use `show(err)` to see complete types.

The problem with ERROR: llvmcall requires the compiler is something I hit when I define a custom rrule for the logsoftmax and ∇logsoftmax.

1 Like
using Lux, CUDA, cuDNN, Random, OneHotArrays, Zygote
using Functors, Optimisers, Printf

model = Chain(
    Conv((5, 5), 1 => 6, relu),
    MeanPool((2, 2)),
    Conv((5, 5), 6 => 16, relu),
    MeanPool((2, 2)),
    FlattenLayer(3),
    Chain(
        Dense(256 => 128, relu),
        Dense(128 => 84, relu),
        Dense(84 => 2)
    )
)

dev = gpu_device(; force=true)

ps, st = Lux.setup(Random.default_rng(), model) |> dev;

x = randn(Float32, 28, 28, 1, 32) |> dev;
δ = randn(Float32, 28, 28, 1, 32) |> dev;
y = onehotbatch(rand((1, 2), 32), 1:2) |> dev;

const celoss = CrossEntropyLoss(; logits=true)
const regloss = MSELoss()

function loss_function(model, ps, st, x, y)
    pred, _ = model(x, ps, st)
    return celoss(pred, y)
end

function ∂xloss_function(model, ps, st, x, δ, y)
    smodel = StatefulLuxLayer{true}(model, ps, st)
    ∂x = only(Zygote.gradient(Base.Fix2(celoss, y) ∘ smodel, x))
    regloss(∂x, δ) + loss_function(model, ps, st, x, y)
end

function ∂∂xloss_function(model, ps, st, x, δ, y)
    only(Zygote.gradient(ps -> ∂xloss_function(model, ps, st, x, δ, y), ps))
end

∂∂xloss_function(model, ps, st, x, δ, y)

I have patched the support for (log)softmax and MeanPool (MaxPool is a bit finicky to write the jvp for, so I try will do that later) in feat: more nested AD rules by avik-pal · Pull Request #1151 · LuxDL/Lux.jl · GitHub. I will merge and tag it later tonight once tests pass

also note that in the original example cuDNN (or LuxCUDA) wasn’t loaded so it wasn’t able to use the correct versions of the algorithms.

2 Likes

Hi Avik,

thank you very much. Your solution has indeed worked after I hit an update. Just for completeness, I copy the state of libraries after the update

Package status
  [6e4b80f9] BenchmarkTools v1.5.0
  [052768ef] CUDA v5.5.2
  [d360d2e6] ChainRulesCore v1.25.0
  [587475ba] Flux v0.16.0
  [d9f16b24] Functors v0.5.2
  [b2108857] Lux v1.4.3
  [d0bbae9a] LuxCUDA v0.3.3
  [872c559c] NNlib v0.9.26
  [0b1bfda6] OneHotArrays v0.2.6
  [3bd65402] Optimisers v0.4.2
  [e88e6eb3] Zygote v0.6.75
  [02a925ec] cuDNN v1.4.0
⌅ [4ee394cb] CUDA_Driver_jll v0.10.4+0
  [9a3f8284] Random v1.11.0

I have a few more questions. Is it better in this case to use reverse-over-forward, which I reckon the current solution does, or to use reverse-over-reverse? I am asking, because the second order gradient is differentiating function many-to-one.

Thanks a lot. Now, I now need to try bigger network, ideally EfficientNet.

I have a few more questions. Is it better in this case to use reverse-over-forward, which I reckon the current solution does, or to use reverse-over-reverse? I am asking, because the second order gradient is differentiating function many-to-one.

With reverse over forward we still need a single JVP (followed by an outer VJP), see the linked nested AD discussion for the explanation. So in this case it is going to be efficient to do the internal transform that Lux does.

That said the solution that will be most efficient in the long run is to make reactant work in this case. Opened [2nd Order AD] Regularization term in loss function · Issue #449 · EnzymeAD/Reactant.jl · GitHub to track this

Thank you very much for your help. I knew that the community will not let me down.

I have made the second order gradient working with efficientnet. In fact, I had to just adapt efficientnet from this repository GitHub - pxl-th/EfficientNet.jl: EfficientNet implementation in Julia compatible with Lux.jl, which was a good opportunity for me to get my hands dirty with the framework.

Getting this work with Reactant would be of course absolutely awesome. But this is enough for me to verify, if my idea for better steganalysis work.

If you have a working efficient net in Lux, it would be great to have it upstreamed to Boltz.jl/src/vision at main · LuxDL/Boltz.jl · GitHub!

I have though about it. Will try to prepare a pull request.

1 Like