[ANN] Yota.jl - yet another reverse-mode autodiff package

#1

Yota.jl is a package for reverse-mode automatic differentiation designed specifically for machine learning applications.


Usage

mutable struct Linear{T}
    W::AbstractArray{T,2}
    b::AbstractArray{T}
end

forward(m::Linear, X) = m.W * X

loss(m::Linear, X) = sum(forward(m, X))

m = Linear(rand(3,4), rand(3))
X = rand(4,5)

val, g = grad(loss, m, X)

where g is an object of type GradientResult holding gradients w.r.t. input variables. For scalars and tensors it returns gradient value, for structs it returns dictionary of (field path → gradient) pairs:

julia> g[1]
Dict{Tuple{Symbol},Array{Float64,2}} with 1 entry:
  (:W,) => [3.38128 2.97142 2.39706 1.55525; 3.38128 2.97142 2.39706 1.55525; 3.38128 2.97142 2.39706 1.55525]   # gradient w.r.t. m.W

julia> g[2]  # gradient w.r.t. X
4×5 Array{Float64,2}:
 0.910691  0.910691  0.910691  0.910691  0.910691
 1.64994   1.64994   1.64994   1.64994   1.64994
 1.81215   1.81215   1.81215   1.81215   1.81215
 2.31594   2.31594   2.31594   2.31594   2.31594

GradientResult can be used in conjunction with update!() function to modify tensors and fields of (mutable) structs, see README for more details.


Features

  • differentiation over scalars, tensors and structs designed to support PyTorch-like API
  • easy to add custom derivatives
  • Cassette-based tracer which avoids structural type constraints
  • experimental GPU support via CuArrays
  • performance-first implementation

Note that tracer is fully customizable and can be used independently of automatic differentiation. Again, see README for an example.

Performance

Comparison of autodiff implementations turns out to be unexpectedly hard because of different sets of supported features, different sets of primitives, etc. (e.g. see an attempt to compare it with Zygote.jl). However, Yota uses a number of proven optimizations from my previous autodiff packages, so generally performance of differentiation pass should be not larger than 2-3x compared to a call to the original function. To put it differently, if you see significant difference between automatic and manually-crafted differentiation, consider it a bug.

Comparison to other autodiff packages

Unlike Zygote, Yota emphasizes performance over flexiblity. While Zygote aims to support full dynamism of the Julia language, Yota restricts a user to a static graph consisting of analytical functions commonly used in ML. This way we can apply a number of optimizations including memory buffer pre-allocation, common subexpression elimination, etc.

Unlike Capstan, Yota is implemented :slight_smile: Although I’ll be very curious to test Capstan when it’s shipped.

Also Yota doesn’t use tracked arrays or function overloading like AutoGrad, ReverseDiff or current Flux tracker and thus doesn’t hit ambiguity issues of multiple dispatch. The downside is that dynamic computational graphs become harder to implement and are currently not supported in Yota. For the curious, there’s also an older version of the package stored in YotaTracked and implemented similar to the mentioned libraries.

Feedback and bug reports are welcome!

27 Likes

#2

Thanks this looks interesting.

I see that you separate the definitions for multiple arguments, which I presume means that each is only evaluated if required:

@diffrule ^(x::Real, y::Real)    x     y * x ^ (y-1) * ds
@diffrule ^(x::Real, y::Real)    y     log(x) * x ^ y * ds

Is there any way to share information, e.g. here to re-use x ^ y from the forward pass? Or perhaps this falls under common subexpression elimination & happens automatically?

0 Likes

#3

Looks very interesting! Does Yota.jl currently support nested and higher-order differentiation?

0 Likes

#4

Very cool! Does Yota support complex differentiation?

1 Like

#5

A related comment: is there a mathematical text that goes over what automatic differentiation is?

0 Likes

#6

Exactly, CSE should take care of it. Right now it’s done during function compilation and is somewhat hard to observe, in future I’m planning to make it a part of tape transformation so one would be able to see exact instructions to be executed.

Also note that if one of parameters is constant (e.g. in expression x ^ 2.0), derivative w.r.t. this param isn’t recorded at all.

0 Likes

#7

Not yet, but it would be interesting feature to add! However, if you think about interface like:

f(x) = ...
g(x) = grad(f, x)[2] * ...

most likely it won’t work, at least not with the current implementation of grad - currently grad works as a caching layer for underlying _grad which itself returns a compiled tape with pre-allocated buffers. So it’s not like taking one ordinary function and returning another ordinary function. But with appropriate API I believe we can make higher order derivatives on the tape level and still have all the optimizations applied to the final result.

1 Like

#8

There are quite a few good texts, I think

@book{griewank2008evaluating,
  title = {Evaluating derivatives: principles and techniques of algorithmic differentiation},
  author = {Griewank, Andreas and Walther, Andrea},
  volume = {105},
  year = {2008},
  publisher = {Siam},
}

is a good introduction.

0 Likes

#9

Haven’t tried it, but Yota should be general enough to support any kind of differentiation that boils down to primitives and chain rule. Do you have an example of you need at hand?

Not really a mathematical text, but I once explained reverse-mode autodiff example here.

0 Likes

#10

Awesome. But for more complicated functions, say I have something where I’d like to cache some large calculation from which both forward & backward are then fast. Can I saftely jut place this in mybigfunc(x) and call that twice (with identical arguments)?

0 Likes

#11

CSE can fail on even simple functions, like exp.

julia> rx = Ref(0.5);

julia> using BenchmarkTools

julia> foo(x) = exp(x)
foo (generic function with 1 method)

julia> ∂foo(x) = (exp(x),exp(x))
∂foo (generic function with 1 method)

julia> @benchmark foo($(rx)[])
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     6.350 ns (0.00% GC)
  median time:      6.450 ns (0.00% GC)
  mean time:        6.554 ns (0.00% GC)
  maximum time:     21.960 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark ∂foo($(rx)[])
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     13.317 ns (0.00% GC)
  median time:      13.778 ns (0.00% GC)
  mean time:        14.066 ns (0.00% GC)
  maximum time:     28.647 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     998

Which means, unless I’m mistaken, logistic(x) in this example is going to get computed three times:

logistic(x) = 1 / (1 + exp(-x))
# for an expression like `logistic(x)` where x is a Number
# gradient w.r.t. x
# is `(logistic(x) * (1 - logistic(x)) * ds)` where "ds" stands for derivative "dL/dy"
@diffrule logistic(x::Number) x (logistic(x) * (1 - logistic(x)) * ds)

Once on the forward pass, and twice on the backwards pass.

1 Like

#12

OK, so the CSE in question here is just Julia’s? I had the idea that the macro @diffrule was itself looking for such repeats… but perhaps that is the planned tape feature mentioned?

Maybe I can answer this by benchmarking gradients of logistic(x) when I have a minute.

0 Likes

#13

With libraries like CommonSubexpressions or DataFlow, it’s certainly possible to do it on the Julia side!

Maybe Yota does. I haven’t tested that / dived into the source code.
The hard part is you need some form (weaker than Julia’s @pure) of purity, to know functions like logistic are going to be the same each time they’re called.

Given that Yota emphasizes performance, and already requires static graphs, I could certainly see the case for some side-effect free assumptions to allow aggressive CSE.

0 Likes

#14

And this is exactly how it works :slight_smile: To be precise, here’s the line that performs CSE. You can check the result like this (which is not the part of external API though):

julia> using Yota

julia> foo(x) = exp(x)
foo (generic function with 1 method)

julia> _, g = grad(foo, 0.5)
(1.6487212707001282, GradResult(1))

julia> g.tape
Tape
  inp %1::Float64
  %2 = exp(%1)::Float64
  const %3 = 1.0::Float32
  %4 = exp(%1)::Float64
  %5 = *(%4, %3)::Float64

julia> Yota.generate_function_expr(g.tape)
:(function ##tape_fn#364()
      #= /home/slipslop/work/Yota/src/compile.jl:112 =#
      #= prologue:0 =#
      %1 = (inp %1::Float64).val
      %2 = (%2 = exp(%1)::Float64).val
      %3 = (const %3 = 1.0::Float32).val
      %4 = (%4 = exp(%1)::Float64).val
      %5 = (%5 = *(%4, %3)::Float64).val
      #= body:0 =#
      #= /home/slipslop/.julia/packages/Espresso/Lewh0/src/exgraph.jl:100 =#
      %2 = (exp)(%1)
      %3 = 1.0f0
      %5 = (*)(%2, %3)
      #= epilogue:0 =#
      (inp %1::Float64).val = %1
      (%2 = exp(%1)::Float64).val = %2
      (const %3 = 1.0::Float32).val = %3
      (%4 = exp(%1)::Float64).val = %4
      (%5 = *(%4, %3)::Float64).val = %5
  end)

Except for prologue and epilogue (which are used for buffer pre-allocation and are the most important optimization for large-scale deep learning models), the only code left is:

%2 = (exp)(%1)
%3 = 1.0f0
%5 = (*)(%2, %3)

We could further optimize it to just exp(%1), but usually it doesn’t make much difference and is presumably eliminated by Julia compiler anyway.

It might look limiting to forbid mutable operations, but for comparison here’s citation of PyTorch documentation:

Supporting in-place operations in autograd is a hard matter, and we discourage their use in most cases. Autograd’s aggressive buffer freeing and reuse makes it very efficient and there are very few occasions when in-place operations actually lower memory usage by any significant amount. Unless you’re operating under heavy memory pressure, you might never need to use them.

3 Likes

#15

Haven’t tried it, but Yota should be general enough to support any kind of differentiation that boils down to primitives and chain rule. Do you have an example of you need at hand?

I need derivatives of multivariate polynomials with real or complex coefficients and maybe some basic analytic functions (exp, log, sin, cos). So far I did not find an AD package which doesn’t introduce significant overhead.

An example would be the derivative of

 f(x) = x[3]*x[8]^3 - 3*x[3]*x[8]*x[6]^2 - x[1]*x[6]^3 + 3*x[1]*x[6]*x[8]^2 + x[4]*x[5]^3 - 3*x[4]*x[5]*x[7]^2 - x[2]*x[7]^3 + 3*x[2]*x[7]*x[5]^2 - 1.2342523
0 Likes

#16

Do you want symbolic derivatives? It sounds like you just need a polynomial manipulator.

0 Likes

#17

This is probably more general question.

I wonder if it would not be possible to share a differential rules for libraries. There are https://github.com/JuliaDiff/DiffRules.jl and I wonder if Yota could not use those. Since I work mostly with static graphs, I would be curious to try Yota. On the other hand I use few custom gradients, which I have currently made available for Flux.

But I am already impressed by Yota.

2 Likes

#18

Sure that is possible, but there are a lot of applications where you have your polynomials already in a factored form and the expanded symbolic form would be quite wasteful.

For example take a linear form l(x) = x[1] + 3x[2]+ 2x[3] + 5x[4] and the polynomial
g(x) = l(x)^4 - l(x)^2. You really want to avoid to compute with the symbolic form of this.

1 Like

#19

Usually yes, but as with any early-stage software, it’s worth to check it in practice :slight_smile: Right now you can use code from this reply to inspect generated stuff, I also created an issue to make it more transparent.

One possible issue is that you use lots of x[i] expressions which themselves have derivatives like:

dx = zero(x)
dx[i] .= dy

So in your example you will get 20 intermediate arrays of the same size as x and most libraries - including Yota at its current state - won’t optimize them out. Let me see what we can do in this case (although I can’t promise great results since Yota is mostly optimized for quite large arrays).

Also thanks for this example, it helped to uncover a bug in derivatives of multiargument *, /, etc. I’ve created an issue for this.

3 Likes

#20

The problem with this is that it is at odds with recommendations for performance in general Julia code, and often does make a significant difference. (I understand it’s hard to support mutation, I’m just saying that this is something that would be very useful to have for general-purpose AD)

1 Like