Dear All,
I had a small pet project over the christmas, for some very specific application (steganography) I want to train a neural network while regularizing the gradient with respect to the input, which means that during training, I need to compute 2nd order gradient.
I have chosen Lux.jl instead my usual go-to Flux.jl. The reason behind was (i) to try it and (ii) i believe it is more “functional” in the sense that the framework has less implicit states, which might be nicer for AD.
The start on CPU flawless. The following code has produced the second order gradient without a flaw
using Lux, CUDA, Random, OneHotArrays, Zygote
using Functors, Optimisers, Printf
model = Chain(
Conv((5, 5), 1 => 6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu),
MaxPool((2, 2)),
FlattenLayer(3),
Chain(
Dense(256 => 128, relu),
Dense(128 => 84, relu),
Dense(84 => 2)
)
)
ps, st = Lux.setup(Random.default_rng(), model);
x = randn(Float32, 28,28,1,32)
δ = randn(Float32, 28,28,1,32)
y = onehotbatch(rand((1,2),32), 1:2)
const celoss = CrossEntropyLoss(;logits=true)
const regloss = MSELoss()
function loss_function(model, ps, st, x, y)
pred, _ = model(x, ps, st)
return celoss(pred, y)
end
function ∂xloss_function(model, ps, st, x, δ, y)
smodel = StatefulLuxLayer{true}(model, ps, st)
∂x = only(Zygote.gradient(Base.Fix2(celoss, y) ∘ smodel, x))
regloss(∂x, δ) + loss_function(model, ps, st, x, y)
end
function ∂∂xloss_function(model, ps, st, x, δ, y)
only(Zygote.gradient(ps -> ∂xloss_function(model, ps, st, x, δ, y), ps))
end
∂∂xloss_function(model, ps, st, x, δ, y)
But then, I need to move the computation to CUDA and that is where the problem starts. I have invoked CUDA version by moving arrays to gpu and call the same function, i.e.
x_cu= x |> cu
δ_cu= δ |> cu
y_cu= y |> cu
ps_cu = ps |> cu
st_cu = st |> cu
∂∂xloss_function(model, ps_cu, st_cu, x_cu, δ_cu, y_cu)
which breaks because maxpool_direct!
realizing max-pooling in nnlib is not twice differentiable on gpu because of the scalar indexing. If I change the model to remove the MaxPool
to rely just on conv
, it crashes in the logsoftmax
used in cross-entropy, which is not twice differentiable on CUDA either, because implementation uses different path than CPU version to use functions from cuDNN, which are not twice differentiable.
I have also tried the luck with Enzyme.jl with eyeing to eventually use Reactant.jl, though I confess I am not that proficient with it. The code modified to use Enzyme looked like
using Lux, Reactant, Enzyme, Random, OneHotArrays
using Functors, Optimisers, Printf
model = Chain(
Conv((5, 5), 1 => 6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu),
MaxPool((2, 2)),
FlattenLayer(3),
Chain(
Dense(256 => 128, relu),
Dense(128 => 84, relu),
Dense(84 => 2)
)
)
ps, st = Lux.setup(Random.default_rng(), model);
x = randn(Float32, 28,28,1,32)
δ = randn(Float32, 28,28,1,32)
y = onehotbatch(rand((1,2),32), 1:2)
const celoss = CrossEntropyLoss(;logits=true)
const regloss = MSELoss()
function loss_function(model, ps, st, x, y)
pred, _ = model(x, ps, st)
return celoss(pred, y)
end
function ∂xloss_function(model, ps, st, x, δ, y)
smodel = StatefulLuxLayer{true}(model, ps, st)
∂x = Enzyme.gradient(Enzyme.Reverse, Const(Base.Fix2(celoss, y) ∘ smodel), x)[1]
regloss(∂x, δ) + loss_function(model, ps, st, x, y)
end
function ∂∂xloss_function(model, ps, st, x, δ, y)
Enzyme.gradient(Enzyme.Reverse, Const(∂xloss_function), Const(model),
ps, Const(st), Const(x), Const(δ), Const(y))[2]
end
∂∂xloss_function(model, ps, st, x, δ, y)
which crashes with the following error
Error message
ERROR: MethodError: no method matching get_base_and_offset(::Nothing; offsetAllowed::Bool, inttoptr::Bool)
The function `get_base_and_offset` exists, but no method is defined for this combination of argument types.
Closest candidates are:
get_base_and_offset(::LLVM.Value; offsetAllowed, inttoptr)
@ Enzyme ~/.julia/packages/Enzyme/DiEvV/src/absint.jl:235
Stacktrace:
[1] check_ir!(job::GPUCompiler.CompilerJob, errors::Vector{…}, imported::Set{…}, f::LLVM.Function, deletedfns::Vector{…}, mod::LLVM.Module)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler/validation.jl:271
[2] check_ir!(job::GPUCompiler.CompilerJob, errors::Vector{Tuple{String, Vector{Base.StackTraces.StackFrame}, Any}}, mod::LLVM.Module)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler/validation.jl:221
[3] check_ir
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler/validation.jl:179 [inlined]
[4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:3413
[5] codegen
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:3338 [inlined]
[6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5387
[7] _thunk
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5387 [inlined]
[8] cached_compilation
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5439 [inlined]
[9] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5550
[10] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5735
[11] autodiff
@ ~/.julia/packages/Enzyme/DiEvV/src/Enzyme.jl:485 [inlined]
[12] macro expansion
@ ~/.julia/packages/Enzyme/DiEvV/src/sugar.jl:275 [inlined]
[13] gradient
@ ~/.julia/packages/Enzyme/DiEvV/src/sugar.jl:263 [inlined]
[14] ∂∂xloss_function(model::Chain{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}, x::Array{…}, δ::Array{…}, y::OneHotMatrix{…})
@ Main ./REPL[18]:2
[15] top-level scope
@ REPL[19]:1
Some type information was truncated. Use `show(err)` to see complete types.
I have naively tried Reactant.jl as follows, which has crashed
Reactant with error message
x_ra = x |> xdev
δ_ra = δ |> xdev
y_ra = y |> xdev
ps_ra = ps |> xdev
st_ra = st |> xdev
julia> ∂∂xloss_function_compiled = @compile ∂∂xloss_function(model, ps_ra, st_ra, x_ra, δ_ra, y_ra)
error: could not compute the adjoint for this operation "enzyme.push"(%56, %125) : (!enzyme.Cache<tensor<32xf32>>, tensor<32xf32>) -> ()
loc("subtract"("/Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Ops.jl":193:0)): error: could not compute the adjoint for this operation "enzyme.push"(%51, %123) : (!enzyme.Cache<tensor<2x32xf32>>, tensor<2x32xf32>) -> ()
error: could not compute the adjoint for this operation "enzyme.push"(%42, %arg9) : (!enzyme.Cache<tensor<84x2xf32>>, tensor<84x2xf32>) -> ()
error: could not compute the adjoint for this operation "enzyme.push"(%34, %arg7) : (!enzyme.Cache<tensor<128x84xf32>>, tensor<128x84xf32>) -> ()
error: could not compute the adjoint for this operation "enzyme.push"(%26, %arg5) : (!enzyme.Cache<tensor<256x128xf32>>, tensor<256x128xf32>) -> ()
error: could not compute the adjoint for this operation "enzyme.push"(%19, %104) : (!enzyme.Cache<tensor<8x8x16x32xf32>>, tensor<8x8x16x32xf32>) -> ()
loc("reverse"("/Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Ops.jl":1038:0)): error: could not compute the adjoint for this operation "enzyme.push"(%11, %99) : (!enzyme.Cache<tensor<5x5x6x16xf32>>, tensor<5x5x6x16xf32>) -> ()
error: could not compute the adjoint for this operation "enzyme.push"(%8, %97) : (!enzyme.Cache<tensor<24x24x6x32xf32>>, tensor<24x24x6x32xf32>) -> ()
loc("reverse"("/Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Ops.jl":1038:0)): error: could not compute the adjoint for this operation "enzyme.push"(%0, %92) : (!enzyme.Cache<tensor<5x5x1x6xf32>>, tensor<5x5x1x6xf32>) -> ()
[85175] signal 11 (2): Segmentation fault: 11
in expression starting at REPL[21]:1
_ZN4mlir9Operation7getAttrEN4llvm9StringRefE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir13SymbolRefAttr3getEPNS_9OperationE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir4func6CallOp5buildERNS_9OpBuilderERNS_14OperationStateENS0_6FuncOpENS_10ValueRangeE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir9OpBuilder6createINS_4func6CallOpEJRNS2_6FuncOpERN4llvm11SmallVectorINS_5ValueELj6EEEEEET_NS_8LocationEDpOT0_ at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZNK15AutoDiffCallRev24createReverseModeAdjointEPN4mlir9OperationERNS0_9OpBuilderEPNS0_6enzyme21MGradientUtilsReverseEN4llvm11SmallVectorINS0_5ValueELj6EEE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir6enzyme6detail41ReverseAutoDiffOpInterfaceInterfaceTraits13FallbackModelI15AutoDiffCallRevE24createReverseModeAdjointEPKNS2_7ConceptEPNS_9OperationERNS_9OpBuilderEPNS0_21MGradientUtilsReverseEN4llvm11SmallVectorINS_5ValueELj6EEE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir6enzyme26ReverseAutoDiffOpInterface24createReverseModeAdjointERNS_9OpBuilderEPNS0_21MGradientUtilsReverseEN4llvm11SmallVectorINS_5ValueELj6EEE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir6enzyme12MEnzymeLogic10visitChildEPNS_9OperationERNS_9OpBuilderEPNS0_21MGradientUtilsReverseE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir6enzyme12MEnzymeLogic13differentiateEPNS0_21MGradientUtilsReverseERNS_6RegionES5_N4llvm12function_refIFvRNS_9OpBuilderEPNS_5BlockEEEENSt3__18functionIFNSE_4pairINS_5ValueESH_EENS_4TypeEEEE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir6enzyme12MEnzymeLogic17CreateReverseDiffENS_19FunctionOpInterfaceENSt3__16vectorI10DIFFE_TYPENS3_9allocatorIS5_EEEES8_RNS0_13MTypeAnalysisENS4_IbNS6_IbEEEESC_14DerivativeModebmNS_4TypeENS0_11MFnTypeInfoESC_Pv at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN12_GLOBAL__N_117DifferentiatePass16lowerEnzymeCallsERN4mlir21SymbolTableCollectionENS1_19FunctionOpInterfaceE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir6detail4walkINS_15ForwardIteratorEEEvPNS_9OperationEN4llvm12function_refIFvS4_EEENS_9WalkOrderE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN12_GLOBAL__N_117DifferentiatePass14runOnOperationEv at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir6detail17OpToOpPassAdaptor3runEPNS_4PassEPNS_9OperationENS_15AnalysisManagerEbj at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir11PassManager9runPassesEPNS_9OperationENS_15AnalysisManagerE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
_ZN4mlir11PassManager3runEPNS_9OperationE at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
mlirPassManagerRunOnOp at /Users/tomas.pevny/.julia/artifacts/3c1a3adbd5fea9dc6cef5ff6f459ce623bb25db8/lib/libReactantExtra.dylib (unknown line)
mlirPassManagerRunOnOp at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/mlir/libMLIR_h.jl:5867 [inlined]
run! at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/mlir/IR/Pass.jl:74 [inlined]
#run_pass_pipeline!#1 at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Compiler.jl:284
run_pass_pipeline! at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Compiler.jl:279 [inlined]
#compile_mlir!#8 at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Compiler.jl:338
compile_mlir! at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Compiler.jl:314 [inlined]
#32 at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Compiler.jl:799
context! at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/mlir/IR/Context.jl:76
unknown function (ip: 0x33a6e413b)
#compile_xla#31 at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Compiler.jl:796
compile_xla at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Compiler.jl:791 [inlined]
#compile#36 at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Compiler.jl:823
compile at /Users/tomas.pevny/.julia/packages/Reactant/WudhJ/src/Compiler.jl:822
unknown function (ip: 0x33a580067)
jl_apply at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/./julia.h:2157 [inlined]
do_call at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/interpreter.c:126
eval_stmt_value at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/interpreter.c:174
eval_body at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/interpreter.c:663
jl_interpret_toplevel_thunk at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/interpreter.c:821
jl_toplevel_eval_flex at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/toplevel.c:943
jl_toplevel_eval_flex at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/toplevel.c:886
jl_toplevel_eval_flex at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/toplevel.c:886
jl_toplevel_eval_flex at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/toplevel.c:886
ijl_toplevel_eval at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/toplevel.c:952 [inlined]
ijl_toplevel_eval_in at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/toplevel.c:994
eval at ./boot.jl:430 [inlined]
eval_user_input at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:245
repl_backend_loop at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:342
#start_repl_backend#59 at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:327
start_repl_backend at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:324
#run_repl#72 at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:483
run_repl at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:469
jfptr_run_repl_10091.1 at /Users/tomas.pevny/.julia/juliaup/julia-1.11.2+0.aarch64.apple.darwin14/share/julia/compiled/v1.11/REPL/u0gqU_3gH4d.dylib (unknown line)
#1150 at ./client.jl:446
jfptr_YY.1150_14648.1 at /Users/tomas.pevny/.julia/juliaup/julia-1.11.2+0.aarch64.apple.darwin14/share/julia/compiled/v1.11/REPL/u0gqU_3gH4d.dylib (unknown line)
jl_apply at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/./julia.h:2157 [inlined]
jl_f__call_latest at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/builtins.c:875
#invokelatest#2 at ./essentials.jl:1055 [inlined]
invokelatest at ./essentials.jl:1052 [inlined]
run_main_repl at ./client.jl:430
repl_main at ./client.jl:567 [inlined]
_start at ./client.jl:541
jfptr__start_73877.1 at /Users/tomas.pevny/.julia/juliaup/julia-1.11.2+0.aarch64.apple.darwin14/lib/julia/sys.dylib (unknown line)
jl_apply at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/./julia.h:2157 [inlined]
true_main at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/jlapi.c:900
jl_repl_entrypoint at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/jlapi.c:1059
Allocations: 128913871 (Pool: 128908824; Big: 5047); GC: 78
zsh: segmentation fault julia --project=.
So few questions.
- Is there any hope to make the above to run on GPU for a person knowledgeable about AD and relatively proficient in writing ADRules? While the common lore is that Zygote.jl does not do 2nd order gradients, I have a feeling it is more a problem of libraries than that of the Zygote.jl itself. Seems like NNLib was not really designed with high order gradients in mind.
I can probably write some customlogsoftmax
to work on CUDA, though my first attempt has the problem described here LoadError: `llvmcall` must be compiled to be called when calling Zygote.Jacobian, very likely due to doing AD over broadcasting. - Am I doing something wrong with Enzyme.jl? Or the problem is somewhere deeper?
- Is there any chance to make this work with Reactant.jl? Do I invoke it incorrectly?
Many thanks for answers, suggestions, and help in advance.
Information about my environment is below
version info
(tmp) pkg> st
Status `/private/tmp/Project.toml`
[d360d2e6] ChainRulesCore v1.25.0
[7da242da] Enzyme v0.13.26
[26cc04aa] FiniteDifferences v0.12.32
[d9f16b24] Functors v0.5.2
[b2108857] Lux v1.4.3
[872c559c] NNlib v0.9.26
[0b1bfda6] OneHotArrays v0.2.6
[3bd65402] Optimisers v0.4.2
[3c362404] Reactant v0.2.12
[e88e6eb3] Zygote v0.6.74
julia> versioninfo()
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: macOS (arm64-apple-darwin24.0.0)
CPU: 11 × Apple M3 Pro
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, apple-m2)
Threads: 1 default, 0 interactive, 1 GC (on 5 virtual cores)