I’m trying to get reverse differentiation up and running on some simulations. Central to the code is a module that does a lot of array manipulation and bit shifting, which are things that Zygote.jl will complain about. None of this is relevant for actually computing the gradients that I want, so I’m trying to figure out how to just wrap the entire module with an ignore_derivatives
command. Is there a concise way of doing this without having to add it inside every function? I’ve tried a few different ways but I can’t seem to get the syntax right.
1 Like
The hacky “lets not think about this” way is to do some code generation.
You can use something like
for name in names(MyModule; all=true)
func = getfield(MyModule, name)
!isempty(methods(func)) && continue # not actually a function
println("@non_differentiable $name(args...; kwargs...)")
end
to create a nice file with them all filled in.
Though probably you just need to define it for some of the functions.
@ignore_derivatives
can only be applied to function calls. Not to modules or function definitions.
You could use MacroTools or MLStyle.jl to recursively match your code AST looking for calls.
But it is definately not trivial.
Could be good excuse to use the 2 argument form of include
though.
1 Like