Hi,
I am trying to implement a deep-learning based differential equation solver in Julia. I have implemented the model using custom layers and @functor. My loss function however uses Flux.pullback to compute first and second derivatives of the model output after sampling from a uniform distribution over the domain of the problem. When I do train_model() below, I am getting the “Can’t differentiate foreigncall expression” error, which I unable to make sense of. Here is my code starting from the loss function definition (assume all constants are defined properly):
function loss_function(N)
t_internal, S_internal, t_terminal, S_terminal = sampling_function(nSim_interior, nSim_terminal)
# differential operator loss
model_output_internal = model(t_internal, S_internal)
out1, back_S = Flux.pullback(S -> model(t_internal, S), S_internal)
∂g∂x = back_S(ones(size(out1)))[1]
println("Computed ∂g∂x")
out2, back_t = Flux.pullback(t -> model(t, S_internal), t_internal)
∂g∂t = back_t(ones(size(out2)))[1]
println("Computed ∂g∂t")
#use finite difference to calculate the second derivative
ϵ = 0.01
S_internal_shift = S_internal .+ ϵ
out3, back_S_shift = Flux.pullback(l -> model(t_internal, l), S_internal_shift)
∂g∂x_shift = back_S_shift(ones(size(out3))[1]
∂g∂xx = (∂g∂x_shift .- ∂g∂x)./ϵ
println("Computed ∂g∂xx")
operator_loss_vec = ∂g∂t + r.*S_internal.*∂g∂x + (0.5*(sigma^2)).*(S_internal.^2).*∂g∂xx - r.*model_output_internal
operator_loss = sum(abs2, operator_loss_vec)
println("Computed operator loss")
# terminal condition loss
target_output_terminal = relu.(S_terminal .- K)
model_output_terminal = model(t_terminal, S_terminal)
terminal_loss = sum(abs2, model_output_terminal - target_output_terminal)
println("Computed terminal loss")
return operator_loss + terminal_loss
end
function train_model()
# set optimizer as ADAM
opt = ADAM(learning_rate)
dataset = [(1) for i in 1:10]
Flux.train!(loss_function, params(model), zip(dataset), opt)
end
train_model()
ERROR: Can't differentiate foreigncall expression
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] get at ./iddict.jl:87 [inlined]
[3] (::typeof(∂(get)))(::Nothing) at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface2.jl:0
[4] accum_global at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/lib/lib.jl:56 [inlined]
[5] (::typeof(∂(accum_global)))(::Nothing) at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface2.jl:0
[6] #89 at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/lib/lib.jl:67 [inlined]
[7] (::typeof(∂(λ)))(::Nothing) at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface2.jl:0
[8] #1550#back at /Users/amritjalan/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
[9] (::typeof(∂(λ)))(::Nothing) at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface2.jl:0
[10] getindex at ./tuple.jl:24 [inlined]
[11] gradindex at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/reverse.jl:12 [inlined]
[12] #39 at /Users/amritjalan/OneDrive - Massachusetts Institute of Technology/Coursework/Fall 2020/18.337 - Parallel Computing/FinalProject/Deep Galerkin Method/european_call_option.jl:191 [inlined]
[13] (::typeof(∂(λ)))(::Tuple{Nothing,Array{Float64,2}}) at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface2.jl:0
[14] #41 at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface.jl:40 [inlined]
[15] (::typeof(∂(λ)))(::Tuple{Array{Float64,2}}) at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface2.jl:0
[16] loss_function at /Users/amritjalan/OneDrive - Massachusetts Institute of Technology/Coursework/Fall 2020/18.337 - Parallel Computing/FinalProject/Deep Galerkin Method/european_call_option.jl:192 [inlined]
[17] (::typeof(∂(loss_function)))(::Float64) at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface2.jl:0
[18] #150 at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/lib/lib.jl:191 [inlined]
[19] #1693#back at /Users/amritjalan/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
[20] #14 at /Users/amritjalan/.julia/packages/Flux/q3zeA/src/optimise/train.jl:103 [inlined]
[21] (::Zygote.var"#54#55"{Params,Zygote.Context,typeof(∂(λ))})(::Float64) at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface.jl:172
[22] gradient(::Function, ::Params) at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface.jl:49
[23] macro expansion at /Users/amritjalan/.julia/packages/Flux/q3zeA/src/optimise/train.jl:102 [inlined]
[24] macro expansion at /Users/amritjalan/.julia/packages/Juno/n6wyj/src/progress.jl:119 [inlined]
[25] train!(::Function, ::Params, ::Base.Iterators.Zip{Tuple{Array{Int64,1}}}, ::ADAM; cb::Flux.Optimise.var"#16#22") at /Users/amritjalan/.julia/packages/Flux/q3zeA/src/optimise/train.jl:100
[26] train! at /Users/amritjalan/.julia/packages/Flux/q3zeA/src/optimise/train.jl:98 [inlined]
[27] train_model() at /Users/amritjalan/OneDrive - Massachusetts Institute of Technology/Coursework/Fall 2020/18.337 - Parallel Computing/FinalProject/Deep Galerkin Method/european_call_option.jl:246
[28] top-level scope at none:1
Any guidance would be much appreciated!