[ANN] Yota v0.5 - now with ChainRules support

A reverse-mode autodiff package Yota.jl just got its biggest update with the release of v0.5 version.

ChainRules support

In an effort towards unified autodiff ecosystem, Yota has moved from custom derivative rules to the de facto standard package - ChainRules.jl. Thanks to its numerous contributors, ChainRules and related packages provide a rich set of rules and enable easier switching between AD implementations.

New tracer

The code tracer has been significantly reworked and moved to a separate package Ghost.jl. See the announcement for more details.


If you migrate from previous versions of Yota, please take a look at the most significant breaking changes.

Given the amount of changes in the code, Yota is expected to take some time to stabilize. As always, issues, feature requests and code contributions are welcome!

35 Likes

Hi, is there a way to get a pullback function with Yota (for vector-jacobian-products), instead of “just” a gradient?

Nope. Even though Yota uses ChainRules, it doesn’t build a stack of pullbacks, instead destructing them onto the tape.

One thing I’m planning to add in the nearest future though is the ability to provide a custom seed. Something like this:

foo(W, x) = W * x    # vector-valued function!
val, g = grad(foo, W, x; seed=[0, 0, 1])

In this case Yota will start reverse pass with [0, 0, 1] instead of the usual 1. I guess this makes the seed a kind of “vector” in the JVP.

2 Likes

Oh, yes. So is Yota Ok with the last/outermost function returning a vector instead of a scalar?

How does Yota compare nowadays, in terms of capabilities and speed, to other alternatives such as Nabla.jl, Zygote.jl, ReverseDiff.jl,…

Right now Yota forbids non scalar-valued functions, but this restriction is completely artificial and will be removed once the seed is introduced. In its turn, the need for the seed-able gradient is driven by the next big feature - differentiation through loops (which Ghost can already trace into a dedicated operation, but Yota cannot yet handle in the reverse pass) and thus is expected to land in Yota v0.5.1 (~1-2 weeks).

Since you mention Nabla twice and don’t mention Yota at all, I’ll assume you mean to compare Yota to Nabla, Zygote and ReverseDiff.

Unlike ReverseDiff and Nabla, Yota doesn’t use type wrappers (e.g. TrackedArray or Leaf) to build an execution graph (tape). I tried this approach a long ago and was disappointed with it. Not only type wrappers limit you to functions with generic arguments (e.g. foo(::Real) is ok, but foo(::Float64) isn’t), but they also make multiple dispatch and type hierarchies your enemies (should you have Diagonal{TrackedArray} or TrackedArray{Diagonal}? how about overloading * for all their combinations?). Thanks to IR-level tools, we can now trace arbitrary Julia code and build tapes without any restrictions.

Yet just like ReverseDiff and Nabla, Yota uses tape to represent a computational graph. On the contrary, Zygote utilizes stacks of closures to combine derivatives (see the paper). Both approaches have their advantages and disadvantages, but what I like about tapes is that they are easily hackable. In Yota, you can call gradtape(f, args...) and get the Tape, optimize it with your favorite graph algorithm, insert more operations, visualize, export to ONNX, etc. Zygote works fully on Julia IR code level which requires much more knowledge and attention to do correctly.

On the speed ↔ flexibility scale, Zygote is definitely more flexible trying to support as many use cases as possible, while Yota keeps things simple and fast. Avalon.jl (based on Yota) contains several benchmarks versus Flux (based on Zygote) - these are relatively old and may not reflect the current state, but may give you an intuition about Yota’s focus.

(As a disclaimer, I don’t follow the development of the mentioned libraries and apologize for any possible misinformation).

7 Likes

Yes, I meant Yota.
Some users report problems when using Flux with Zygote. Then following other alternatives, maybe Flux with Yota or maybe Soss.jl or other alternatives.

1 Like

Cool, thanks! I would love to try Yota with use cases that need to do vector-jacobian-products.

I cannot thank you enough for adding support for ChainRules. It is so much easier to write and debug my frules and rrules once, then experiment with various AD frameworks.

I also appreciate how Yota.grad returns both the gradient and the value. It is very rare to just need the gradient and not the value, so this is the best API design for most purposes.

9 Likes

Who should we thank is the team behind ChainRules who enabled the unified ecosystem for the autodiff in Julia, integrating it into Yota was the easy part :slight_smile:

Just to clarify, since Yota is purely reverse-mode autodiff package, it uses only rrules and not frules. This is not a mantra though, if you get a problem that can only be solved with frules and it fits into Yota’s target domain (machine learning & co.), I’ll be happy to research this possibility as well!

8 Likes

Definitely — I think this was a great initiative. But it takes two to tango, so I also feel like thanking all the AD package authors who opted in so that my life is made easier.

Yes, I know — I just like to write and test both frule and rrule at the same time; if I am implementing one then adding the other is a no-brainer.

So far I had no need for mixed mode in Yota, but it would be an interesting option.

Not to mention that computing the value is essentially required to get to the gradient. I agree it only makes sense to return it

I look forward to a whole new world of bug reports against ChainRules

9 Likes

Any performance benchmarks vs Zygote and ReverseDiff on things with scalar iteration (i.e. not ML)? What’s the mutation support like?

2 Likes

Any performance benchmarks vs Zygote and ReverseDiff on things with scalar iteration (i.e. not ML)?

No benchmarks for Yota v0.5 yet. I have a preliminary plan to add regular benchmarks in CI/CD for several repositories (at least Yota and NNlib), but want to stabilize things first.

In general, I expect Yota to be somewhat faster on high-dimensional ML problems and slower on scalar and low-dimensional problems. This is because Yota doesn’t optimize small constant factors like type stability which usually takes < 1% of overall run time in ML tasks. At the same time, Yota cares about things like array preallocation, kernel fusion, etc. which you don’t need in scalar functions, but which becomes a killer feature in high-dimensional tasks.

Also note that all explicit optimizations have been removed during the refactoring and will be slowly added back in future versions. Benchmarks suggestions (both - in ML and non-ML domains) are highly welcome as they will help to track the progress.

What’s the mutation support like?

No mutation support is intended. If you have a mutating operation, you can try to wrap it into an rrule() to “hide” it from the AD engine, but the engine itself will not attempt to handle it.

A lot have been told about mutating in Zygote, but I should highlight a couple of important reasons for not supporting it in Yota:

  1. Mutation often means slower code, not faster. For example, filling an array element-by-element like x[i] = y may be fast on CPU, but for GPU it is a disaster. Yota explicitly shouts out when it sees something we don’t have a fast way to do. Surely, it restricts a number of use cases, but at the same time it helps to uncover some of the performance bugs.
  2. Since Yota produces an easy to handle tape, in many cases it should be possible to optimize it in a way that replaces possibly non-optimal immutable operations into their mutable counterparts.
6 Likes

I see, basically the opposite of Enzyme. That’ll be interesting.

2 Likes

It’s now on the master branch:

julia> W = rand(3, 4)
x 3Ă—4 Matrix{Float64}:
 0.905898  0.751797  0.724524  0.762578
 0.416724  0.307087  0.212776  0.694455
 0.754229  0.91923   0.205732  0.130126

julia> x = rand(4)
4-element Vector{Float64}:
 0.46976815615264256
 0.29409430273791903
 0.9282164909689541
 0.502618804377758

julia> val, g = grad(foo, W, x; seed=[0, 0, 1])
([1.702462429791826, 0.8326244095396746, 0.8810201289094359], (ZeroTangent(), [0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.46976815615264256 0.29409430273791903 0.9282164909689541 0.502618804377758], [0.754228829088897, 0.9192296288745856, 0.205731786463623, 0.13012567040781997]))
4 Likes