Speed up compilation with Zygote and PINN

Thanks @de-souza for this reference. Based on this work, I propose the following (not really optimized) solution that seems to work well and fast. However, when we compare with the fully Zygote solution, there is a difference in the gradients that is not that negligible. But this is difficult to say which one is the best since we do not have any reference.

using Flux, Zygote, ForwardDiff, TaylorDiff, SliceMap

x_train = [2; 0;; 0; 1] * rand(Float32, 2, 10)
target = rand(Float32, 1, 10)

model = Chain(
    Dense(2 => 15, tanh),
    Dense(15 => 15, tanh),
    Dense(15 => 1)
)

function get_residual(f::Chain)
    df(u) = Zygote.gradient(u -> sum(f(u)), u)[1]
    return u -> df(u)[[1], :] .+ (1.0f0 .- 2.0f0 .* f(u)) .* df(u)[[2], :] # Cannot do higher order
end

function get_residual_forward(f::Chain)
    df(u) = ForwardDiff.gradient(u -> sum(f(u)), u)
    #ddf(u) = ForwardDiff.gradient(u -> sum(df(u)), u)
    return u -> df(u)[[1], :] .+ (1.0f0 .- 2.0f0 .* f(u)) .* df(u)[[2], :] #.- 0.001f0 .* ddf(u)[2]
end

function get_residual_Taylor(f::Chain, x)
    x = convert(Vector{Float32}, x)
    dfdt(u) = TaylorDiff.derivative(u -> sum(f(u)), u, [1.0f0, 0.0f0], 1)
    dfdx(u) = TaylorDiff.derivative(u -> sum(f(u)), u, [0.0f0, 1.0f0], 1)
    #dfdxx(u) = TaylorDiff.derivative(u -> sum(f(u)), u, [0.0f0, 1.0f0], 2)
    return dfdt(x) .+ (1.0f0 .- 2.0f0 .* f(x)) .* dfdx(x) #.- 0.001f0 .* dfdxx(x)
end

function get_residual_fd(f::Chain)
    ε = cbrt(eps(Float32))
    ε₁ = [ε; 0]
    ε₂ = [0; ε]
    V(x) = (1.0f0 .- 2.0f0 .* f(x))
    return x -> (f(x .+ ε₁) - f(x)) / ε .+ V(x) .* (f(x .+ ε₂) - f(x)) / ε
end

r = get_residual(model)
@info "Getting residuals (Zygote)"
timed = @timed r(x_train)
@show timed

r_forward = get_residual_forward(model)
@info "Getting residuals (ForwardDiff)"
timed = @timed r_forward(x_train)
@show timed

r_taylor(x) = get_residual_Taylor(model, x)
@info "Getting residuals (TaylorDiff)"
timed = @timed mapcols(x -> get_residual_Taylor(model, x), x_train)
@show timed

r_fd = get_residual_fd(model)
@info "Getting residuals (FiniteDifference)"
timed = @timed r_fd(x_train)
@show timed

@info "Getting gradients (Zygote)"
function withgradient_Zygote(model, x_train)
    loss, grads = Zygote.withgradient(model) do m
        y_hat = m(x_train)
        loss = Flux.Losses.mse(y_hat, target)
        penalty = Flux.Losses.mse(get_residual(m)(x_train), zeros(Float32, 1, 10))
        loss + penalty
    end
    (loss, grads)
end

timed_Zygote = @timed withgradient_Zygote(model, x_train)
@show timed_Zygote

@info "Getting gradients (TaylorDiff)"
function withgradient_taylor(model, x_train)
    loss, grads = Zygote.withgradient(model) do m
        y_hat = m(x_train)
        loss = Flux.Losses.mse(y_hat, target)
        r = mapcols(x -> get_residual_Taylor(model, x), x_train)
        penalty = Flux.Losses.mse(r, zeros(Float32, 1, 10))
        loss + penalty
    end
    (loss, grads)
end

timed_taylor = @timed withgradient_taylor(model, x_train)
@show timed_taylor

The solution with Zygote takes around 175sec to compile and run (first time) while the solution with TaylorDiff is about 10 times less (14sec).

One solution would be to compare on a simple example with an analytical solution but I am not sure I have the will for doing that now.

I hope that can help you too @de-souza , tell me what you think!

2 Likes