"Can't differentiate foreigncall expression" error while using Flux.train!

Hi,

I am trying to implement a deep-learning based differential equation solver in Julia. I have implemented the model using custom layers and @functor. My loss function however uses Flux.pullback to compute first and second derivatives of the model output after sampling from a uniform distribution over the domain of the problem. When I do train_model() below, I am getting the “Can’t differentiate foreigncall expression” error, which I unable to make sense of. Here is my code starting from the loss function definition (assume all constants are defined properly):

function loss_function(N)

    t_internal, S_internal, t_terminal, S_terminal = sampling_function(nSim_interior, nSim_terminal)

    # differential operator loss
    model_output_internal = model(t_internal, S_internal)

    out1, back_S = Flux.pullback(S -> model(t_internal, S), S_internal)
    ∂g∂x = back_S(ones(size(out1)))[1]
    println("Computed ∂g∂x")

    out2, back_t = Flux.pullback(t -> model(t, S_internal), t_internal)
    ∂g∂t = back_t(ones(size(out2)))[1]
    println("Computed ∂g∂t")

    #use finite difference to calculate the second derivative
    ϵ = 0.01
    S_internal_shift = S_internal .+ ϵ
    out3, back_S_shift = Flux.pullback(l -> model(t_internal, l), S_internal_shift)
    ∂g∂x_shift = back_S_shift(ones(size(out3))[1]
    ∂g∂xx = (∂g∂x_shift .- ∂g∂x)./ϵ
    println("Computed ∂g∂xx")

    operator_loss_vec = ∂g∂t + r.*S_internal.*∂g∂x + (0.5*(sigma^2)).*(S_internal.^2).*∂g∂xx - r.*model_output_internal
    operator_loss = sum(abs2, operator_loss_vec)
    println("Computed operator loss")

    # terminal condition  loss
    target_output_terminal = relu.(S_terminal .- K)
    model_output_terminal = model(t_terminal, S_terminal)
    terminal_loss = sum(abs2, model_output_terminal - target_output_terminal)
    println("Computed terminal loss")

    return operator_loss + terminal_loss
end

function train_model()
    # set optimizer as ADAM
    opt = ADAM(learning_rate)
    dataset = [(1) for i in 1:10]
    Flux.train!(loss_function, params(model), zip(dataset), opt)
end

train_model()
ERROR: Can't differentiate foreigncall expression
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] get at ./iddict.jl:87 [inlined]
 [3] (::typeof(∂(get)))(::Nothing) at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface2.jl:0
 [4] accum_global at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/lib/lib.jl:56 [inlined]
 [5] (::typeof(∂(accum_global)))(::Nothing) at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface2.jl:0
 [6] #89 at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/lib/lib.jl:67 [inlined]
 [7] (::typeof(∂(λ)))(::Nothing) at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface2.jl:0
 [8] #1550#back at /Users/amritjalan/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [9] (::typeof(∂(λ)))(::Nothing) at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface2.jl:0
 [10] getindex at ./tuple.jl:24 [inlined]
 [11] gradindex at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/reverse.jl:12 [inlined]
 [12] #39 at /Users/amritjalan/OneDrive - Massachusetts Institute of Technology/Coursework/Fall 2020/18.337 - Parallel Computing/FinalProject/Deep Galerkin Method/european_call_option.jl:191 [inlined]
 [13] (::typeof(∂(λ)))(::Tuple{Nothing,Array{Float64,2}}) at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface2.jl:0
 [14] #41 at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface.jl:40 [inlined]
 [15] (::typeof(∂(λ)))(::Tuple{Array{Float64,2}}) at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface2.jl:0
 [16] loss_function at /Users/amritjalan/OneDrive - Massachusetts Institute of Technology/Coursework/Fall 2020/18.337 - Parallel Computing/FinalProject/Deep Galerkin Method/european_call_option.jl:192 [inlined]
 [17] (::typeof(∂(loss_function)))(::Float64) at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface2.jl:0
 [18] #150 at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/lib/lib.jl:191 [inlined]
 [19] #1693#back at /Users/amritjalan/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [20] #14 at /Users/amritjalan/.julia/packages/Flux/q3zeA/src/optimise/train.jl:103 [inlined]
 [21] (::Zygote.var"#54#55"{Params,Zygote.Context,typeof(∂(λ))})(::Float64) at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface.jl:172
 [22] gradient(::Function, ::Params) at /Users/amritjalan/.julia/packages/Zygote/bRa8J/src/compiler/interface.jl:49
 [23] macro expansion at /Users/amritjalan/.julia/packages/Flux/q3zeA/src/optimise/train.jl:102 [inlined]
 [24] macro expansion at /Users/amritjalan/.julia/packages/Juno/n6wyj/src/progress.jl:119 [inlined]
 [25] train!(::Function, ::Params, ::Base.Iterators.Zip{Tuple{Array{Int64,1}}}, ::ADAM; cb::Flux.Optimise.var"#16#22") at /Users/amritjalan/.julia/packages/Flux/q3zeA/src/optimise/train.jl:100
 [26] train! at /Users/amritjalan/.julia/packages/Flux/q3zeA/src/optimise/train.jl:98 [inlined]
 [27] train_model() at /Users/amritjalan/OneDrive - Massachusetts Institute of Technology/Coursework/Fall 2020/18.337 - Parallel Computing/FinalProject/Deep Galerkin Method/european_call_option.jl:246
 [28] top-level scope at none:1

Any guidance would be much appreciated!

Have you found a solution?