Function that fixes an argument of another function

I need to define a function that fixes the value of another function’s argument. One way is the following:

using InvertedIndices, PreallocationTools, BenchmarkTools
function construct_fixed_f(g, n, val, cache)
    new_f = @inline (x, p) -> begin
        cache2 = get_tmp(cache, x)
        cache2[Not(n)] .= x
        cache2[n] = val
        return g(cache2, p)
    end
end
_f = (x, p) -> x[1] * x[2] * x[3] + p[1] * p[2]
_p = [2.0, 3.0]
_n = 2
_val = 2.0
_x = rand(3)
cache = DiffCache(_x)
fixed_f = construct_fixed_f(_f, _n, _val, cache) # _f([x[1], 2.0, x[3]], p)

But this seems to be very slow. A comparison of evaluating the returned function vs. the original function:

test_x_fix =[2.7, 3.7]
test_x = [2.7, _val, 3.7]
@benchmark $_f($test_x, $_p)
@benchmark $fixed_f($test_x_fix, $_p)
julia> @benchmark $_f($test_x, $_p)
BenchmarkTools.Trial: 10000 samples with 1000 evaluations.
 Range (min … max):  2.300 ns … 25.700 ns  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     2.400 ns              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   2.445 ns Β±  0.513 ns  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  β–†  β–ˆ   β–ƒ                                                   ▁
  β–ˆβ–β–β–ˆβ–β–β–β–ˆβ–β–β–ˆβ–β–β–β–‡β–β–β–β–ˆβ–β–β–†β–β–β–β–‡β–β–β–β–ˆβ–β–β–‡β–β–β–β–‡β–β–β–‡β–β–β–β–‡β–β–β–β–‡β–β–β–†β–β–β–β–†β–β–β–… β–ˆ
  2.3 ns       Histogram: log(frequency) by time      3.9 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

julia> @benchmark $fixed_f($test_x_fix, $_p)
BenchmarkTools.Trial: 10000 samples with 197 evaluations.
 Range (min … max):  452.792 ns … 62.085 ΞΌs  β”Š GC (min … max): 0.00% … 98.73%
 Time  (median):     470.051 ns              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   593.130 ns Β±  1.571 ΞΌs  β”Š GC (mean Β± Οƒ):  6.93% Β±  2.61%

  β–ˆβ–…β–ƒβ–β–  ▁▅▅▅▃▂                                                ▁
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–‡β–‡β–†β–‡β–†β–†β–†β–†β–‡β–†β–†β–†β–…β–„β–…β–…β–…β–…β–ƒβ–†β–…β–…β–…β–†β–…β–…β–…β–‡β–†β–…β–†β–†β–…β–…β–…β–„β–…β–†β–„β–…β–…β–„β–†β–† β–ˆ
  453 ns        Histogram: log(frequency) by time      1.38 ΞΌs <

 Memory estimate: 496 bytes, allocs estimate: 15.

Why the big difference? What’s the best way to do this?

The performance problem here is not in the function wrapper, but in the correction code. When you return new_f, it’s defined so the argument correction happens at the call time, but not at the construction time. See benchmarks here:

function construct_fixed_identity(g, n, val, cache)
    return @inline (x, p) -> begin
        return g(get_tmp(cache, x), p)
    end
end

function construct_fixed_preprocessing(g, n, val, cache)
    return @inline (x, p) -> begin
        cache2 = get_tmp(cache, x)
        cache2[Not(n)] .= x
        cache2[n] = val
        return cache2
    end
end

fixed_f_identity = construct_fixed_identity(_f, _n, _val, cache)
fixed_f_preprocessing = construct_fixed_preprocessing(_f, _n, _val, cache)
julia> @benchmark $fixed_f_identity($test_x_fix, $_p)
BenchmarkTools.Trial: 10000 samples with 125 evaluations.
 Range (min … max):  751.912 ns … 30.858 ΞΌs  β”Š GC (min … max): 0.00% … 96.90%
 Time  (median):     877.280 ns              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   928.744 ns Β±  1.035 ΞΌs  β”Š GC (mean Β± Οƒ):  4.00% Β±  3.49%

@benchmark $fixed_f_preprocessing($test_x_fix, $_p)
BenchmarkTools.Trial: 10000 samples with 1000 evaluations.
 Range (min … max):  4.824 ns … 29.725 ns  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     4.848 ns              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   4.974 ns Β±  0.964 ns  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

I’m not sure what your correction does, but if you provide more details we can think on how to optimize that.

I’m not sure I 100% follow. What do you mean by correction?

Regarding what the code is trying to do: The function is trying to fix x[n], using a cache cache to allocate the full vector (with cache2[n] fixed at val each time). get_tmp is used from PreallocationTools to allow for automatic differentiation. I want it to return a function rather than just evaluate it directly so that it can be used in e.g. an OptimizationFunction.

By correction I mean the part that fixes the arguments:

cache2 = get_tmp(cache, x)
cache2[Not(n)] .= x
cache2[n] = val

This are the three lines you need to optimize

1 Like

Thanks. You were right. I was overthinking it, the following function does the job:

function construct_fixed_f(g, n, val, cache)
    new_f = @inline (x, p) -> begin
        cache2 = get_tmp(cache, x)
        for i in eachindex(cache2)
            if i < n
                cache2[i] = x[i]
            elseif i == n
                cache2[i] = val
            elseif i > n
                cache2[i] = x[i-1]
            end
        end
        return g(cache2, p)
    end
end