Question about writting custom training function using Flux.jl

I think the problem is in your logging statement, gs[1] is asking for the gradient with respect to 1, you need to request the gradient with respect to a parameter, for example gs[ps[1]].

1 Like