Hello everyone. I am stuck with the following problem for which I have created below a MWE. It concerns automatic differentiation using DifferentiationInterface.jl together with Mooncake.jl. From the error message I get, I am guessing that I am doing something wrong with the signatures of my functions, but I can’t seem to find the problem. This is not a problem of either DifferentiatioInterface or Mooncake.
Basically, in the MWE below, I have an objective function that depends on the model passed to it. The model implements a simple calculation. In my actual application, I will have multiple models that I want to pass to the objective function. Here I show two contrived models, Model1 and Model2.
- Method
outerfunctionrepresents a flexible implementation in the sense that the calculation of the objective function depends on the model passed to it. Unfortunately, this doesn’t work as I thought it would and this is my problem. - Method
outerfunction_cosrepresents an implementation where the model has been “hard coded” into it, in this case model 2, and works indeed as expected.
using DifferentiationInterface
import Mooncake
abstract type AbstractModel end
# contrived model definitions
struct Model1<:AbstractModel end
calculate(::Model1, x) = sin(x)
struct Model2<:AbstractModel end
calculate(::Model2, x) = cos(x)
# This function is represents the intented functionality.
# It should be flexible as to the chose model .
function outerfunction(y, model::AbstractModel)
function mock_objective(x)
sum(abs2.(y .- calculate.(model, x)))
end
backend = AutoMooncake(; config=nothing)
x = randn(size(y))
prep = prepare_gradient(mock_objective, backend, x)
gradient(mock_objective, prep, backend, x)
end
# This function hard codes the model
function outerfunction_cos(y)
function mock_objective(x)
sum(abs2.(y .- cos.(x)))
end
backend = AutoMooncake(; config=nothing)
x = randn(size(y))
prep = prepare_gradient(mock_objective, backend, x)
gradient(mock_objective, prep, backend, x)
end
This is how I call the above code:
# call above
y = randn(10)
outerfunction_cos(y) # works as expected
outerfunction(y, Model1()) # fails
outerfunction(y, Model2()) # fails
Despite the error message that I see in my terminal, I can’t properly interpret it.
In brief, it’s a MooncakeRuleCompilationError error.
Could somebody please help with this? Many thanks in advance.