Hi,
Below, I have an extract of code that does
a) automatic differentiation (I know, ForwardDiff.jl does it better, but I want to learn) and
b) some “tensor” operation.
Running the same test on plain old Float64 is faster by a factor of 100ish, should be 10ish.
Profiling reveals two things I do not understand:
- my elementary operations on my autodiff ∂ℝ type are type-unstable. Why on earth?
- in the function/operator ∘, the call to zeros wastes time in “deprecated”. Am I not following the book?
I am running Julia 0.6.0.
Thank you in advance!
workspace()
using ForwardDiff
ι(a,d::Int) = 1:size(a,d)
ι(a::Vector) = 1:size(a,1)
ι(a::StepRangeLen) = 1:size(a,1)
using StaticArrays
ℤ = Int64
ℝ = Float64
ℝ1 = Vector{ℝ}
ι(a,d::Int) = 1:size(a,d)
ι(a::Vector) = 1:size(a,1)
ι(a::StepRangeLen) = 1:size(a,1)
struct ∂ℝ{N} <:Real
x :: ℝ
dx :: SVector{N,ℝ}
end
∂ℝ(x,dx::ℝ1) = ∂ℝ(x,SVector{size(dx,1),ℝ}(dx))
Base.convert(::Type{∂ℝ{N}},x::Union{ℤ,ℝ}) where{N} = ∂ℝ{N}(convert(ℝ,x),zeros(ℝ,N))
Base.promote_rule(::Type{∂ℝ{N}},::Type{<:Union{ℤ,ℝ}}) where{N} = ∂ℝ{N}
Base.show(io::IO,a::∂ℝ) = print(io,a.x," + ɛ⋅",a.dx,"\n")
variate(a::ℝ) = ∂ℝ(a ,[1.])
variate(a::ℝ1) = [∂ℝ(a[i],[i==j?1.:0. for j=ι(a,1)]) for i=ι(a,1)]
Base.:(+)(a::∂ℝ{N},b::∂ℝ{N}) where{N} = ∂ℝ{N}(a.x+b.x, a.dx.+b.dx)
Base.:(+)(a:: ℝ,b::∂ℝ{N}) where{N} = ∂ℝ{N}(a +b.x, b.dx)
Base.:(+)(a::∂ℝ{N},b:: ℝ) where{N} = ∂ℝ{N}(a.x+b , a.dx )
Base.:(*)(a::∂ℝ{N},b::∂ℝ{N}) where{N} = ∂ℝ{N}(a.x*b.x, a.dx.*b.x.+b.dx.*a.x)
Base.:(*)(a:: ℝ,b::∂ℝ{N}) where{N} = ∂ℝ{N}(a *b.x, b.dx.*a )
Base.:(*)(a::∂ℝ{N},b:: ℝ) where{N} = ∂ℝ{N}(a.x*b , a.dx.*b )
function ∘(a::Array{Float64,3},b::Vector{T}) where {T}
c = zeros(T,size(a,1),size(a,2))
for k=ι(a,3),j=ι(a,2),i=ι(a,1)
@inbounds c[i,j] += a[i,j,k]*b[k]
end
return c
end
function foo(a,b)
c = a∘b
for i = 1:10000
c = a∘b
end
return c
end
a = randn(3,4,5)
b = randn(5)
@time c = foo(a,b) # plain old Float64
@time ForwardDiff.jacobian(b->foo(a,b),b) # ForwardDiff - fast!
b = variate(b)
@time c = foo(a,b) # my stuff, "wasting ti'ime"
if false
Base.Profile.init(delay=0.001)
Base.Profile.clear()
Base.Profile.clear_malloc_data()
tic()
@profile c=foo(a,b)
toc()
Base.Profile.print()
using ProfileView
ProfileView.view()
end