Which direction: DifferentiatonInterface, Enzyme, Zygote with CUDA and FFTs?

Hi!

For the third time, I am trying to implement some wave optics algorithms in Julia but each time I start, I suffer a bit from the current automatic differentiation status.

In principle, my functions look like:

function propagate(field::AbstractArray{T, 6}, distance::NumberOrArray, 
                   wavelengths::NumberOrArray, pixel_size)
    H = # some kernel calculated with distances and wavelengths

    field_p = # some processing on the field with wavelengths and distance

    field_prop = ifft(fft(field_p) .* H) # roughly
    return field_prop
end

The adjoint is a bit nasty to write manually since distance or wavelengths can be both scalars or vectors. The output will be 6 dimensional ((x,y,z, wavelengths, polarization, batch)), but some dimensions are singleton depending on the types of distance and wavelengths.

Anyway, it seems like Enzyme.jl is not supporting FFTs at the moment (latest is this issue or older this). So in my current code I use a mix of Zygote and custom written ChainRules.jl.

I’d love to use DifferentiationInterface.jl but if FFTs do not work, that’s gonna be a hard time.

Has anyone an idea what I should do? I feel like using Zygote.jl is not quite future proof as everyone else seem to switch to Enzyme.jl.

Honestly, I seriously consider to implement it in another language because this pain has been ongoing for years (I started hitting those general issues in 2020). I know there is progress on this front and things have been improved but they haven’t reached my feature demands yet (CUDA + FFT + complex numbers). And implementing an Enzyme rule seems just very hard for me. I would love to help out a bit but right now I want to focus on my real research problems.

I see similar packages implemented in PyTorch and JAX and it works well for them, so I am really a bit hopeless and cannot recommend Julia to them right now.

Best,

Felix

2 Likes

Hi Felix!

I just want to point out that the question of using DifferentiationInterface is pretty much orthogonal to the choice of backend. At least if you fit inside the restrictions of DI (a single active argument), you can choose between Enzyme, Zygote or anything else you fancy. Whether FFTs work or not depends on the existence of rules for a given backend, not on the DI infrastructure, see this docs page for details.

In fact, using DI can actually make your code more resilient and allow easy switches between autodiff systems down the road. One of the main motivations for developing DI was to encourage this kind of seamless transition, e.g. from Zygote to Enzyme. Of course it comes with a caveat that DI induces some overhead and/or bugs which can be avoided with a backend’s native API. If you run into something like that, please open an issue so we can work to fix it!

Have you tried implementing the Enzyme FFT rule yourself? Maybe it’s not out of reach?

Alternately, have you tried Mooncake.jl? Even if it doesn’t have FFT rules either, its rule system is less complex than that of Enzyme, so it may be an easier starting point.
EDIT: I don’t think Mooncake fully supports CUDA at the moment, so it doesn’t answer your full question. Sorry.

Honestly, that’s fair. If it ain’t broke, don’t fix it. It would be great to get this in Julia too, but if you need it for your research and Python has it, no one will blame you ^^

6 Likes

Reactant.jl supports FFT (and its differentiation), used quite extensively in NeuralOperators.jl (it also runs on CUDA, TPU, CPU and whatever other accelerator you might want).

Is your field_p a 6-dimensional tensor? (we added upto 3-dims for FFTs in reactant though it is easy to expand if that is what is needed)

2 Likes

But it’s not supported by Enzyme? What would I to do to use it?

Yes, it’s a 6dim array. FFTs are only done along 1st and 2nd dimension though.

Reactant.jl uses Enzyme.jl for autodiff support. So you just need to @compile Enzyme.gradient(....).

Then it should just work

1 Like

Is there any further link? Didn’t see anything specific in the docs.

So reactant introduces some FFT rules but they are not included in Enzyme itself?

Compiling Lux Models using Reactant.jl | Lux.jl Docs has an example of compiling enzyme.gradient. I just realized we don’t show Enzyme examples in Reactant (will add something today)

So Reactant uses EnzymeMLIR, in-contract to EnzymeLLVM both interfaced via Enzyme.jl (so no change is required on user side except moving data via Reactant.to_rarray and doing a @compile). We have the FFT rules defined for EnzymeMLIR.

2 Likes

Not quite correct, but an analogy is that Reactant is similar to jax.jit, and Enzyme is similar to jax.gradient. The difference here is that if you Reactant.@compile a function, it converts your julia code into a nice tensor program that can be automatically optimized/fused/etc and executed on whatever backend you want, including CPU, GPU, TPU, and distributed versions thereof. Essentially all programs in this tensor IR are differentiable by Enzyme(MLIR), so if it reactant.compile’s, you’re good!

As for Enzyme.jl without Reactant compilation, there’s no technical reason for it – I think someone just needs to write the rule and/or use Enzyme.import_rrule to import the chainrule like described in the issue you linked above. Maybe you, @ptiede or someone are interested in getting that over the line?

That said I’d recommend just using Reactant.@compile, with Enzyme on the inside. It’ll give you all the things you say are missing, and much more :slight_smile: .

3 Likes
1 Like

Wait so when you Reactant.@compile with Enzyme.gradient on the inside, Enzyme nonetheless differentiates the Reactant-generated code?

yes

julia> using Reactant, Enzyme

julia> function foo(x)
           return sum(x)
       end

julia> x = Reactant.to_rarray(ones(10));

julia> @code_hlo optimize=false Enzyme.gradient(Reverse, foo, x)
module @reactant_gradient attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func private @identity_broadcast_scalar(%arg0: tensor<f64>) -> tensor<f64> {
    return %arg0 : tensor<f64>
  }
  func.func private @"Const{typeof(foo)}(Main.foo)_autodiff"(%arg0: tensor<10xf64>) -> (tensor<f64>, tensor<10xf64>) {
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
    %1 = stablehlo.convert %cst : tensor<f64>
    %2 = enzyme.batch @identity_broadcast_scalar(%0) {batch_shape = array<i64: 10>} : (tensor<10xf64>) -> tensor<10xf64>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f64>
    %3 = stablehlo.convert %cst_0 : tensor<f64>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f64>
    %4 = stablehlo.convert %cst_1 : tensor<f64>
    %5 = stablehlo.reduce(%2 init: %1) applies stablehlo.add across dimensions = [0] : (tensor<10xf64>, tensor<f64>) -> tensor<f64>
    %6 = stablehlo.transpose %2, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    return %5, %6 : tensor<f64>, tensor<10xf64>
  }
  func.func @main(%arg0: tensor<10xf64> {tf.aliasing_output = 1 : i32}) -> (tensor<10xf64>, tensor<10xf64>) {
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
    %1 = stablehlo.convert %cst : tensor<f64>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<10xf64>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f64>
    %2 = stablehlo.convert %cst_1 : tensor<f64>
    %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<f64>) -> tensor<10xf64>
    %cst_2 = stablehlo.constant dense<1.000000e+00> : tensor<f64>
    %4 = stablehlo.transpose %0, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    %5 = stablehlo.transpose %3, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    %6:2 = enzyme.autodiff @"Const{typeof(foo)}(Main.foo)_autodiff"(%4, %cst_2, %5) {activity = [#enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_active>]} : (tensor<10xf64>, tensor<f64>, tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>)
    %7 = stablehlo.transpose %6#0, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    %8 = stablehlo.transpose %6#1, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    %9 = stablehlo.transpose %8, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    %10 = stablehlo.transpose %7, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    return %9, %10 : tensor<10xf64>, tensor<10xf64>
  }
}

julia> @code_hlo Enzyme.gradient(Reverse, foo, x)
module @reactant_gradient attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<10xf64>) -> tensor<10xf64> {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<10xf64>
    return %cst : tensor<10xf64>
  }
}
1 Like

Generally using Reactant compilation makes thing more likely to be efficiently differentiated by Enzyme. It removes type instabilities, preserves structure in IR, and lots of other good things that make differentiation (and also just the performance of the original code) better.

1 Like

The code snippets pulled together in this discourse thread:

work fine on Julia 1.10. So far I cannot get them to work on 1.11 or 1.12. I also haven’t tested on GPU.

A more detailed autodiff tutorial in Reactant Automatic Differentiation | Reactant.jl

4 Likes

I have some private code where I added Enzyme support for FFTs (non-reactant version). I didn’t consider it production ready but I can share it clean it up and we could make a PR.

4 Likes

Does it only work with fft(x) but not with FFT plans p = plan_fft(x); p * x?

# ╔═╡ 3eadd964-9079-11f0-2f63-f7227db0c8f8
using Enzyme, Reactant, ImageShow, ImageIO, FFTW

# ╔═╡ 45f5d151-e496-4526-9333-5ee68d38c496
arr = rand(Float32, (100, 100));

# ╔═╡ 3d59315e-cc6a-4d19-aeb2-06cdb7372f8f
p = plan_fft(arr)

# ╔═╡ 597cddb8-cceb-4af7-afae-1f0958ce3d55
loss_function(x) = sum(abs2.((p * x)))

# ╔═╡ d11dba54-6d60-4a90-b783-774c14d8c1c6
x = Reactant.to_rarray(arr);

# Compute gradient using reverse mode

# ╔═╡ d06c79f7-502c-47c5-ae0a-21f48fa8601b
function f(x)
	return @jit Enzyme.gradient(Reverse, loss_function, x)
end

# ╔═╡ d43b834f-0a96-4caa-a9f2-82abc8f546c0
f_compiled = @compile f(x)

Error

Scalar indexing is disallowed.

Invocation of getindex(::TracedRArray, ::Union{Int, TracedRNumber{Int}}) resulted in scalar indexing of a GPU array.

This is typically caused by calling an iterating implementation of a method.

Such implementations do not execute on the GPU, but very slowly on the CPU,

and therefore should be avoided.

If you want to allow scalar iteration, use allowscalar or @allowscalar

to enable scalar iteration globally or for the operations in question.

Stack trace

Here is what happened, the most recent locations are first:

  1. error

from error.jl:35

  1. (::Nothing) (none::typeof(error), none::String)

from Reactant

  • ErrorException

from boot.jl:323

  • error

from error.jl:35

  • call_with_reactant (::Reactant.MustThrowError, ::typeof(error), ::String)

from Reactantutils.jl

  • errorscalar

from GPUArraysCore.jl:151

  • (::Nothing) (none::typeof(GPUArraysCore.errorscalar), none::String)

from Reactant

  • string

from substring.jl:236

  • scalardesc

from GPUArraysCore.jl:134

  • errorscalar

from GPUArraysCore.jl:150

  • call_with_reactant (::Reactant.MustThrowError, ::typeof(GPUArraysCore.errorscalar), ::String)

from Reactantutils.jl

  • _assertscalar

from GPUArraysCore.jl:124

  • (::Nothing) (none::typeof(GPUArraysCore._assertscalar), none::String, none::GPUArraysCore.ScalarIndexing)

from Reactant

  • _assertscalar

from GPUArraysCore.jl:123

  • call_with_reactant (::typeof(GPUArraysCore._assertscalar), ::String, ::GPUArraysCore.ScalarIndexing)

from Reactantutils.jl

  • assertscalar

from GPUArraysCore.jl:112

  • (::Nothing) (none::typeof(GPUArraysCore.assertscalar), none::String)

from Reactant

  • current_task

from task.jl:152

  • task_local_storage

from task.jl:280

  • assertscalar

from GPUArraysCore.jl:97

  • call_with_reactant (::typeof(GPUArraysCore.assertscalar), ::String)

from Reactantutils.jl

  • getindex

from TracedRArray.jl:120

  • opaque closure (none::typeof(getindex), none::Reactant.TracedRArray{…}, none::Int64) …show types…

from Reactant

  • getindex

from TracedRArray.jl:120

  • call_with_reactant (::typeof(getindex), ::Reactant.TracedRArray{…}, ::Int64) …show types…

from Reactantutils.jl

  • copyto_unaliased!

from abstractarray.jl:1081

  • copyto!

from abstractarray.jl:1061

  • opaque closure (none::typeof(copyto!), none::Matrix{…}, none::Reactant.TracedRArray{…}) …show types…

from Reactant

  • getproperty

from Base.jl:49

  • size

from TracedRArray.jl:489

  • length

from abstractarray.jl:315

  • isempty

from abstractarray.jl:1212

  • copyto!

from abstractarray.jl:1055

  • call_with_reactant (::typeof(copyto!), ::Matrix{…}, ::Reactant.TracedRArray{…}) …show types…

from Reactantutils.jl

  • circcopy!

from multidimensional.jl:1303

  • copy1

from definitions.jl:54

  • *****

from definitions.jl:224

  • opaque closure (none::typeof(*), none::FFTW.cFFTWPlan{…}, none::Reactant.TracedRArray{…}) …show types…

from Reactant

  • getproperty

from Base.jl:49

  • size

from TracedRArray.jl:489

  • axes

from abstractarray.jl:98

  • copy1

from definitions.jl:53

  • *****

from definitions.jl:224

  • call_with_reactant (::typeof(*), ::FFTW.cFFTWPlan{…}, ::Reactant.TracedRArray{…}) …show types…

from Reactantutils.jl

  • loss_function

from Other cell: line 1

[

loss_function(x) = sum(abs2.((p * x)))

](http://localhost:1234/edit?id=3eadd8bc-9079-11f0-2f23-a30cd90fb73b#597cddb8-cceb-4af7-afae-1f0958ce3d55)

  • opaque closure (none::typeof(loss_function), none::Reactant.TracedRArray{…}) …show types…

from Reactant

  • loss_function

from Other cell: line 1

[

loss_function(x) = sum(abs2.((p * x)))

](http://localhost:1234/edit?id=3eadd8bc-9079-11f0-2f23-a30cd90fb73b#597cddb8-cceb-4af7-afae-1f0958ce3d55)

  • call_with_reactant (::typeof(loss_function), ::Reactant.TracedRArray{…}) …show types…

from Reactantutils.jl

  • #make_mlir_fn#6 (f::typeof(loss_function), args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Nothing, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool) …show types…

from ReactantTracedUtils.jl:330

  • make_mlir_fn

from TracedUtils.jl:260

  • overload_autodiff (::EnzymeCore.ReverseMode{…}, f::EnzymeCore.Const{…}, ::Type{…}, args::EnzymeCore.Duplicated{…}) …show types…

from ReactantEnzyme.jl:303

  • autodiff (rmode::EnzymeCore.ReverseMode{…}, f::EnzymeCore.Const{…}, rt::Type{…}, args::EnzymeCore.Duplicated{…}) …show types…

from ReactantOverlay.jl:21

  • autodiff

from Enzyme.jl:538

  • macro expansion

from sugar.jl:324

  • gradient

from sugar.jl:262

  • f

from Other cell: line 2509

  • opaque closure (none::typeof(f), none::Reactant.TracedRArray{…}) …show types…

from Reactant

  • GenericMemory

from boot.jl:516

  • IdDict

from iddict.jl:31

  • IdDict

from iddict.jl:49

  • make_zero

from EnzymeCore.jl:587

  • macro expansion

from sugar.jl:321

  • gradient

from sugar.jl:262

  • f

from Other cell: line 2509

  • call_with_reactant (::typeof(f), ::Reactant.TracedRArray{…}) …show types…

from Reactantutils.jl

  • #make_mlir_fn#6 (f::typeof(f), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool) …show types…

from ReactantTracedUtils.jl:330

  • #compile_mlir!#15 (mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::Reactant.CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{}) …show types…

from ReactantCompiler.jl:1544

  • compile_mlir!

from Compiler.jl:1511

  • #compile_xla#58 (f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…}) …show types…

from ReactantCompiler.jl:3420

  • compile_xla

from Compiler.jl:3393

  • #compile#59 (f::Function, args::Tuple{…}; kwargs::@Kwargs{…}) …show types…

from ReactantCompiler.jl:3492

  • compile

from Compiler.jl:3489

  • macro expansion

from Compiler.jl:2573

[

f_compiled = @compile f(x)

](http://localhost:1234/edit?id=3eadd8bc-9079-11f0-2f23-a30cd90fb73b#d43b834f-0a96-4caa-a9f2-82abc8f546c0)

I think an Enzyme rule that works with FFT plans is:

using Enzyme, FFTW

grad_safe_fft!(arr, fft_plan) = fft_plan * arr

function EnzymeRules.augmented_primal(config, ::Const{typeof(grad_safe_fft!)}, t,
                                      arr, fft_plan)
    fft_plan.val * arr.val
    return EnzymeRules.AugmentedReturn(nothing, nothing, nothing)
end

function EnzymeRules.reverse(config, ::Const{typeof(grad_safe_fft!)}, dret, tape,
                             arr, fft_plan)
    arr.dval .= bfft(arr.dval)
    return (nothing, nothing)
end

This is based on the FFT rrule in AbstractFFTs.jl, the other variants are available there.

2 Likes

Is the bfft correct here? Why don’t you use inv(fft_plan)?

bfft matches the rrule and passes tests for my use case. See some discussion on the original PR. inv(fft_plan) may also work.

2 Likes

bfft is the correct adjoint operator (it is literally the adjoint, the conjugate-transpose, of fft). If you used inv(fft_plan) or ifft then it would introduce an additional 1/n scale factor that you don’t want.

4 Likes