Hello everyone. I am stuck with the following problem for which I have created below a MWE. It concerns automatic differentiation using DifferentiationInterface.jl together with Mooncake.jl. From the error message I get, I am guessing that I am doing something wrong with the signatures of my functions, but I can’t seem to find the problem. This is not a problem of either DifferentiatioInterface or Mooncake.
Basically, in the MWE below, I have an objective function that depends on the model passed to it. The model implements a simple calculation. In my actual application, I will have multiple models that I want to pass to the objective function. Here I show two contrived models, Model1
and Model2
.
- Method
outerfunction
represents a flexible implementation in the sense that the calculation of the objective function depends on the model passed to it. Unfortunately, this doesn’t work as I thought it would and this is my problem. - Method
outerfunction_cos
represents an implementation where the model has been “hard coded” into it, in this case model 2, and works indeed as expected.
using DifferentiationInterface
import Mooncake
abstract type AbstractModel end
# contrived model definitions
struct Model1<:AbstractModel end
calculate(::Model1, x) = sin(x)
struct Model2<:AbstractModel end
calculate(::Model2, x) = cos(x)
# This function is represents the intented functionality.
# It should be flexible as to the chose model .
function outerfunction(y, model::AbstractModel)
function mock_objective(x)
sum(abs2.(y .- calculate.(model, x)))
end
backend = AutoMooncake(; config=nothing)
x = randn(size(y))
prep = prepare_gradient(mock_objective, backend, x)
gradient(mock_objective, prep, backend, x)
end
# This function hard codes the model
function outerfunction_cos(y)
function mock_objective(x)
sum(abs2.(y .- cos.(x)))
end
backend = AutoMooncake(; config=nothing)
x = randn(size(y))
prep = prepare_gradient(mock_objective, backend, x)
gradient(mock_objective, prep, backend, x)
end
This is how I call the above code:
# call above
y = randn(10)
outerfunction_cos(y) # works as expected
outerfunction(y, Model1()) # fails
outerfunction(y, Model2()) # fails
Despite the error message that I see in my terminal, I can’t properly interpret it.
In brief, it’s a MooncakeRuleCompilationError
error.
Could somebody please help with this? Many thanks in advance.