Zygote differentiation issues

In an attempt to replicate what I saw here: " The impact of differentiable programming: how ∂P is enabling new science in Julia"

function g(x)
    if x < 0
        print("Enter function name: ")
        getfield(Base, Symbol(readline()))(x)
    else
        2*x^3 + 4*x^2 +5*x
    end
end
julia> g'(4)
133
julia> g'(-pi/6) 
Enter function name: sin 
ERROR: Can't differentiate foreigncall expression 
Stacktrace: 
[1] error(::String) at ./error.jl:33 
[2] Symbol at ./boot.jl:438 [inlined] 
[3] (::typeof(∂(Symbol)))(::Nothing) at /home/user/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0 
[4] g at /home/user/julia_control/cm_control.jl:17 [inlined] 
[5] (::typeof(∂(g)))(::Float64) at /home/user/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0 
[6] (::Zygote.var"#41#42"{typeof(∂(g))})(::Float64) at /home/user/.julia/packages/Zygote/ggM8Z/src/compiler/interface.jl:40 
[7] gradient(::Function, ::Float64) at /home/user/.julia/packages/Zygote/ggM8Z/src/compiler/interface.jl:49
[8] (::Zygote.var"#43#44"{typeof(g)})(::Float64) at /home/user/.julia/packages/Zygote/ggM8Z/src/compiler/interface.jl:52 [9] top-level scope at none:1

Why am I getting this error and how to resolve it?

using Flux
using Zygote
using Trebuchet

function shoot(wind, angle, weight)
  Trebuchet.shoot((wind, Trebuchet.deg2rad(angle), weight))[2]
end


julia> shoot'(5,50,220)
ERROR: MethodError: no method matching (::Zygote.var"#43#44"{typeof(shoot)})(::Int64, ::Int64, ::Int64)

It seems like Zygote doesn’t like (::Int64, ::Int64, ::Int64). How to fix it?

Also, tried this:

function shoot(pars)
    Trebuchet.shoot((pars[1], Trebuchet.deg2rad(pars[2]), pars[3]))[2]
end

got this error:

julia> shoot'([5,55,200])
ERROR: Compiling Tuple{typeof(Trebuchet.shoot),Tuple{Int64,Float64,Int64}}: try/catch is not supported.

I don’t think Zygote (or any tool which works only at compile time) will be able to differentiate a function that is determined only at run time?

Here: The impact of differentiable programming: how ∂P is enabling new science in Julia - YouTube

For him, it seems to be working.

Ah, yes, I cheated a smidge there. I elided a Zygote.@nograd Symbol. That should probably be upstreamed — and I imagine that’s what I was thinking when I chose to not include it in the talk, but promptly forgot.

2 Likes

Thanks for replying.

I am new to this thing. Could you please explain what Zygote.@nograd Symbol does ? Is there a good resource, where I can read about these things (other than fluxml.ai’s blog)?
Also, how did you make shoot'() work?

Zygote.@nograd Symbol is basically telling Zygote that it shouldn’t bother trying to differentiate the construction of the symbol :sin itself because it’s non-differentiable. It fails without this because it’s part of the small bit of Julia that’s not implemented in Julia — we end up calling into Julia’s C core (which is what’s marked as the “foreigncall”).

Looks like I never published my scratch work that I used to generate that talk. I should do that:

https://github.com/mbauman/dTrebuchet

The key on differentiating through shoot is to opt-in to forward mode AD and ensure you use a single array argument when differentiating with respect to the passed argument.

https://github.com/mbauman/dTrebuchet/blob/ae5e215747b6da2ce59197ff18cdde16807eb980/reinforcement.jl#L9-L10

2 Likes

That’s actually something Zygote can very much do. It’s even on the Readme.

julia> using Zygote

julia> fs = Dict("sin" => sin, "cos" => cos, "tan" => tan);

julia> gradient(x -> fs[readline()](x), 1)
sin
(0.5403023058681398,)

That sin there was typed by me into the repl.

3 Likes

Thanks a lot