How to capture model output with loss in Flux.withgradient

You can use ignore_derivatives from ChainRulesCore, which can be accessed also from Zygote:

using Flux, Zygote

function train()
    preds = [] 
    @showprogress for i in 1:10
        l, grads = Flux.withgradient(model) do m 
            Ŷ = m(X)
            Zygote.ignore_derivatives() do 
                push!(preds, Ŷ)
            end
            lossfn(Ŷ, Y)
        end
        fmap(model, grads[1]) do p, g
            p .= p .- η .* g
        end
        println(l)
    end
end
1 Like