Hello there,
I’m trying to write efficient code for custom automatic differentiation in Julia.
I noticed that even for a simple case, Zygote gives type unstable gradients. Can you please help me understand why?
Here’s the example: I define a model called GaussianModel which takes a single parameter as input. This model can also be evaluated to return the value of its parameter. This is inspired by implicit gradients in Zygote’s documentation.
using LinearAlgebra, Random, Zygote
abstract type Model end
mutable struct GaussianModel{T<:AbstractFloat} <: Model
σ::T
end
(model::GaussianModel)() = model.σ
function logprob(x::T, y::T, model::GaussianModel{T}) where {T<:AbstractFloat}
σ = model()
return -(y - x)^2 / (2σ^2) - log(2 * π * σ^2) / 2
end
Now with this code
x = 0.2
y = 0.1
mymodel = GaussianModel(0.5)
logprob(x, y, mymodel)
gradient(model -> logprob(x, y, model), mymodel)
I get the correct result, but this
@code_warntype gradient(model -> logprob(x, y, model), mymodel)
gives me
MethodInstance for Zygote.gradient(::var"#129#130", ::GaussianModel{Float64})
from gradient(f, args...) in Zygote at [...]
Arguments
#self#e[36m::Core.Const(Zygote.gradient)e[39m
fe[36m::Core.Const(var"#129#130"())e[39m
argse[36m::Tuple{GaussianModel{Float64}}e[39m
Locals
@_4e[36m::Int64e[39m
grade[33me[1m::Union{Nothing, Tuple}e[22me[39m
backe[91me[1m::Zygote.var"#75#76"e[22me[39m
ye[91me[1m::Anye[22me[39m
Bodye[33me[1m::Union{Nothing, Tuple{Any}}e[22me[39m
e[90m1 ─e[39m %1 = Core.tuple(f)e[36m::Core.Const((var"#129#130"(),))e[39m
e[90m│ e[39m %2 = Core._apply_iterate(Base.iterate, Zygote.pullback, %1, args)e[91me[1m::Tuple{Any, Zygote.var"#75#76"}e[22me[39m
e[90m│ e[39m %3 = Base.indexed_iterate(%2, 1)e[36m::Core.PartialStruct(Tuple{Any, Int64}, Any[Any, Core.Const(2)])e[39m
e[90m│ e[39m (y = Core.getfield(%3, 1))
e[90m│ e[39m (@_4 = Core.getfield(%3, 2))
e[90m│ e[39m %6 = Base.indexed_iterate(%2, 2, @_4::Core.Const(2))e[36m::Core.PartialStruct(Tuple{Zygote.var"#75#76", Int64}, Any[Zygote.var"#75#76", Core.Const(3)])e[39m
e[90m│ e[39m (back = Core.getfield(%6, 1))
e[90m│ e[39m %8 = Zygote.sensitivity(y)e[91me[1m::Anye[22me[39m
e[90m│ e[39m (grad = (back)(%8))
e[90m│ e[39m %10 = Zygote.isnothing(grad)e[36m::Boole[39m
e[90m└──e[39m goto #3 if not %10
e[90m2 ─e[39m return Zygote.nothing
e[90m3 ─e[39m %13 = Zygote.map(Zygote._project, args, grad::Tuple)e[91me[1m::Tuple{Any}e[22me[39m
e[90m└──e[39m return %13
which seems a type instability. I attach a screenshot for clarity.
Thank you in advance!