Hi, let me introduce a Julia package called BangBang.jl. Quoting its README:
BangBang.jl implements functions whose name ends with
!!
. Those functions provide a uniform interface for mutable and immutable data structures. Furthermore, those function implement the “widening” fallback for the case the usual mutating function does not work (e.g.,push!!(Int[], 1.5)
creates a new arrayFloat64[1.5]
).See the supported functions in the documentation
It is one of the spin-off packages from Transducers.jl v0.3 (ANN: Transducers.jl 0.3. taking "zeros" seriously, type stability improvements, fusible groupby, OnlineStats, "GPU support", and more). The motivation back then was to make it easy to write function that works with mutable and immutable data structures:
julia> mapappend!!(f, ys, xs) = foldl(xs; init=ys) do ys, x
push!!(ys, f(x))
end
mapappend!! (generic function with 1 method)
julia> mapappend!!(x -> x + 1, [], 1:3)
3-element Array{Any,1}:
2
3
4
julia> mapappend!!(x -> x + 1, (), 1:3)
(2, 3, 4)
julia> mapappend!!(x -> x + 1, SVector(0), 1:3)
4-element SArray{Tuple{4},Int64,1,4} with indices SOneTo(4):
0
2
3
4
as well as eltype
-incompatible containers (aka widening):
julia> mapappend!!(x -> x / 2, [0], 1:3)
4-element Array{Float64,1}:
0.0
0.5
1.0
1.5
Zygote.jl support
One of the new features in BangBang.jl v0.3 is Zygote.jl support. You can now write functions that mutate data structures in normal execution but switch to immutable mode during the forward pass of the differentiation. This is of course far from the real mutation support Zygote may have at some point but I still find it handy; I don’t have to rewrite my models just for Zygote and test the compatibility of the mutating and non-mutating paths.
Here is a toy example (macro @!
rewrites mutating functions with the ones in BangBang):
using BangBang
using LinearAlgebra
function rnn!!(n, J, x)
dest = x
y = similar(x)
for _ in 1:n
@! y = mul!(y, J, x)
@! y .= tanh.(y)
x, y = y, x
end
@! dest .= x
return dest
end
using Zygote
d = 10
J = randn(d, d)
x0 = randn(d)
y_target = randn(d)
g, = Zygote.gradient(J -> sum((rnn!!(20, J, x0) .- y_target) .^ 2), J)
This example uses mul!!
and materialize!!
(for broadcasting) behind the scene to make it work with Zygote.
Just in case you are curious, the way this “Zygote support” is implemented is super trivial (just a oneliner). I suppose/hope it is as easy as that with ChainRulesCore.jl but I haven’t looked into it yet.