Which direction: DifferentiatonInterface, Enzyme, Zygote with CUDA and FFTs?

Sorry, to be precise: I meant it’s better to use the FFT plan instead of bfft, especially with CUDA this usually allocates a lot of memory.

Something along the lines like this:

    ....
    Pt = P'  # using the plan instead of bfft
    scale = P.scale
    project_x = ChainRulesCore.ProjectTo(x)
    project_scale = ChainRulesCore.ProjectTo(scale)
    function mul_scaledplan_pullback(ȳ)
        x̄ = ChainRulesCore.@thunk(project_x(Pt * ȳ))
        scale_tangent = ChainRulesCore.@thunk(project_scale(AbstractFFTs.dot(y, ȳ) / conj(scale)))
        plan_tangent = ChainRulesCore.Tangent{typeof(P)}(;p=ChainRulesCore.NoTangent(), scale=scale_tangent)
        ...

Source

2 Likes