Using gradients from struct with Zygote


Let’s say I have some nested struct (in my case the nesting is arbitrarily complicated)

struct Foo

struct Bar

c = Bar(Foo([2.0]), 1.0)

When using

g = Zygote.gradient(c) do x

g[1] will be a tuple of the form (c=nothing, b=(a=[...])).

My problem is that I have no idea how to apply this gradient automatically on my object.

I know there is Flux.params to compute the gradients implicitly but it has big disadvantages like getting saving unwanted gradients and also it just fails in a lot of cases for me.

What should I do?

I guess you are referring to the mechanism for implicit gradients which is a mechanism for when you in advance know exactly which parts of that nested struct you want the gradients for. It should in other words not have the problem of getting unwanted gradients. Flux.params is just Fluxs way of conveniently returning all AbstractArrays found in the nested struct, but afaik it is not tied to that mechanism.

Anyways, the docs (in the same section I linked) recommend to not use that approach so it might be better to work with the output you have got there.

I don’t know what is the best way, but I think you should be able to use getfield to traverse the struct. I have found that Julias multiple dispatch makes it relatively painless to recurse into nested structs. Here is an untested skeleton implementation:

apply_gradient(g::NamedTuple, s) = foreach(pairs(g)) do (fieldname, subgradient)
              apply_gradient(subgradient, getfield(s, fieldname))

function apply_gradient(::Nothing, x) end # No gradient -> do nothing

apply_gradient(g::AbstractArray, p::AbstractArray) = g .- p #might want to propagate some policy (e.g. a learning rate) as a third argument

You might need to add a few methods there depending on what one might find in your structure (e.g. if there are arrays or tuples of structs in there).

1 Like

Thanks I was exactly looking for something like this

1 Like