I just scratched my head for a while because my super fast reconstruction algorithm suddenly took 100 times longer than usual after playing around with the quality function. Then I realised that the abs()
function inside the function which is being minimised is causing issues. Now, the only reason I can think of is that probably some dispatch goes wrong and the automatic differentiation is causing type inference issues.
Before I start to dive in and create an MWE: is this a known issue?
Edit: ehm, actually the following workaround is almost as slow as abs()
, so I guess the <
operator is the problem 
a = a < 0 ? -a : a
have you tried changing 0
for zero(T)
? maybe it’s a dispatch problem
what’s the speed of sqrt(a^2)
?
What auto diff package are you using? I recently had issues with ForwardDiff returning Any
.
@longemen3000 I’ll try with sort
, zero(T)
has no effect as I am using regular floats.
@sebastianpech I am using JuMP
and no idea what’s used, I guess it’s ForwardDiff.jl
but I have to look it up…
I cannot reproduce it… This is a MNWE which should only contain the relevant parts. I did everything the same way, as far as I can tell but the following snippet takes the same time with or without the <
-code block.
using JuMP
using Ipopt
using BenchmarkTools
using StaticArrays
using ProgressMeter
struct Bar{T} <: FieldVector{2, T}
s::T
t::T
end
dosomethingwithbar(bar::Bar) = bar.s + bar.t
struct FooMinimiser <: Function
x::Vector{Float64}
end
function (f::FooMinimiser)(a)
bar = Bar(f.x[1], f.x[2])
z = dosomethingwithbar(bar)
# This is the block which in my case causes a ~200x slowdown, where both `z` and `a` are `Float64`
# if a < z
# a .+ 0.01
# end
sum((f.x .- a).^2)
end
function loop(;n=10000)
results = Vector{Float64}()
@showprogress 1 for _ in 1:n
x = rand(Float64, 1000)
foo = FooMinimiser(x)
model = Model(with_optimizer(Ipopt.Optimizer, print_level=0))
register(model, :foo, 1, foo, autodiff=true)
@variable(model, 0 <= a <= 1)
@NLobjective(model, Min, foo(a))
optimize!(model)
push!(results, value(a))
end
results
end
loop()