Pullback for both Flux model's inputs and parameters

How can I obtain pullback for a Flux model not only w.r.t. its params but also its inputs ?

outputs, back = Zygote.pullback( () -> model(inputs), Flux.params(model) )

returns the adjoint (i.e., back) for the model params values only…

Maybe that’s the sort of thing that would be easier with Lux.jl?

1 Like

One way is to avoid Flux.params in favour of explicit gradients, and write

grad_in, grad_model = gradient(|>, inputs, model)

If you must use implicit gradients, then I believe you could write something like

grad = gradient(() -> model(inputs), Flux.params((model, inputs)))

To be more precise you may need to tell us what inputs are.

(These could both use withgradient or pullback instead.)

3 Likes

Thanks @gdalle ! It indeed does the job, but I might need Flux layers that are not yet mirrored inside Lux.
I’ll let a REPL code snippet below for a curious reader of this particular topic :


julia> using Lux, Random, Zygote; rng = Random.default_rng(); Random.seed!(rng, 0); model = Chain(Dense(1, 4), Dense(4, 1)); ps, st = Lux.setup(rng, model); inputs = rand(rng, 1, 10);


julia> outputs, back = Zygote.pullback( (inputs_, ps_) -> model(inputs_, ps_, st), inputs, ps );


julia> grads = back( deepcopy(outputs) );


julia> grads[1] # for the model's inputs

1×10 Matrix{Float64}:
 0.0459606  1.15793  1.15771  1.51406  0.65143  1.1169  0.677322  0.826841  1.39484  0.655619


julia> grads[2] # for the model's parameters

(layer_1 = (weight = Float32[-3.4613643; 4.178537; -2.6454673; -4.777247;;], bias = Float32[-4.8556366; 5.8616934; -3.7110882; -6.70157;;]), layer_2 = (weight = Float32[-0.16420026 -1.0880805 -5.5679674 -4.978304], bias = Float32[7.3879676;;]))

Thank you @mcabbott ! Could you explain further how |> is doing it under the hood ? I see that the explicit gradients are coming as a result, indeed, rather than the Grads(...) object :


julia> back( deepcopy(outputs) )
([0.0023901141181289407 0.0035153696753484334 … 0.0029914787527228937 0.0003643633434114374], (layers = ((weight = Float32[-0.0073440215; -0.10114246; 0.0638036; 0.012706678;;], bias = Float32[-0.014901155, -0.20521991, 0.12945868, 0.025782084], σ = nothing), (weight = Float32[0.045871202 0.03815957 0.07034941 0.062251702], bias = Float32[0.2531139], σ = nothing)),))

Are there limitations on obtaining explicit/implicit gradients for some type of inputs ?

Could you explain further how |> is doing it under the hood ?

|> is just a shorthand for (inputs_, model_) -> model_(inputs_).

What you write in the Lux example is exactly the same Zygote explicit mode as this. In both cases grads[2] is some nested thing like (layers = ((weight = [..., and grads[1] == grad_in is something matching the structure of inputs (which here is just an array).

The biggest difference with Lux is another meaning of “explicit”, about the handling of the model state st, which is one of the things returned by apply. For a model this simple, it is trivial st == (layer_1 = NamedTuple(), layer_2 = NamedTuple()), so you can ignore it.

1 Like

Could you mention which Flux layers you need and aren’t mirrored? (Will help me prioritize stuff in Flux Feature Parity · Issue #13 · avik-pal/Lux.jl · GitHub)

Also, lux.csail.mit.edu/stable/manual/migrate_from_flux/#can-we-still-use-flux-layers might interest you.

2 Likes

Also, you shouldn’t have to deepcopy, if any input is mutated consider that a bug and open an issue.

1 Like

Hi @avikpal ! I’d say Maxout can be eventually used in my cases. Thanks for the link on using Flux layers inside Lux !

No inputs or outputs are mutated in the case above.
I could have written back(outputs) rather than back( deepcopy(outputs) ), indeed !

1 Like