Thanks @de-souza for this reference. Based on this work, I propose the following (not really optimized) solution that seems to work well and fast. However, when we compare with the fully Zygote solution, there is a difference in the gradients that is not that negligible. But this is difficult to say which one is the best since we do not have any reference.
using Flux, Zygote, ForwardDiff, TaylorDiff, SliceMap
x_train = [2; 0;; 0; 1] * rand(Float32, 2, 10)
target = rand(Float32, 1, 10)
model = Chain(
Dense(2 => 15, tanh),
Dense(15 => 15, tanh),
Dense(15 => 1)
)
function get_residual(f::Chain)
df(u) = Zygote.gradient(u -> sum(f(u)), u)[1]
return u -> df(u)[[1], :] .+ (1.0f0 .- 2.0f0 .* f(u)) .* df(u)[[2], :] # Cannot do higher order
end
function get_residual_forward(f::Chain)
df(u) = ForwardDiff.gradient(u -> sum(f(u)), u)
#ddf(u) = ForwardDiff.gradient(u -> sum(df(u)), u)
return u -> df(u)[[1], :] .+ (1.0f0 .- 2.0f0 .* f(u)) .* df(u)[[2], :] #.- 0.001f0 .* ddf(u)[2]
end
function get_residual_Taylor(f::Chain, x)
x = convert(Vector{Float32}, x)
dfdt(u) = TaylorDiff.derivative(u -> sum(f(u)), u, [1.0f0, 0.0f0], 1)
dfdx(u) = TaylorDiff.derivative(u -> sum(f(u)), u, [0.0f0, 1.0f0], 1)
#dfdxx(u) = TaylorDiff.derivative(u -> sum(f(u)), u, [0.0f0, 1.0f0], 2)
return dfdt(x) .+ (1.0f0 .- 2.0f0 .* f(x)) .* dfdx(x) #.- 0.001f0 .* dfdxx(x)
end
function get_residual_fd(f::Chain)
ε = cbrt(eps(Float32))
ε₁ = [ε; 0]
ε₂ = [0; ε]
V(x) = (1.0f0 .- 2.0f0 .* f(x))
return x -> (f(x .+ ε₁) - f(x)) / ε .+ V(x) .* (f(x .+ ε₂) - f(x)) / ε
end
r = get_residual(model)
@info "Getting residuals (Zygote)"
timed = @timed r(x_train)
@show timed
r_forward = get_residual_forward(model)
@info "Getting residuals (ForwardDiff)"
timed = @timed r_forward(x_train)
@show timed
r_taylor(x) = get_residual_Taylor(model, x)
@info "Getting residuals (TaylorDiff)"
timed = @timed mapcols(x -> get_residual_Taylor(model, x), x_train)
@show timed
r_fd = get_residual_fd(model)
@info "Getting residuals (FiniteDifference)"
timed = @timed r_fd(x_train)
@show timed
@info "Getting gradients (Zygote)"
function withgradient_Zygote(model, x_train)
loss, grads = Zygote.withgradient(model) do m
y_hat = m(x_train)
loss = Flux.Losses.mse(y_hat, target)
penalty = Flux.Losses.mse(get_residual(m)(x_train), zeros(Float32, 1, 10))
loss + penalty
end
(loss, grads)
end
timed_Zygote = @timed withgradient_Zygote(model, x_train)
@show timed_Zygote
@info "Getting gradients (TaylorDiff)"
function withgradient_taylor(model, x_train)
loss, grads = Zygote.withgradient(model) do m
y_hat = m(x_train)
loss = Flux.Losses.mse(y_hat, target)
r = mapcols(x -> get_residual_Taylor(model, x), x_train)
penalty = Flux.Losses.mse(r, zeros(Float32, 1, 10))
loss + penalty
end
(loss, grads)
end
timed_taylor = @timed withgradient_taylor(model, x_train)
@show timed_taylor
The solution with Zygote takes around 175sec to compile and run (first time) while the solution with TaylorDiff is about 10 times less (14sec).
One solution would be to compare on a simple example with an analytical solution but I am not sure I have the will for doing that now.
I hope that can help you too @de-souza , tell me what you think!