Which direction: DifferentiatonInterface, Enzyme, Zygote with CUDA and FFTs?

Hi!

For the third time, I am trying to implement some wave optics algorithms in Julia but each time I start, I suffer a bit from the current automatic differentiation status.

In principle, my functions look like:

function propagate(field::AbstractArray{T, 6}, distance::NumberOrArray, 
                   wavelengths::NumberOrArray, pixel_size)
    H = # some kernel calculated with distances and wavelengths

    field_p = # some processing on the field with wavelengths and distance

    field_prop = ifft(fft(field_p) .* H) # roughly
    return field_prop
end

The adjoint is a bit nasty to write manually since distance or wavelengths can be both scalars or vectors. The output will be 6 dimensional ((x,y,z, wavelengths, polarization, batch)), but some dimensions are singleton depending on the types of distance and wavelengths.

Anyway, it seems like Enzyme.jl is not supporting FFTs at the moment (latest is this issue or older this). So in my current code I use a mix of Zygote and custom written ChainRules.jl.

I’d love to use DifferentiationInterface.jl but if FFTs do not work, that’s gonna be a hard time.

Has anyone an idea what I should do? I feel like using Zygote.jl is not quite future proof as everyone else seem to switch to Enzyme.jl.

Honestly, I seriously consider to implement it in another language because this pain has been ongoing for years (I started hitting those general issues in 2020). I know there is progress on this front and things have been improved but they haven’t reached my feature demands yet (CUDA + FFT + complex numbers). And implementing an Enzyme rule seems just very hard for me. I would love to help out a bit but right now I want to focus on my real research problems.

I see similar packages implemented in PyTorch and JAX and it works well for them, so I am really a bit hopeless and cannot recommend Julia to them right now.

Best,

Felix

Hi Felix!

I just want to point out that the question of using DifferentiationInterface is pretty much orthogonal to the choice of backend. At least if you fit inside the restrictions of DI (a single active argument), you can choose between Enzyme, Zygote or anything else you fancy. Whether FFTs work or not depends on the existence of rules for a given backend, not on the DI infrastructure, see this docs page for details.

In fact, using DI can actually make your code more resilient and allow easy switches between autodiff systems down the road. One of the main motivations for developing DI was to encourage this kind of seamless transition, e.g. from Zygote to Enzyme. Of course it comes with a caveat that DI induces some overhead and/or bugs which can be avoided with a backend’s native API. If you run into something like that, please open an issue so we can work to fix it!

Have you tried implementing the Enzyme FFT rule yourself? Maybe it’s not out of reach?

Alternately, have you tried Mooncake.jl? Even if it doesn’t have FFT rules either, its rule system is less complex than that of Enzyme, so it may be an easier starting point.
EDIT: I don’t think Mooncake fully supports CUDA at the moment, so it doesn’t answer your full question. Sorry.

Honestly, that’s fair. If it ain’t broke, don’t fix it. It would be great to get this in Julia too, but if you need it for your research and Python has it, no one will blame you ^^

2 Likes

Reactant.jl supports FFT (and its differentiation), used quite extensively in NeuralOperators.jl (it also runs on CUDA, TPU, CPU and whatever other accelerator you might want).

Is your field_p a 6-dimensional tensor? (we added upto 3-dims for FFTs in reactant though it is easy to expand if that is what is needed)

1 Like

But it’s not supported by Enzyme? What would I to do to use it?

Yes, it’s a 6dim array. FFTs are only done along 1st and 2nd dimension though.

Reactant.jl uses Enzyme.jl for autodiff support. So you just need to @compile Enzyme.gradient(....).

Then it should just work

1 Like

Is there any further link? Didn’t see anything specific in the docs.

So reactant introduces some FFT rules but they are not included in Enzyme itself?

Compiling Lux Models using Reactant.jl | Lux.jl Docs has an example of compiling enzyme.gradient. I just realized we don’t show Enzyme examples in Reactant (will add something today)

So Reactant uses EnzymeMLIR, in-contract to EnzymeLLVM both interfaced via Enzyme.jl (so no change is required on user side except moving data via Reactant.to_rarray and doing a @compile). We have the FFT rules defined for EnzymeMLIR.

1 Like

Not quite correct, but an analogy is that Reactant is similar to jax.jit, and Enzyme is similar to jax.gradient. The difference here is that if you Reactant.@compile a function, it converts your julia code into a nice tensor program that can be automatically optimized/fused/etc and executed on whatever backend you want, including CPU, GPU, TPU, and distributed versions thereof. Essentially all programs in this tensor IR are differentiable by Enzyme(MLIR), so if it reactant.compile’s, you’re good!

As for Enzyme.jl without Reactant compilation, there’s no technical reason for it – I think someone just needs to write the rule and/or use Enzyme.import_rrule to import the chainrule like described in the issue you linked above. Maybe you, @ptiede or someone are interested in getting that over the line?

That said I’d recommend just using Reactant.@compile, with Enzyme on the inside. It’ll give you all the things you say are missing, and much more :slight_smile: .

1 Like
1 Like

Wait so when you Reactant.@compile with Enzyme.gradient on the inside, Enzyme nonetheless differentiates the Reactant-generated code?

yes

julia> using Reactant, Enzyme

julia> function foo(x)
           return sum(x)
       end

julia> x = Reactant.to_rarray(ones(10));

julia> @code_hlo optimize=false Enzyme.gradient(Reverse, foo, x)
module @reactant_gradient attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func private @identity_broadcast_scalar(%arg0: tensor<f64>) -> tensor<f64> {
    return %arg0 : tensor<f64>
  }
  func.func private @"Const{typeof(foo)}(Main.foo)_autodiff"(%arg0: tensor<10xf64>) -> (tensor<f64>, tensor<10xf64>) {
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
    %1 = stablehlo.convert %cst : tensor<f64>
    %2 = enzyme.batch @identity_broadcast_scalar(%0) {batch_shape = array<i64: 10>} : (tensor<10xf64>) -> tensor<10xf64>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f64>
    %3 = stablehlo.convert %cst_0 : tensor<f64>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f64>
    %4 = stablehlo.convert %cst_1 : tensor<f64>
    %5 = stablehlo.reduce(%2 init: %1) applies stablehlo.add across dimensions = [0] : (tensor<10xf64>, tensor<f64>) -> tensor<f64>
    %6 = stablehlo.transpose %2, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    return %5, %6 : tensor<f64>, tensor<10xf64>
  }
  func.func @main(%arg0: tensor<10xf64> {tf.aliasing_output = 1 : i32}) -> (tensor<10xf64>, tensor<10xf64>) {
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
    %1 = stablehlo.convert %cst : tensor<f64>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<10xf64>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f64>
    %2 = stablehlo.convert %cst_1 : tensor<f64>
    %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<f64>) -> tensor<10xf64>
    %cst_2 = stablehlo.constant dense<1.000000e+00> : tensor<f64>
    %4 = stablehlo.transpose %0, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    %5 = stablehlo.transpose %3, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    %6:2 = enzyme.autodiff @"Const{typeof(foo)}(Main.foo)_autodiff"(%4, %cst_2, %5) {activity = [#enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_active>]} : (tensor<10xf64>, tensor<f64>, tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>)
    %7 = stablehlo.transpose %6#0, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    %8 = stablehlo.transpose %6#1, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    %9 = stablehlo.transpose %8, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    %10 = stablehlo.transpose %7, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    return %9, %10 : tensor<10xf64>, tensor<10xf64>
  }
}

julia> @code_hlo Enzyme.gradient(Reverse, foo, x)
module @reactant_gradient attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<10xf64>) -> tensor<10xf64> {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<10xf64>
    return %cst : tensor<10xf64>
  }
}
1 Like

Generally using Reactant compilation makes thing more likely to be efficiently differentiated by Enzyme. It removes type instabilities, preserves structure in IR, and lots of other good things that make differentiation (and also just the performance of the original code) better.

The code snippets pulled together in this discourse thread:

work fine on Julia 1.10. So far I cannot get them to work on 1.11 or 1.12. I also haven’t tested on GPU.

A more detailed autodiff tutorial in Reactant Automatic Differentiation | Reactant.jl

2 Likes