Let’s say I have some nested struct (in my case the nesting is arbitrarily complicated)
c = Bar(Foo([2.0]), 1.0)
g = Zygote.gradient(c) do x
g will be a tuple of the form
My problem is that I have no idea how to apply this gradient automatically on my object.
I know there is
Flux.params to compute the gradients implicitly but it has big disadvantages like getting saving unwanted gradients and also it just fails in a lot of cases for me.
What should I do?
I guess you are referring to the mechanism for implicit gradients which is a mechanism for when you in advance know exactly which parts of that nested struct you want the gradients for. It should in other words not have the problem of getting unwanted gradients.
Flux.params is just
Fluxs way of conveniently returning all
AbstractArrays found in the nested struct, but afaik it is not tied to that mechanism.
Anyways, the docs (in the same section I linked) recommend to not use that approach so it might be better to work with the output you have got there.
I don’t know what is the best way, but I think you should be able to use
getfield to traverse the struct. I have found that Julias multiple dispatch makes it relatively painless to recurse into nested structs. Here is an untested skeleton implementation:
apply_gradient(g::NamedTuple, s) = foreach(pairs(g)) do (fieldname, subgradient)
apply_gradient(subgradient, getfield(s, fieldname))
function apply_gradient(::Nothing, x) end # No gradient -> do nothing
apply_gradient(g::AbstractArray, p::AbstractArray) = g .- p #might want to propagate some policy (e.g. a learning rate) as a third argument
You might need to add a few methods there depending on what one might find in your structure (e.g. if there are arrays or tuples of structs in there).
Thanks I was exactly looking for something like this