Untracking a Flux Model

I’m using Flux to train a model and it’s working wonderfully. After the model has been trained, it’s parameters are still tracked, which means the output of calling the model is tracked. Is there a way to “untrack” each of the model’s parameters, so that I don’t need to append #data each time I want to use the model? I thought this would be trivial but I can’t find a way. The workaround I have in mind is re-creating the model using the untracked version of the trained parameters, but I’m hoping there’s an easier way.

1 Like

I wrote my own method to do that here: https://github.com/rdeits/FluxExtensions.jl/blob/7d6d67eb6d4da66f0fddc24bf8ec2d04c60db254/src/FluxExtensions.jl#L16-L18

This was some slightly hack-y research code, but it seems to work and even has tests.

2 Likes

Here’s one easy way to do it:

model = Flux.mapchildren(Flux.data, model)

Edit: As pointed out below, this should be mapleaves:

model = Flux.mapleaves(Flux.data, model)
6 Likes

What’s the difference of mapchildren compared to mapleaves as used here? See def of mapleaves here

Oh, good catch- actually that line was exactly what I had in mind! :slightly_smiling_face: mapleaves is the way to do it. There’s a note in the docs here:

Flux provides mapleaves , which allows you to alter all parameters of a model at once.

Awesome, thank you!

1 Like

After we do this how do we turn tracking back on?

I don’t have a Julia installation in front of me right now, but I this should work for re-tracking a model:

model = Flux.mapleaves(Flux.param, model)
2 Likes

Since the use of Zygote, Flux.data is deprecated and the proposed untracking solution no longer works.

Is there a way to untrack a model with Flux+Zygote?

Should not be a need to do so as Zygote works on normal arrays.

To prevent Zygite from taking gradients of a function you can use @nograd, but I guess that is not what the original question was about.

1 Like

Thanks for pointing me in the right direction.