Hi all!
I just stumbled over this post since we’ve got a similar problem where we want to fix certain parameters.
The main piece of code is:
function forward_with_fixed_params(forward, params)
forward_fixed(x) = forward((; params..., x...))
return forward_fixed
end
Is there anything wrong with that?
julia> forward((a, b, c)) = a+b+c^C
julia> forward(x) = x.a+x.b+x.c
forward (generic function with 1 method)
julia> forward((;a=0, b=10, c=10))
20
julia> fwd_n = forward_with_fixed_params(forward, (;a=10, b=10, c=10))
(::var"#forward_fixed#1"{typeof(forward), NamedTuple{(:a, :b, :c), Tuple{Int64, Int64, Int64}}}) (generic function with 1 method)
julia> fwd_n((;a=100))
120
julia> @code_warntype fwd_n((;a=100))
MethodInstance for (::var"#forward_fixed#1"{typeof(forward), NamedTuple{(:a, :b, :c), Tuple{Int64, Int64, Int64}}})(::NamedTuple{(:a,), Tuple{Int64}})
from (::var"#forward_fixed#1")(x) in Main at REPL[1]:2
Arguments
#self#::var"#forward_fixed#1"{typeof(forward), NamedTuple{(:a, :b, :c), Tuple{Int64, Int64, Int64}}}
x::NamedTuple{(:a,), Tuple{Int64}}
Body::Int64
1 ─ %1 = Core.getfield(#self#, :forward)::Core.Const(forward)
│ %2 = Base.NamedTuple()::Core.Const(NamedTuple())
│ %3 = Core.getfield(#self#, :params)::NamedTuple{(:a, :b, :c), Tuple{Int64, Int64, Int64}}
│ %4 = Base.merge(%2, %3)::NamedTuple{(:a, :b, :c), Tuple{Int64, Int64, Int64}}
│ %5 = Base.merge(%4, x)::NamedTuple{(:a, :b, :c), Tuple{Int64, Int64, Int64}}
│ %6 = (%1)(%5)::Int64
└── return %6
julia> @time fwd_n((;a=100))
0.000004 seconds
120