Reactant: how to use it, limititations and opportunities?

After reading a few posts (and some suggestion from @wmshen0323 to try Enzyme), Reactant got my attention. My understanding is that, if I manage to compile my code with Reactant, I get for free to run my code on GPUs and I can differentiate that with Enzyme.

I did a few tests on my (very old) laptop and I was quite impressed.

Small, successful example
using Reactant, Tullio, LoopVectorization, FFTW, BenchmarkTools, LinearAlgebra
import AbstractFFTs

# --- 1. THE ORIGINAL PIPELINE (Pre-planned) ---
function full_pipeline_original(vals, T, plan)
    # Apply pre-computed plan
    coefs = plan * vals
    
    s = size(coefs)
    coefs ./= 2*(s[1]-1)
    
    # Interior doubling (optimized with broadcasting for comparison)
    weights = ones(Float64, s[1])
    weights[2:end-1] .= 2.0
    coefs .*= reshape(weights, (s[1], 1, 1))
    
    # Contract
    @tullio w[i,j,k] := coefs[j,k,l] * T[i,j,k,l]
    return w
end

# --- 2. THE REACTANT PIPELINE ---
function reactant_dct1_equivalent(x)
    n = size(x, 1)
    # Traceable mirroring logic
    mirrored = vcat(x, reverse(x[2:end-1, :, :], dims=1))
    transformed = real.(AbstractFFTs.fft(mirrored, 1))
    return transformed[1:n, :, :]
end

function fast_chebcoefs_reactant(vals)
    coefs = reactant_dct1_equivalent(vals)
    s1 = size(coefs, 1)
    weights = ones(Float64, s1)
    weights[2:end-1] .= 2.0
    scale = 1.0 / (2 * (s1 - 1))
    return (coefs .* scale) .* reshape(weights, (s1, 1, 1))
end

function full_pipeline_reactant(vals, T)
    c = fast_chebcoefs_reactant(vals)
    c_reshaped = reshape(c, 1, size(c)...)
    return dropdims(sum(c_reshaped .* T; dims=4); dims=4)
end

# --- 3. DATA SETUP ---
I, J, K, L = 20, 96, 48, 128
vals_cpu = rand(Float64, J, K, L)
T_cpu = rand(Float64, I, J, K, L)

# Pre-compute the FFTW plan
# Use PATIENT to give FFTW the best possible chance
println("Planning FFTW...")
plan = FFTW.plan_r2r(copy(vals_cpu), FFTW.REDFT00, [1]; flags=FFTW.PATIENT)

# Setup Reactant data and Compile
vals_ra = Reactant.Array(vals_cpu)
T_ra = Reactant.Array(T_cpu)

println("Compiling Reactant...")
compiled_pipe = @compile full_pipeline_reactant(vals_ra, T_ra)

# --- 4. VERIFICATION ---
res_orig = full_pipeline_original(vals_cpu, T_cpu, plan)
res_ra = compiled_pipe(vals_ra, T_ra)
diff = norm(Array(res_ra) - res_orig) / norm(res_orig)
println("\nVerification Relative Difference: $diff")

# --- 5. THE COMPARISON ---

println("\n--- BENCHMARK: ORIGINAL (Pre-planned FFTW + Tullio) ---")
display(@benchmark full_pipeline_original($vals_cpu, $T_cpu, $plan))

println("\n--- BENCHMARK: REACTANT (Compiled XLA Pipeline) ---")
display(@benchmark $compiled_pipe($vals_ra, $T_ra))

with results

--- BENCHMARK: ORIGINAL (Pre-planned FFTW + Tullio) ---

julia> display(@benchmark full_pipeline_original($vals_cpu, $T_cpu, $plan))
BenchmarkTools.Trial: 316 samples with 1 evaluation per sample.
 Range (min … max):  14.283 ms … 21.669 ms  β”Š GC (min … max): 0.00% … 21.24%
 Time  (median):     15.725 ms              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   15.810 ms Β±  1.110 ms  β”Š GC (mean Β± Οƒ):  1.24% Β±  2.98%

    β–„ ▁▆▅           β–‡β–ˆβ–ƒ ▁▂                                     
  β–‚β–‚β–ˆβ–„β–ˆβ–ˆβ–ˆβ–ƒβ–ƒβ–‚β–‚β–ƒβ–β–ƒβ–„β–ƒβ–‚β–ƒβ–ˆβ–ˆβ–ˆβ–†β–ˆβ–ˆβ–ˆβ–ƒβ–ƒβ–‚β–β–„β–‚β–β–ƒβ–β–ƒβ–…β–…β–‚β–‚β–ƒβ–ƒβ–…β–‚β–‚β–ƒβ–β–‚β–‚β–„β–ƒβ–‚β–„β–‚β–‚β–‚β–‚β–‚β–ƒβ–‚ β–ƒ
  14.3 ms         Histogram: frequency by time        18.6 ms <

 Memory estimate: 5.21 MiB, allocs estimate: 56.

julia> println("\n--- BENCHMARK: REACTANT (Compiled XLA Pipeline) ---")

--- BENCHMARK: REACTANT (Compiled XLA Pipeline) ---

julia> display(@benchmark $compiled_pipe($vals_ra, $T_ra))
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):  51.366 ΞΌs …  1.502 ms  β”Š GC (min … max):  0.00% … 90.80%
 Time  (median):     60.232 ΞΌs              β”Š GC (median):     0.00%
 Time  (mean Β± Οƒ):   72.141 ΞΌs Β± 93.269 ΞΌs  β”Š GC (mean Β± Οƒ):  13.85% Β±  9.98%

  β–ˆβ–ƒβ–‚β–‚                                                        ▁
  β–ˆβ–ˆβ–ˆβ–ˆβ–‡β–‡β–„β–„β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–„β–† β–ˆ
  51.4 ΞΌs      Histogram: log(frequency) by time       842 ΞΌs <

 Memory estimate: 720.09 KiB, allocs estimate: 3.

I was quite impressed!

This basically improves our previous results by a factor of 10…basically for free! Now, I have a few questions.

  1. Autodiff . Is it true that, if I can compile my program with Reactant, I will be able to differentiate with Enzyme (of course there might be edge cases, let’s say for a standard scenario). Can I/should I write some rrules or can I assume the performance is good in most cases? If not, can I just write rrules using Enzyme the usual way?

  2. SciML & ODE . An important part of my workflow is represented by solving ODEs with the SciML stack. Can I compile that part as well, or that one will stay outside of the compilation?

  3. Turing. After producing all of these high performance codes, I usually put them in some pipeline that leverages Turing to either run chains or compute bestfits. I usually don’t do weird thing on the Turing side (Uniform, Normal priors, eventually I use the submodels syntax). Can I expect that these pieces of Turing are gonna work ok with Enzyme?

  4. Compilation. One last point: compilation times. While I am more than happy with the final performance of the code, I wonder if there is anything I can do to reduce the compilation times. For some of the tests I am doing, rewriting some legacy codes, I see compilation times of order 20 minutes, that are more than compensated by the final performancem, but I wonder if there is any best practice or approach to reduce that compilation time.

Thanks in advance!

2 Likes

Too good to be true: the code I initially put in the example above was not working. There was some constant folding.

Here the corrected version of that code.
Now I am sad.

Corrected code
using Reactant
using Test
using BenchmarkTools
using LinearAlgebra
import AbstractFFTs
using FFTW

# --- TRACEABLE CHEBYSHEV PIPELINE ---

function reactant_dct1_equivalent(x)
    n = size(x, 1)
    mid = x[2:n-1, :, :]
    mirrored = cat(x, reverse(mid, dims=1), dims=1)
    
    # Use 0.0im instead of 0im to create ComplexF64
    mirrored_c = mirrored .+ 0.0im
    
    transformed_c = AbstractFFTs.fft(mirrored_c, 1)
    return real.(transformed_c[1:n, :, :])
end

function fast_chebcoefs_reactant(vals)
    coefs = reactant_dct1_equivalent(vals)
    s1 = size(coefs, 1)
    
    # Apply scale to everything
    scale = 1.0 / (2.0 * (s1 - 1))
    coefs_scaled = coefs .* scale
    
    # Double everything first
    result = coefs_scaled .* 2.0
    
    # Set first and last rows back to their original scaled values
    result[1:1, :, :] .= coefs_scaled[1:1, :, :]
    result[s1:s1, :, :] .= coefs_scaled[s1:s1, :, :]
    
    return result
end

function full_pipeline_reactant(vals, T)
    c = fast_chebcoefs_reactant(vals)
    c_reshaped = reshape(c, 1, size(c)...)
    return dropdims(sum(c_reshaped .* T; dims=2); dims=2)
end

# --- VERIFICATION & BENCHMARK ---

function verify_and_benchmark()
    println("=== Verifying Reactant Dynamic Behavior & Performance ===")
    
    vals = rand(Float64,  96, 48, 128)
    T = rand(Float64, 20, 96, 48, 128)

    println("\n1. Compilation with ConcreteRArray...")
    # Convert to ConcreteRArray to mark as dynamic inputs
    vals_concrete = Reactant.ConcreteRArray(vals)
    T_concrete = Reactant.ConcreteRArray(T)
    
    # Compile with ConcreteRArray inputs
    compiled = Reactant.@compile full_pipeline_reactant(vals_concrete, T_concrete)

    println("2. Running dynamic check...")
    
    # Create new ConcreteRArray for each call
    vals1_c = Reactant.ConcreteRArray(vals)
    T1_c = Reactant.ConcreteRArray(T)
    out1 = compiled(vals1_c, T1_c)
    
    vals2_c = Reactant.ConcreteRArray(2.0 .* vals)
    T2_c = Reactant.ConcreteRArray(T)
    out2 = compiled(vals2_c, T2_c)
    
    diff = maximum(abs.(out2 .- 2.0 .* out1))
    println("   Max Diff (out2 - 2*out1): ", diff)
    println("   out1 magnitude: ", maximum(abs.(out1)))
    println("   out2 magnitude: ", maximum(abs.(out2)))
    println("   2*out1 magnitude: ", maximum(abs.(2.0 .* out1)))
    
    if diff < 1e-10 && maximum(abs.(out1)) > 0
        println("   βœ“ SUCCESS: Pipeline is dynamic (No constant folding)!")
    else
        println("   βœ— FAILURE: Constant folding detected or numerical issue.")
    end

    println("\n3. Benchmarking...")
    println("   --- Pure Julia ---")
    @btime full_pipeline_reactant($vals, $T)
    
    println("   --- Reactant (compilation excluded) ---")
    vals_bench = Reactant.ConcreteRArray(vals)
    T_bench = Reactant.ConcreteRArray(T)
    @btime $compiled($vals_bench, $T_bench)
    
    println("   --- Reactant (with setup, proper benchmark) ---")
    p = @benchmark f($vals_concrete, $T_concrete) setup=(
        vals_concrete = Reactant.ConcreteRArray($vals);
        T_concrete = Reactant.ConcreteRArray($T);
        f = Reactant.@compile sync=true full_pipeline_reactant($vals_concrete, $T_concrete)
    )
    display(p)
end

if abspath(PROGRAM_FILE) == @__FILE__
    try
        verify_and_benchmark()
    catch e
        println("\nERROR: Compilation or Execution failed.")
        println("Error message: ", e)
        rethrow(e)
    end
end

with results

3. Benchmarking...
   --- Pure Julia ---
  85.813 ms (63 allocations: 166.88 MiB)
   --- Reactant (compilation excluded) ---
  23.924 ΞΌs (14 allocations: 432 bytes)
   --- Reactant (with setup, proper benchmark) ---
BenchmarkTools.Trial: 9 samples with 1 evaluation per sample.
 Range (min … max):  52.223 ms … 71.184 ms  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     62.318 ms              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   61.355 ms Β±  5.858 ms  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  ▁   ▁              ▁        ▁  ▁    β–ˆ   ▁                 ▁  
  β–ˆβ–β–β–β–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆβ–β–β–β–β–β–β–β–β–ˆβ–β–β–ˆβ–β–β–β–β–ˆβ–β–β–β–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆ ▁
  52.2 ms         Histogram: frequency by time        71.2 ms <

 Memory estimate: 480 bytes, allocs estimate: 15.