Sorry, to be precise: I meant it’s better to use the FFT plan instead of bfft
, especially with CUDA this usually allocates a lot of memory.
Something along the lines like this:
....
Pt = P' # using the plan instead of bfft
scale = P.scale
project_x = ChainRulesCore.ProjectTo(x)
project_scale = ChainRulesCore.ProjectTo(scale)
function mul_scaledplan_pullback(ȳ)
x̄ = ChainRulesCore.@thunk(project_x(Pt * ȳ))
scale_tangent = ChainRulesCore.@thunk(project_scale(AbstractFFTs.dot(y, ȳ) / conj(scale)))
plan_tangent = ChainRulesCore.Tangent{typeof(P)}(;p=ChainRulesCore.NoTangent(), scale=scale_tangent)
...
Source