Hi!
I am relatively new to Julia and I am trying to move all my research projects to this new ecosystem instead of Python.
However, I have trouble when my project deals with Physics informed neural network. Since I am working on how to improve PINN, I need to keep control of what is happening and I cannot use any package. I am trying on my own but I am facing some compilation time that are excessively long. Here is a very simple example:
using Flux, Zygote
x_train = [2; 0;; 0; 1] * rand(Float32, 2, 10)
target = rand(Float32, 1, 10)
model = Chain(
Dense(2 => 15, tanh),
Dense(15 => 15, tanh),
Dense(15 => 1)
)
function get_residual(f::Chain)
df(u) = Zygote.gradient(u -> sum(f(u)), u)[1]
#ddf(u) = Zygote.gradient(u -> sum(df(u)), u)[1]
return u -> df(u)[1] .+ (1.0f0 .- 2.0f0 .* f(u)) .* df(u)[2] #.- 0.001f0 .* ddf(u)[2]
end
r = get_residual(model)
@info "Getting residuals"
timed = @timed r(x_train)
@show timed.time
function withgradient(model, x_train)
loss, grads = Flux.withgradient(model) do m
y_hat = m(x_train)
loss = Flux.Losses.mse(y_hat, target)
penalty = Flux.Losses.mse(get_residual(m)(x_train), zeros(Float32, 1, 10))
loss + penalty
end
(loss, grads)
end
@info "Getting gradients"
timed = @timed withgradient(model, x_train)
@show timed.time
It takes a very long time to compile first time, even if the model is very simple (25s for getting the residuals and 264s for the gradients). Moreover, if I uncomment the lines in the function get_residual
then it almost never finishes compilation. Such a code was running in a couple of seconds in Tensorflow before.
Do you have some suggestions to improve this code? Is there anything I am doing wrong? I tried using Yota (which is faster in my case) but it cannot compute the gradients in the end…
Best,
Matthieu