[ANN] Yota.jl - yet another reverse-mode autodiff package

I agree that using pre-allocated buffers does make a lot of difference, but I would call the technique an inconvenient workaround instead of a recommendation. It opens a whole can of worms, including element type calculations for containers and buffer ownership, which are all manageable but inconvenient.

Is broadcasting supported? I had a go, with the following function, but get an error Failed to find a derivative for %44 = broadcast(%43, %38, %-1)::Array{Float64,2}:

function pd(x)
    m = x'*x
    v = diag(m)
    v .+ v' .- 2 .* m
end

Yota.grad(x -> sum(pd(x)), rand(2,3)) 

Thanks for this example! This is exactly the feedback I wanted from the announce - a set of workflows which are different from mine and which I thus missed.

In this particular case function diag wasn’t recorded as primitive, but instead traced-trough resulting in multiple getindex() calls (and some other unexpected things that I still need to investigate). I have created an issue for this. Also pardon for slow reaction - start of the week turned to be pretty busy, but I hope to fix all mentioned issues by next Monday.

As for broadcasting itself, it should work fine as long as the function being broadcasted is a primitive with defined @diffrule. Possibly we will remove this restriction in future, but right now it’s unclear how much of a constraint it is in real life (e.g. in your case all the broadcasting happens over primitives).

Let’s break it into 2 parts:

  1. Using autodiff to differentiate existing functions with mutable operations.
  2. Getting efficient code for forward and backward passes.

Somewhere deep in my backlog I have a ticket to rewrite most popular mutating operations (e.g. mul!) back into non-mutating as a part of preprocessing, so it should address (1) at least for some portion of functions.

(2) is already partially addressed - all matrix multiplications are replaced with mul! in generated code and all array variables are modified in-place using .= instead of =. In my previous packages I also had a macro to define rule for rewriting non-mutating into mutating operations, perhaps I’ll bring it to Yota too.

The set of rule in Yota dates back to ReverseDiffSource.jl and has never been seriously revisited, so yes, it’s possible I will migrate to them at some point.

1 Like

OK, thanks for taking a look!

The other problem I ran into was I think due to calling vec(m)… which @less tells me is reshape(a,length(a)) , and if I’m reading right it looks like you only have diffrules for reshape(a, n,m).

1 Like

Congrats on the announcement @dfdx! It’s great to have something to show off the power of Cassette here, and a bonus to have something robust and maintained to benchmark against.

I’m hoping that we can end up doing many Yota-style optimisations on the “static subgraphs” of Julia’s IR, but there’s some way to go in terms of compiler support before that’s possible. With some cleverness in handling things like memory management it should be possible to avoid trading off performance and flexibility entirely.

I’m curious if you’ve tried this on something like DiffEq? For me this is the ultimate challenge case for Zygote and I’d be very curious how Yota fares.

9 Likes

Thanks Mike! I’ve actually spent quite a lot of time inspecting code of Zygote to learn more about Julia IR on early stage of Yota.

With some cleverness in handling things like memory management it should be possible to avoid trading off performance and flexibility entirely.

In my tests memory pre-allocation is by far the most important performance optimization. The tricky part is that you need to hold pre-allocated buffers somewhere instead of creating them on each function call. One way to do it that might be suitable to Zygote is to pass memory manager (in simplest case - a Dict with buffers for each tensor variable) as an optional parameter to a gradient function. I used this approach in XGrad and it worked pretty well in practice providing to a user an ability (but not obligation) to speed up calculations with just a bit more typing. I believe if you implement just this (or any other clever memory pre-allocation strategy), it will bring you to almost optimal performance without loosing any of current flexibility.

I’m curious if you’ve tried this on something like DiffEq? For me this is the ultimate challenge case for Zygote and I’d be very curious how Yota fares.

I’ve never considered DiffEq, but it sounds like an interesting challenge!

2 Likes

It might just be a different problem.

For an adaptive integrator, I am not sure if it will ever work since it needs to know how to backprop through the control flow since the computation is very value-dependent. But one thing we could do is define an adjoint primitive via DiffEqSensitivity.jl’s adjoint for Yota.jl so that way users can make use of DiffEq+Yota, but it would have limitations of course.

Another approach would be to just add support for dynamic graphs in Yota - technically it’s not hard, but it invalidates most of current optimizations. I’ve created an issue to track ideas.

2 Likes

I’ve added a couple of experimental features that have been mentioned here or in other threads. Experimental status means that they are most likely buggy, incomplete and may be completed removed in future versions. Yet I’m eager to hear what folks think about them.

Dynamic graphs

function iterative(x, n)
    for i=1:n
        x = 2 .* x
    end
    return sum(x)
end

x = rand(4)
_, g = grad(iterative, x, 1; dynamic=true)   # g[1] == [2.0, 2.0, 2.0, 2.0]
_, g = grad(iterative, x, 2; dynamic=true)   # g[1] == [4.0, 4.0, 4.0, 4.0]
_, g = grad(iterative, x, 3; dynamic=true)   # g[1] == [8.0, 8.0, 8.0, 8.0]

When calling grad() with dynamic=true keyword, Yota will trace the function (so cost of tracing is still here), but then instead of backpropagating, it will first lookup previously calculated gradients corresponding to the same tape. For a computational graph with reasonably many branches most traces should be cached soon.

Performance of re-tracing a function on each call depends mostly on performance of Cassette’s tagging mechanism, so any improvements to it should be automatically reflected in performance of tracer and thus the whole dynamic grad() calls.

Simple grad

When you don’t want to deal with buffers or “strange” data structures like GradientResult attached to your gradient function, you can use simplegrad() to generate a “pure” gradient function that takes same arguments as the original one and returns derivatives of its arguments:

import Yota: simplegrad

loss(W::AbstractMatrix, b::AbstractVector, x::AbstractArray) = sum(W * x .+ b)

W, b, x = rand(128, 784), rand(128), rand(784, 100)
∇loss = simplegrad(loss, W, b, x)   # note: ∇loss is a new _function_, world age concerns apply

val, dW, db, dx = ∇loss(W, b, x)

@code_lowered ∇loss(W, b, x)
# CodeInfo(
# 1 ─       %4 = (*)(%1, %3)
# │         %5 = +
# │         %6 = (broadcast)(%5, %4, %2)
# │         %7 = (sum)(%6)
# │         %8 = 1.0
# │         %9 = (Yota.sum_grad)(%6, %8)
# │         %10 = (Yota.unbroadcast)(%4, %9)
# │         %11 = (Yota.unbroadcast)(%2, %9)
# │         %12 = (transpose)(%3)
# │         %13 = (*)(%10, %12)
# │         %14 = (transpose)(%1)
# │         %15 = (*)(%14, %10)
# │         %13 = (Core.tuple)(%7, %13, %11, %15)
# └──       return %13
# )

Note that this form doesn’t support gradients of struct fields and will return nothing for them.

2 Likes