How to capture model output with loss in Flux.withgradient

Here is the sample code for my training loop

function train()
    @showprogress for i in 1:10
        l, grads = Flux.withgradient(m -> lossfn(m(X), Y), model)
        fmap(model, grads[1]) do p, g
            p .= p .- η .* g
        end       
        println(l)
    end
end

I am able to capture the loss while computing the gradients. However I also want to say log my predictions generated in the m(X) call. For example I might want to write them to a file. What is the idiomatic way of doing this?

I can modify lossfn to do it inside there but it doesn’t seem clean. Is there a way to return multiple values with Flux.withgradient?

Thanks!

You can use ignore_derivatives from ChainRulesCore, which can be accessed also from Zygote:

using Flux, Zygote

function train()
    preds = [] 
    @showprogress for i in 1:10
        l, grads = Flux.withgradient(model) do m 
            Ŷ = m(X)
            Zygote.ignore_derivatives() do 
                push!(preds, Ŷ)
            end
            lossfn(Ŷ, Y)
        end
        fmap(model, grads[1]) do p, g
            p .= p .- η .* g
        end
        println(l)
    end
end
1 Like

Thanks @CarloLucibello I was indeed looking for something like PyTorch’s detach :smiley:

ValueHistories.jl worked very nicely for me. Highly recommend taking a look at the ecosystem Flux docs page if you’re creating custom training pipelines and models. I didn’t, and realized months later that a bunch of what I did was already in some package :smiling_face_with_tear:

2 Likes