How to store _parameteric_ functions with JLD2?

I have a struct (part of a ML model in the real case) where one field is a function.
For the function itself, as the function is provided by the package, I have no issues saving and loading it back.

However, my problem is when the function is parametric.

E.g.

using JLD2
scaling_par= 0.1
f(x,scaling_par=scaling_par) = x/scaling_par # this is a small function can can be exported and made available by my package
struct Foo
    a::Int
    b::Function
end
foo = Foo(1, x->f(x,0.01))
r = foo.b(2)
save("foo.jld2", "foo", foo)

Loading back this struct will rise an error, as anonymous functions are not supported in JLD2.

I don’t really have much control on the struct Foo, but I can save into it and use it as I prefer. Which alternative approach may I use to store and use also the information about the scaling parameter ?

Soimething on these lines may work in my situation, even if I am not super-happy to use non-local vars in a function…

File definitions.jl:

module Foomodule

export scaling_par, f, Foo 
scaling_par::Float64 = 1.0
f(x) = x/scaling_par 
struct Foo
    a::Int
    b::Function
end

end

File writer.jl:

using Pkg
cd(@__DIR__)
Pkg.activate(".")
using JLD2
includet("definitions.jl")
using .Foomodule

Foomodule.scaling_par = 0.001 
foo = Foo(1, f)
r = foo.b(2) # 2000
write("scaling_parameter.txt", string(Foomodule.scaling_par))
save("foo.jld2", "foo", foo)

File reader.jl (on a different Julia session):

using Pkg
cd(@__DIR__)
Pkg.activate(".")
using JLD2

includet("definitions.jl")
using .Foomodule

Foomodule.scaling_par = parse(Float64,readline("scaling_parameter.txt")) 
foo2 = load("foo.jld2", "foo")
r2 = foo2.b(2) # 2000

If you need to save it, you could use a more explicit data structure instead of an anonymous function. In this case, Base provides such a type for you: you can replace x->f(x,0.01) with Base.Fix2(f, 0.01).

3 Likes

Thank you. This works and simplify my code:

Writer:

using Pkg
cd(@__DIR__)
Pkg.activate(".")
using JLD2

scaling_par::Float64 = 1.0
f(x,scaling_par) = x/scaling_par 
struct Foo
    a::Int
    b::Function
end
foo = Foo(1, Base.Fix2(f, 0.001))
r = foo.b(2) # 2000
save("foo.jld2", "foo", foo)

Reader (after julia restart):

using Pkg
cd(@__DIR__)
Pkg.activate(".")
using JLD2

f(x,scaling_par) = x/scaling_par 
struct Foo
    a::Int
    b::Function
end

foo2 = load("foo.jld2", "foo")
r2 = foo2.b(2) # 2000

I think, this should be parametrized over b::F if performance is important

1 Like

Thank you. Indeed it is in the actual implementation :slightly_smiling_face: