[ANN] Yota.jl - yet another reverse-mode autodiff package

Havenā€™t tried it, but Yota should be general enough to support any kind of differentiation that boils down to primitives and chain rule. Do you have an example of you need at hand?

Not really a mathematical text, but I once explained reverse-mode autodiff example here.

Awesome. But for more complicated functions, say I have something where Iā€™d like to cache some large calculation from which both forward & backward are then fast. Can I saftely jut place this in mybigfunc(x) and call that twice (with identical arguments)?

CSE can fail on even simple functions, like exp.

julia> rx = Ref(0.5);

julia> using BenchmarkTools

julia> foo(x) = exp(x)
foo (generic function with 1 method)

julia> āˆ‚foo(x) = (exp(x),exp(x))
āˆ‚foo (generic function with 1 method)

julia> @benchmark foo($(rx)[])
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     6.350 ns (0.00% GC)
  median time:      6.450 ns (0.00% GC)
  mean time:        6.554 ns (0.00% GC)
  maximum time:     21.960 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark āˆ‚foo($(rx)[])
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     13.317 ns (0.00% GC)
  median time:      13.778 ns (0.00% GC)
  mean time:        14.066 ns (0.00% GC)
  maximum time:     28.647 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     998

Which means, unless Iā€™m mistaken, logistic(x) in this example is going to get computed three times:

logistic(x) = 1 / (1 + exp(-x))
# for an expression like `logistic(x)` where x is a Number
# gradient w.r.t. x
# is `(logistic(x) * (1 - logistic(x)) * ds)` where "ds" stands for derivative "dL/dy"
@diffrule logistic(x::Number) x (logistic(x) * (1 - logistic(x)) * ds)

Once on the forward pass, and twice on the backwards pass.

1 Like

OK, so the CSE in question here is just Juliaā€™s? I had the idea that the macro @diffrule was itself looking for such repeatsā€¦ but perhaps that is the planned tape feature mentioned?

Maybe I can answer this by benchmarking gradients of logistic(x) when I have a minute.

With libraries like CommonSubexpressions or DataFlow, itā€™s certainly possible to do it on the Julia side!

Maybe Yota does. I havenā€™t tested that / dived into the source code.
The hard part is you need some form (weaker than Juliaā€™s @pure) of purity, to know functions like logistic are going to be the same each time theyā€™re called.

Given that Yota emphasizes performance, and already requires static graphs, I could certainly see the case for some side-effect free assumptions to allow aggressive CSE.

And this is exactly how it works :slight_smile: To be precise, hereā€™s the line that performs CSE. You can check the result like this (which is not the part of external API though):

julia> using Yota

julia> foo(x) = exp(x)
foo (generic function with 1 method)

julia> _, g = grad(foo, 0.5)
(1.6487212707001282, GradResult(1))

julia> g.tape
Tape
  inp %1::Float64
  %2 = exp(%1)::Float64
  const %3 = 1.0::Float32
  %4 = exp(%1)::Float64
  %5 = *(%4, %3)::Float64

julia> Yota.generate_function_expr(g.tape)
:(function ##tape_fn#364()
      #= /home/slipslop/work/Yota/src/compile.jl:112 =#
      #= prologue:0 =#
      %1 = (inp %1::Float64).val
      %2 = (%2 = exp(%1)::Float64).val
      %3 = (const %3 = 1.0::Float32).val
      %4 = (%4 = exp(%1)::Float64).val
      %5 = (%5 = *(%4, %3)::Float64).val
      #= body:0 =#
      #= /home/slipslop/.julia/packages/Espresso/Lewh0/src/exgraph.jl:100 =#
      %2 = (exp)(%1)
      %3 = 1.0f0
      %5 = (*)(%2, %3)
      #= epilogue:0 =#
      (inp %1::Float64).val = %1
      (%2 = exp(%1)::Float64).val = %2
      (const %3 = 1.0::Float32).val = %3
      (%4 = exp(%1)::Float64).val = %4
      (%5 = *(%4, %3)::Float64).val = %5
  end)

Except for prologue and epilogue (which are used for buffer pre-allocation and are the most important optimization for large-scale deep learning models), the only code left is:

%2 = (exp)(%1)
%3 = 1.0f0
%5 = (*)(%2, %3)

We could further optimize it to just exp(%1), but usually it doesnā€™t make much difference and is presumably eliminated by Julia compiler anyway.

It might look limiting to forbid mutable operations, but for comparison hereā€™s citation of PyTorch documentation:

Supporting in-place operations in autograd is a hard matter, and we discourage their use in most cases. Autogradā€™s aggressive buffer freeing and reuse makes it very efficient and there are very few occasions when in-place operations actually lower memory usage by any significant amount. Unless youā€™re operating under heavy memory pressure, you might never need to use them.

3 Likes

Havenā€™t tried it, but Yota should be general enough to support any kind of differentiation that boils down to primitives and chain rule. Do you have an example of you need at hand?

I need derivatives of multivariate polynomials with real or complex coefficients and maybe some basic analytic functions (exp, log, sin, cos). So far I did not find an AD package which doesnā€™t introduce significant overhead.

An example would be the derivative of

 f(x) = x[3]*x[8]^3 - 3*x[3]*x[8]*x[6]^2 - x[1]*x[6]^3 + 3*x[1]*x[6]*x[8]^2 + x[4]*x[5]^3 - 3*x[4]*x[5]*x[7]^2 - x[2]*x[7]^3 + 3*x[2]*x[7]*x[5]^2 - 1.2342523

Do you want symbolic derivatives? It sounds like you just need a polynomial manipulator.

This is probably more general question.

I wonder if it would not be possible to share a differential rules for libraries. There are GitHub - JuliaDiff/DiffRules.jl: A simple shared suite of common derivative definitions and I wonder if Yota could not use those. Since I work mostly with static graphs, I would be curious to try Yota. On the other hand I use few custom gradients, which I have currently made available for Flux.

But I am already impressed by Yota.

2 Likes

Sure that is possible, but there are a lot of applications where you have your polynomials already in a factored form and the expanded symbolic form would be quite wasteful.

For example take a linear form l(x) = x[1] + 3x[2]+ 2x[3] + 5x[4] and the polynomial
g(x) = l(x)^4 - l(x)^2. You really want to avoid to compute with the symbolic form of this.

1 Like

Usually yes, but as with any early-stage software, itā€™s worth to check it in practice :slight_smile: Right now you can use code from this reply to inspect generated stuff, I also created an issue to make it more transparent.

One possible issue is that you use lots of x[i] expressions which themselves have derivatives like:

dx = zero(x)
dx[i] .= dy

So in your example you will get 20 intermediate arrays of the same size as x and most libraries - including Yota at its current state - wonā€™t optimize them out. Let me see what we can do in this case (although I canā€™t promise great results since Yota is mostly optimized for quite large arrays).

Also thanks for this example, it helped to uncover a bug in derivatives of multiargument *, /, etc. Iā€™ve created an issue for this.

3 Likes

The problem with this is that it is at odds with recommendations for performance in general Julia code, and often does make a significant difference. (I understand itā€™s hard to support mutation, Iā€™m just saying that this is something that would be very useful to have for general-purpose AD)

1 Like

I agree that using pre-allocated buffers does make a lot of difference, but I would call the technique an inconvenient workaround instead of a recommendation. It opens a whole can of worms, including element type calculations for containers and buffer ownership, which are all manageable but inconvenient.

Is broadcasting supported? I had a go, with the following function, but get an error Failed to find a derivative for %44 = broadcast(%43, %38, %-1)::Array{Float64,2}:

function pd(x)
    m = x'*x
    v = diag(m)
    v .+ v' .- 2 .* m
end

Yota.grad(x -> sum(pd(x)), rand(2,3)) 

Thanks for this example! This is exactly the feedback I wanted from the announce - a set of workflows which are different from mine and which I thus missed.

In this particular case function diag wasnā€™t recorded as primitive, but instead traced-trough resulting in multiple getindex() calls (and some other unexpected things that I still need to investigate). I have created an issue for this. Also pardon for slow reaction - start of the week turned to be pretty busy, but I hope to fix all mentioned issues by next Monday.

As for broadcasting itself, it should work fine as long as the function being broadcasted is a primitive with defined @diffrule. Possibly we will remove this restriction in future, but right now itā€™s unclear how much of a constraint it is in real life (e.g. in your case all the broadcasting happens over primitives).

Letā€™s break it into 2 parts:

  1. Using autodiff to differentiate existing functions with mutable operations.
  2. Getting efficient code for forward and backward passes.

Somewhere deep in my backlog I have a ticket to rewrite most popular mutating operations (e.g. mul!) back into non-mutating as a part of preprocessing, so it should address (1) at least for some portion of functions.

(2) is already partially addressed - all matrix multiplications are replaced with mul! in generated code and all array variables are modified in-place using .= instead of =. In my previous packages I also had a macro to define rule for rewriting non-mutating into mutating operations, perhaps Iā€™ll bring it to Yota too.

The set of rule in Yota dates back to ReverseDiffSource.jl and has never been seriously revisited, so yes, itā€™s possible I will migrate to them at some point.

1 Like

OK, thanks for taking a look!

The other problem I ran into was I think due to calling vec(m)ā€¦ which @less tells me is reshape(a,length(a)) , and if Iā€™m reading right it looks like you only have diffrules for reshape(a, n,m).

1 Like

Congrats on the announcement @dfdx! Itā€™s great to have something to show off the power of Cassette here, and a bonus to have something robust and maintained to benchmark against.

Iā€™m hoping that we can end up doing many Yota-style optimisations on the ā€œstatic subgraphsā€ of Juliaā€™s IR, but thereā€™s some way to go in terms of compiler support before thatā€™s possible. With some cleverness in handling things like memory management it should be possible to avoid trading off performance and flexibility entirely.

Iā€™m curious if youā€™ve tried this on something like DiffEq? For me this is the ultimate challenge case for Zygote and Iā€™d be very curious how Yota fares.

9 Likes

Thanks Mike! Iā€™ve actually spent quite a lot of time inspecting code of Zygote to learn more about Julia IR on early stage of Yota.

With some cleverness in handling things like memory management it should be possible to avoid trading off performance and flexibility entirely.

In my tests memory pre-allocation is by far the most important performance optimization. The tricky part is that you need to hold pre-allocated buffers somewhere instead of creating them on each function call. One way to do it that might be suitable to Zygote is to pass memory manager (in simplest case - a Dict with buffers for each tensor variable) as an optional parameter to a gradient function. I used this approach in XGrad and it worked pretty well in practice providing to a user an ability (but not obligation) to speed up calculations with just a bit more typing. I believe if you implement just this (or any other clever memory pre-allocation strategy), it will bring you to almost optimal performance without loosing any of current flexibility.

Iā€™m curious if youā€™ve tried this on something like DiffEq? For me this is the ultimate challenge case for Zygote and Iā€™d be very curious how Yota fares.

Iā€™ve never considered DiffEq, but it sounds like an interesting challenge!

2 Likes

It might just be a different problem.

For an adaptive integrator, I am not sure if it will ever work since it needs to know how to backprop through the control flow since the computation is very value-dependent. But one thing we could do is define an adjoint primitive via DiffEqSensitivity.jlā€™s adjoint for Yota.jl so that way users can make use of DiffEq+Yota, but it would have limitations of course.

Another approach would be to just add support for dynamic graphs in Yota - technically itā€™s not hard, but it invalidates most of current optimizations. Iā€™ve created an issue to track ideas.

2 Likes