Yota.jl is a package for reverse-mode automatic differentiation designed specifically for machine learning applications.
Usage
mutable struct Linear{T}
W::AbstractArray{T,2}
b::AbstractArray{T}
end
forward(m::Linear, X) = m.W * X
loss(m::Linear, X) = sum(forward(m, X))
m = Linear(rand(3,4), rand(3))
X = rand(4,5)
val, g = grad(loss, m, X)
where g
is an object of type GradientResult
holding gradients w.r.t. input variables. For scalars and tensors it returns gradient value, for structs it returns dictionary of (field path → gradient) pairs:
julia> g[1]
Dict{Tuple{Symbol},Array{Float64,2}} with 1 entry:
(:W,) => [3.38128 2.97142 2.39706 1.55525; 3.38128 2.97142 2.39706 1.55525; 3.38128 2.97142 2.39706 1.55525] # gradient w.r.t. m.W
julia> g[2] # gradient w.r.t. X
4×5 Array{Float64,2}:
0.910691 0.910691 0.910691 0.910691 0.910691
1.64994 1.64994 1.64994 1.64994 1.64994
1.81215 1.81215 1.81215 1.81215 1.81215
2.31594 2.31594 2.31594 2.31594 2.31594
GradientResult
can be used in conjunction with update!()
function to modify tensors and fields of (mutable) structs, see README for more details.
Features
- differentiation over scalars, tensors and structs designed to support PyTorch-like API
- easy to add custom derivatives
- Cassette-based tracer which avoids structural type constraints
- experimental GPU support via CuArrays
- performance-first implementation
Note that tracer is fully customizable and can be used independently of automatic differentiation. Again, see README for an example.
Performance
Comparison of autodiff implementations turns out to be unexpectedly hard because of different sets of supported features, different sets of primitives, etc. (e.g. see an attempt to compare it with Zygote.jl). However, Yota uses a number of proven optimizations from my previous autodiff packages, so generally performance of differentiation pass should be not larger than 2-3x compared to a call to the original function. To put it differently, if you see significant difference between automatic and manually-crafted differentiation, consider it a bug.
Comparison to other autodiff packages
Unlike Zygote, Yota emphasizes performance over flexiblity. While Zygote aims to support full dynamism of the Julia language, Yota restricts a user to a static graph consisting of analytical functions commonly used in ML. This way we can apply a number of optimizations including memory buffer pre-allocation, common subexpression elimination, etc.
Unlike Capstan, Yota is implemented Although I’ll be very curious to test Capstan when it’s shipped.
Also Yota doesn’t use tracked arrays or function overloading like AutoGrad, ReverseDiff or current Flux tracker and thus doesn’t hit ambiguity issues of multiple dispatch. The downside is that dynamic computational graphs become harder to implement and are currently not supported in Yota. For the curious, there’s also an older version of the package stored in YotaTracked and implemented similar to the mentioned libraries.
Feedback and bug reports are welcome!