Havenāt tried it, but Yota should be general enough to support any kind of differentiation that boils down to primitives and chain rule. Do you have an example of you need at hand?
Not really a mathematical text, but I once explained reverse-mode autodiff example here.
Awesome. But for more complicated functions, say I have something where Iād like to cache some large calculation from which both forward & backward are then fast. Can I saftely jut place this in mybigfunc(x) and call that twice (with identical arguments)?
julia> rx = Ref(0.5);
julia> using BenchmarkTools
julia> foo(x) = exp(x)
foo (generic function with 1 method)
julia> āfoo(x) = (exp(x),exp(x))
āfoo (generic function with 1 method)
julia> @benchmark foo($(rx)[])
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 6.350 ns (0.00% GC)
median time: 6.450 ns (0.00% GC)
mean time: 6.554 ns (0.00% GC)
maximum time: 21.960 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 1000
julia> @benchmark āfoo($(rx)[])
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 13.317 ns (0.00% GC)
median time: 13.778 ns (0.00% GC)
mean time: 14.066 ns (0.00% GC)
maximum time: 28.647 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 998
Which means, unless Iām mistaken, logistic(x) in this example is going to get computed three times:
logistic(x) = 1 / (1 + exp(-x))
# for an expression like `logistic(x)` where x is a Number
# gradient w.r.t. x
# is `(logistic(x) * (1 - logistic(x)) * ds)` where "ds" stands for derivative "dL/dy"
@diffrule logistic(x::Number) x (logistic(x) * (1 - logistic(x)) * ds)
Once on the forward pass, and twice on the backwards pass.
OK, so the CSE in question here is just Juliaās? I had the idea that the macro @diffrule was itself looking for such repeatsā¦ but perhaps that is the planned tape feature mentioned?
Maybe I can answer this by benchmarking gradients of logistic(x) when I have a minute.
Maybe Yota does. I havenāt tested that / dived into the source code.
The hard part is you need some form (weaker than Juliaās @pure) of purity, to know functions like logistic are going to be the same each time theyāre called.
Given that Yota emphasizes performance, and already requires static graphs, I could certainly see the case for some side-effect free assumptions to allow aggressive CSE.
And this is exactly how it works To be precise, hereās the line that performs CSE. You can check the result like this (which is not the part of external API though):
Except for prologue and epilogue (which are used for buffer pre-allocation and are the most important optimization for large-scale deep learning models), the only code left is:
%2 = (exp)(%1)
%3 = 1.0f0
%5 = (*)(%2, %3)
We could further optimize it to just exp(%1), but usually it doesnāt make much difference and is presumably eliminated by Julia compiler anyway.
Supporting in-place operations in autograd is a hard matter, and we discourage their use in most cases. Autogradās aggressive buffer freeing and reuse makes it very efficient and there are very few occasions when in-place operations actually lower memory usage by any significant amount. Unless youāre operating under heavy memory pressure, you might never need to use them.
Havenāt tried it, but Yota should be general enough to support any kind of differentiation that boils down to primitives and chain rule. Do you have an example of you need at hand?
I need derivatives of multivariate polynomials with real or complex coefficients and maybe some basic analytic functions (exp, log, sin, cos). So far I did not find an AD package which doesnāt introduce significant overhead.
I wonder if it would not be possible to share a differential rules for libraries. There are GitHub - JuliaDiff/DiffRules.jl: A simple shared suite of common derivative definitions and I wonder if Yota could not use those. Since I work mostly with static graphs, I would be curious to try Yota. On the other hand I use few custom gradients, which I have currently made available for Flux.
Sure that is possible, but there are a lot of applications where you have your polynomials already in a factored form and the expanded symbolic form would be quite wasteful.
For example take a linear form l(x) = x[1] + 3x[2]+ 2x[3] + 5x[4] and the polynomial g(x) = l(x)^4 - l(x)^2. You really want to avoid to compute with the symbolic form of this.
Usually yes, but as with any early-stage software, itās worth to check it in practice Right now you can use code from this reply to inspect generated stuff, I also created an issue to make it more transparent.
One possible issue is that you use lots of x[i] expressions which themselves have derivatives like:
dx = zero(x)
dx[i] .= dy
So in your example you will get 20 intermediate arrays of the same size as x and most libraries - including Yota at its current state - wonāt optimize them out. Let me see what we can do in this case (although I canāt promise great results since Yota is mostly optimized for quite large arrays).
Also thanks for this example, it helped to uncover a bug in derivatives of multiargument *, /, etc. Iāve created an issue for this.
The problem with this is that it is at odds with recommendations for performance in general Julia code, and often does make a significant difference. (I understand itās hard to support mutation, Iām just saying that this is something that would be very useful to have for general-purpose AD)
I agree that using pre-allocated buffers does make a lot of difference, but I would call the technique an inconvenient workaround instead of a recommendation. It opens a whole can of worms, including element type calculations for containers and buffer ownership, which are all manageable but inconvenient.
Is broadcasting supported? I had a go, with the following function, but get an error Failed to find a derivative for %44 = broadcast(%43, %38, %-1)::Array{Float64,2}:
function pd(x)
m = x'*x
v = diag(m)
v .+ v' .- 2 .* m
end
Yota.grad(x -> sum(pd(x)), rand(2,3))
Thanks for this example! This is exactly the feedback I wanted from the announce - a set of workflows which are different from mine and which I thus missed.
In this particular case function diag wasnāt recorded as primitive, but instead traced-trough resulting in multiple getindex() calls (and some other unexpected things that I still need to investigate). I have created an issue for this. Also pardon for slow reaction - start of the week turned to be pretty busy, but I hope to fix all mentioned issues by next Monday.
As for broadcasting itself, it should work fine as long as the function being broadcasted is a primitive with defined @diffrule. Possibly we will remove this restriction in future, but right now itās unclear how much of a constraint it is in real life (e.g. in your case all the broadcasting happens over primitives).
Letās break it into 2 parts:
Using autodiff to differentiate existing functions with mutable operations.
Getting efficient code for forward and backward passes.
Somewhere deep in my backlog I have a ticket to rewrite most popular mutating operations (e.g. mul!) back into non-mutating as a part of preprocessing, so it should address (1) at least for some portion of functions.
(2) is already partially addressed - all matrix multiplications are replaced with mul! in generated code and all array variables are modified in-place using .= instead of =. In my previous packages I also had a macro to define rule for rewriting non-mutating into mutating operations, perhaps Iāll bring it to Yota too.
The set of rule in Yota dates back to ReverseDiffSource.jl and has never been seriously revisited, so yes, itās possible I will migrate to them at some point.
The other problem I ran into was I think due to calling vec(m)ā¦ which @less tells me is reshape(a,length(a)) , and if Iām reading right it looks like you only have diffrules for reshape(a, n,m).
Congrats on the announcement @dfdx! Itās great to have something to show off the power of Cassette here, and a bonus to have something robust and maintained to benchmark against.
Iām hoping that we can end up doing many Yota-style optimisations on the āstatic subgraphsā of Juliaās IR, but thereās some way to go in terms of compiler support before thatās possible. With some cleverness in handling things like memory management it should be possible to avoid trading off performance and flexibility entirely.
Iām curious if youāve tried this on something like DiffEq? For me this is the ultimate challenge case for Zygote and Iād be very curious how Yota fares.
Thanks Mike! Iāve actually spent quite a lot of time inspecting code of Zygote to learn more about Julia IR on early stage of Yota.
With some cleverness in handling things like memory management it should be possible to avoid trading off performance and flexibility entirely.
In my tests memory pre-allocation is by far the most important performance optimization. The tricky part is that you need to hold pre-allocated buffers somewhere instead of creating them on each function call. One way to do it that might be suitable to Zygote is to pass memory manager (in simplest case - a Dict with buffers for each tensor variable) as an optional parameter to a gradient function. I used this approach in XGrad and it worked pretty well in practice providing to a user an ability (but not obligation) to speed up calculations with just a bit more typing. I believe if you implement just this (or any other clever memory pre-allocation strategy), it will bring you to almost optimal performance without loosing any of current flexibility.
Iām curious if youāve tried this on something like DiffEq? For me this is the ultimate challenge case for Zygote and Iād be very curious how Yota fares.
Iāve never considered DiffEq, but it sounds like an interesting challenge!
For an adaptive integrator, I am not sure if it will ever work since it needs to know how to backprop through the control flow since the computation is very value-dependent. But one thing we could do is define an adjoint primitive via DiffEqSensitivity.jlās adjoint for Yota.jl so that way users can make use of DiffEq+Yota, but it would have limitations of course.
Another approach would be to just add support for dynamic graphs in Yota - technically itās not hard, but it invalidates most of current optimizations. Iāve created an issue to track ideas.