How can I obtain pullback for a Flux model not only w.r.t. its params but also its inputs ?
outputs, back = Zygote.pullback( () -> model(inputs), Flux.params(model) )
returns the adjoint (i.e., back
) for the model
params values only…
How can I obtain pullback for a Flux model not only w.r.t. its params but also its inputs ?
outputs, back = Zygote.pullback( () -> model(inputs), Flux.params(model) )
returns the adjoint (i.e., back
) for the model
params values only…
Maybe that’s the sort of thing that would be easier with Lux.jl?
One way is to avoid Flux.params
in favour of explicit gradients, and write
grad_in, grad_model = gradient(|>, inputs, model)
If you must use implicit gradients, then I believe you could write something like
grad = gradient(() -> model(inputs), Flux.params((model, inputs)))
To be more precise you may need to tell us what inputs
are.
(These could both use withgradient
or pullback
instead.)
Thanks @gdalle ! It indeed does the job, but I might need Flux layers that are not yet mirrored inside Lux.
I’ll let a REPL code snippet below for a curious reader of this particular topic :
julia> using Lux, Random, Zygote; rng = Random.default_rng(); Random.seed!(rng, 0); model = Chain(Dense(1, 4), Dense(4, 1)); ps, st = Lux.setup(rng, model); inputs = rand(rng, 1, 10);
julia> outputs, back = Zygote.pullback( (inputs_, ps_) -> model(inputs_, ps_, st), inputs, ps );
julia> grads = back( deepcopy(outputs) );
julia> grads[1] # for the model's inputs
1×10 Matrix{Float64}:
0.0459606 1.15793 1.15771 1.51406 0.65143 1.1169 0.677322 0.826841 1.39484 0.655619
julia> grads[2] # for the model's parameters
(layer_1 = (weight = Float32[-3.4613643; 4.178537; -2.6454673; -4.777247;;], bias = Float32[-4.8556366; 5.8616934; -3.7110882; -6.70157;;]), layer_2 = (weight = Float32[-0.16420026 -1.0880805 -5.5679674 -4.978304], bias = Float32[7.3879676;;]))
Thank you @mcabbott ! Could you explain further how |>
is doing it under the hood ? I see that the explicit gradients are coming as a result, indeed, rather than the Grads(...)
object :
julia> back( deepcopy(outputs) )
([0.0023901141181289407 0.0035153696753484334 … 0.0029914787527228937 0.0003643633434114374], (layers = ((weight = Float32[-0.0073440215; -0.10114246; 0.0638036; 0.012706678;;], bias = Float32[-0.014901155, -0.20521991, 0.12945868, 0.025782084], σ = nothing), (weight = Float32[0.045871202 0.03815957 0.07034941 0.062251702], bias = Float32[0.2531139], σ = nothing)),))
Are there limitations on obtaining explicit/implicit gradients for some type of inputs
?
Could you explain further how
|>
is doing it under the hood ?
|>
is just a shorthand for (inputs_, model_) -> model_(inputs_)
.
What you write in the Lux example is exactly the same Zygote explicit mode as this. In both cases grads[2]
is some nested thing like (layers = ((weight = [...
, and grads[1] == grad_in
is something matching the structure of inputs
(which here is just an array).
The biggest difference with Lux is another meaning of “explicit”, about the handling of the model state st
, which is one of the things returned by apply
. For a model this simple, it is trivial st == (layer_1 = NamedTuple(), layer_2 = NamedTuple())
, so you can ignore it.
Could you mention which Flux layers you need and aren’t mirrored? (Will help me prioritize stuff in Flux Feature Parity · Issue #13 · avik-pal/Lux.jl · GitHub)
Also, lux.csail.mit.edu/stable/manual/migrate_from_flux/#can-we-still-use-flux-layers might interest you.
Also, you shouldn’t have to deepcopy, if any input is mutated consider that a bug and open an issue.
Hi @avikpal ! I’d say Maxout can be eventually used in my cases. Thanks for the link on using Flux layers inside Lux !
No inputs or outputs are mutated in the case above.
I could have written back(outputs)
rather than back( deepcopy(outputs) )
, indeed !