Ignore_derivatives of entire module with ChainRules.jl

I’m trying to get reverse differentiation up and running on some simulations. Central to the code is a module that does a lot of array manipulation and bit shifting, which are things that Zygote.jl will complain about. None of this is relevant for actually computing the gradients that I want, so I’m trying to figure out how to just wrap the entire module with an ignore_derivatives command. Is there a concise way of doing this without having to add it inside every function? I’ve tried a few different ways but I can’t seem to get the syntax right.

1 Like

The hacky “lets not think about this” way is to do some code generation.
You can use something like

for name in names(MyModule; all=true)
    func = getfield(MyModule, name)
    !isempty(methods(func)) && continue  # not actually a function
    
    println("@non_differentiable $name(args...; kwargs...)")
end

to create a nice file with them all filled in.
Though probably you just need to define it for some of the functions.


@ignore_derivatives can only be applied to function calls. Not to modules or function definitions.
You could use MacroTools or MLStyle.jl to recursively match your code AST looking for calls.
But it is definately not trivial.
Could be good excuse to use the 2 argument form of include though.

1 Like