The pass I wrote transforms:
m1 = @model x begin
μ ~ @model z ~ Normal(0,1)
x ~ Normal(μ,1)
end
into
m2 = @model x begin
M = @model z ~ Normal(0,1)
μ ~ M
x ~ Normal(μ,1)
end
(but with an autogenerated name instead of M
)
so I was implicitly assuming that the latter is a more canonical form / the next step in a transformation process.
This is harder:
M = @model z ~ Normal(0,1)
m3 = @model x begin
μ ~ M
x ~ Normal(μ,1)
end
I really should look into cassette!
Should probably do that before, for example trying to resurrect/reimplement Sugar.jl; README example:
using Sugar
function controlflow_1(a, b)
if a == 10
x = if b == 22
7
else
8
end
for i=1:100
x += i
x -= 77
if i == 77
continue
elseif i == 99
break
end
end
return x
else
return 77
end
end
Sugar.sugared(controlflow_1, (Int, Int), code_lowered)
yields
quote
NewvarNode(:(_4))
if _2 == 10
if _3 == 22
_7 = 7
else
_7 = 8
end
_4 = _7
SSAValue(0) = (Main.colon)(1,100)
_5 = (Base.start)(SSAValue(0))
while !((Base.done)(SSAValue(0),_5))
SSAValue(1) = (Base.next)(SSAValue(0),_5)
_6 = (Core.getfield)(SSAValue(1),1)
_5 = (Core.getfield)(SSAValue(1),2)
_4 = _4 + _6
_4 = _4 - 77
if _6 == 77
continue
end
if _6 == 99
break
end
end
return _4
end
return 77
end
The approach could give us the AST of other functions.
Alternatively, because in your example M
was defined using the @model
macro, you could make sure this macro provides information necessary.
I’ve been “watching” Soss on Github (but didn’t dive into the code base) since you first announced it here!
Soss is much further along than my efforts, which probably don’t deserve much more credit than a pipe dream at the moment. I put some code up here, but had a lot of time to rethink implementation details since the last update. I’d also use ChainRules instead of DiffRules, and probably drop LightGraphs. There are also no comments or documentation.
When I resume working on this, I’m likely to start fresh with most of, or at least refactor significantly.
Much of my work recently has been on building the lower level tools I plan to use (or on projects that will actually go into my dissertation).
Eg, vectorization efforts in matrix multiplciation and improved loop vectorization, eg
julia> using StaticArrays, BenchmarkTools, Random
julia> function fill_BPP!(BPP::AbstractMatrix{T}) where T
@views randn!(BPP[:,1:3])
@inbounds for i ∈ 1:size(BPP,1)
S = ( @SMatrix randn(T,6,3) ) |> x -> x' * x
BPP[i,4] = zero(T)
BPP[i,5] = S[1,1]
BPP[i,6] = S[1,2]
BPP[i,7] = S[2,2]
BPP[i,8] = S[1,3]
BPP[i,9] = S[2,3]
BPP[i,10] = S[3,3]
end
BPP
end
fill_BPP! (generic function with 1 method)
julia> BPP = fill_BPP!(Matrix{Float32}(undef, 4096, 10));
julia> @inline function pdbacksolve(x1,x2,x3,S11,S12,S22,S13,S23,S33)
Ui33 = inv(sqrt(S33))
U13 = S13 * Ui33
U23 = S23 * Ui33
Ui22 = inv(sqrt(S22 - U23*U23))
U12 = (S12 - U13*U23) * Ui22
Ui33x3 = Ui33*x3
Ui11 = inv(sqrt(S11 - U12*U12 - U13*U13))
Ui12 = - U12 * Ui11 * Ui22
Ui13x3 = - (U13 * Ui11 + U23 * Ui12) * Ui33x3
Ui23x3 = - U23 * Ui22 * Ui33x3
(
Ui11*x1 + Ui12*x2 + Ui13x3,
Ui22*x2 + Ui23x3,
Ui33x3
)
end
pdbacksolve (generic function with 1 method)
julia> X = Matrix{Float32}(undef, size(BPP,1), 3);
julia> function process_big_prop_points_v1!(X::AbstractMatrix{T}, Data::AbstractMatrix{T}) where T
@inbounds @simd ivdep for i ∈ 1:size(Data,1)
X[i,:] .= pdbacksolve(
Data[i,1],Data[i,2],Data[i,3],
Data[i,5],Data[i,6],Data[i,7],Data[i,8],Data[i,9],Data[i,10]
)
end
end
process_big_prop_points_v1! (generic function with 1 method)
julia> @benchmark process_big_prop_points_v1!($X, $BPP)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 51.442 μs (0.00% GC)
median time: 53.201 μs (0.00% GC)
mean time: 54.350 μs (0.00% GC)
maximum time: 90.791 μs (0.00% GC)
--------------
samples: 10000
evals/sample: 1
julia> using SIMDPirates, SLEEFwrap
julia> @generated function process_big_prop_points!(X::AbstractMatrix{T}, Data::AbstractMatrix{T}) where T
quote
@vectorize $T for i ∈ 1:size(Data,1)
X[i,:] .= pdbacksolve(
Data[i,1],Data[i,2],Data[i,3],
Data[i,5],Data[i,6],Data[i,7],Data[i,8],Data[i,9],Data[i,10]
)
end
end
end
process_big_prop_points! (generic function with 1 method)
julia> @benchmark process_big_prop_points!($X, $BPP)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 6.051 μs (0.00% GC)
median time: 6.111 μs (0.00% GC)
mean time: 6.308 μs (0.00% GC)
maximum time: 13.125 μs (0.00% GC)
--------------
samples: 10000
evals/sample: 5
julia> (@macroexpand @vectorize Float32 for i ∈ 1:size(Data,1)
X[i,:] .= pdbacksolve(
Data[i,1],Data[i,2],Data[i,3],
Data[i,5],Data[i,6],Data[i,7],Data[i,8],Data[i,9],Data[i,10]
)
end) |> striplines
quote
##N#419 = size(Data, 1)
(Q, r) = (SLEEFwrap.VectorizationBase).size_loads(Data, 1, Val{16}())
##pData#426 = SLEEFwrap.vectorizable(Data)
##pX#420 = SLEEFwrap.vectorizable(X)
begin
for ##i#418 = 1:16:Q * 16
##iter#417 = ##i#418
begin
####pData#426_i#427 = SIMDPirates.vload(SVec{16,Float32}, ##pData#426, ##iter#417 + 0 * SLEEFwrap.stride_row(Data))
####pData#426_i#428 = SIMDPirates.vload(SVec{16,Float32}, ##pData#426, ##iter#417 + 1 * SLEEFwrap.stride_row(Data))
####pData#426_i#429 = SIMDPirates.vload(SVec{16,Float32}, ##pData#426, ##iter#417 + 2 * SLEEFwrap.stride_row(Data))
####pData#426_i#430 = SIMDPirates.vload(SVec{16,Float32}, ##pData#426, ##iter#417 + 4 * SLEEFwrap.stride_row(Data))
####pData#426_i#431 = SIMDPirates.vload(SVec{16,Float32}, ##pData#426, ##iter#417 + 5 * SLEEFwrap.stride_row(Data))
####pData#426_i#432 = SIMDPirates.vload(SVec{16,Float32}, ##pData#426, ##iter#417 + 6 * SLEEFwrap.stride_row(Data))
####pData#426_i#433 = SIMDPirates.vload(SVec{16,Float32}, ##pData#426, ##iter#417 + 7 * SLEEFwrap.stride_row(Data))
####pData#426_i#434 = SIMDPirates.vload(SVec{16,Float32}, ##pData#426, ##iter#417 + 8 * SLEEFwrap.stride_row(Data))
####pData#426_i#435 = SIMDPirates.vload(SVec{16,Float32}, ##pData#426, ##iter#417 + 9 * SLEEFwrap.stride_row(Data))
begin
##numiter#424 = SLEEFwrap.num_row_strides(X)
##stride#425 = SLEEFwrap.stride_row(X)
##B#421 = SLEEFwrap.extract_data.(pdbacksolve(####pData#426_i#427, ####pData#426_i#428, ####pData#426_i#429, ####pData#426_i#430, ####pData#426_i#431, ####pData#426_i#432, ####pData#426_i#433, ####pData#426_i#434, ####pData#426_i#435))
for ##j#423 = 0:SIMDPirates.vsub(##numiter#424, 1)
SIMDPirates.vstore(getindex(##B#421, SIMDPirates.vadd(1, ##j#423)), ##pX#420, SIMDPirates.vmuladd(##stride#425, ##j#423, ##iter#417))
end
end
end
end
end
begin
if r > 0
mask = SIMDPirates.vless_or_equal(SIMDPirates.vsub((Core.VecElement{Int32}(1), Core.VecElement{Int32}(2), Core.VecElement{Int32}(3), Core.VecElement{Int32}(4), Core.VecElement{Int32}(5), Core.VecElement{Int32}(6), Core.VecElement{Int32}(7), Core.VecElement{Int32}(8), Core.VecElement{Int32}(9), Core.VecElement{Int32}(10), Core.VecElement{Int32}(11), Core.VecElement{Int32}(12), Core.VecElement{Int32}(13), Core.VecElement{Int32}(14), Core.VecElement{Int32}(15), Core.VecElement{Int32}(16)), unsafe_trunc(Int32, r)), zero(Int32))
##i#418 = (##N#419 - r) + 1
##iter#417 = ##i#418
begin
####pData#426_i#427 = SIMDPirates.vload(SVec{16,Float32}, ##pData#426, ##iter#417 + 0 * SLEEFwrap.stride_row(Data), mask)
####pData#426_i#428 = SIMDPirates.vload(SVec{16,Float32}, ##pData#426, ##iter#417 + 1 * SLEEFwrap.stride_row(Data), mask)
####pData#426_i#429 = SIMDPirates.vload(SVec{16,Float32}, ##pData#426, ##iter#417 + 2 * SLEEFwrap.stride_row(Data), mask)
####pData#426_i#430 = SIMDPirates.vload(SVec{16,Float32}, ##pData#426, ##iter#417 + 4 * SLEEFwrap.stride_row(Data), mask)
####pData#426_i#431 = SIMDPirates.vload(SVec{16,Float32}, ##pData#426, ##iter#417 + 5 * SLEEFwrap.stride_row(Data), mask)
####pData#426_i#432 = SIMDPirates.vload(SVec{16,Float32}, ##pData#426, ##iter#417 + 6 * SLEEFwrap.stride_row(Data), mask)
####pData#426_i#433 = SIMDPirates.vload(SVec{16,Float32}, ##pData#426, ##iter#417 + 7 * SLEEFwrap.stride_row(Data), mask)
####pData#426_i#434 = SIMDPirates.vload(SVec{16,Float32}, ##pData#426, ##iter#417 + 8 * SLEEFwrap.stride_row(Data), mask)
####pData#426_i#435 = SIMDPirates.vload(SVec{16,Float32}, ##pData#426, ##iter#417 + 9 * SLEEFwrap.stride_row(Data), mask)
begin
##numiter#424 = SLEEFwrap.num_row_strides(X)
##stride#425 = SLEEFwrap.stride_row(X)
##B#421 = SLEEFwrap.extract_data.(pdbacksolve(####pData#426_i#427, ####pData#426_i#428, ####pData#426_i#429, ####pData#426_i#430, ####pData#426_i#431, ####pData#426_i#432, ####pData#426_i#433, ####pData#426_i#434, ####pData#426_i#435))
for ##j#423 = 0:SIMDPirates.vsub(##numiter#424, 1)
SIMDPirates.vstore(getindex(##B#421, SIMDPirates.vadd(1, ##j#423)), ##pX#420, SIMDPirates.vmuladd(##stride#425, ##j#423, ##iter#417), mask)
end
end
end
end
end
end
My goal would be for the software to be several times faster than the likes of Stan.
I’ll have to spend some time looking at the Soss code base.