Saving and loading architectures with multiple blocks

Hi!

I’m working on a project where another person has been trying out different architectures in Flux.jl for a given problem, and the one that turned out to be the best one is giving us problems when saving and loading through BSON.jl’s @save and @load macros, leading to segmentation faults and core dumps when trying to do a forward pass, but not when just calling the model by variable on the terminal, for example. I have to find a way to make this model usable for production, and I was hoping to be able to do it without having to completely refactor it in a more “julian” way. Retraining it seems inevitable, due to the current bson files being unusable, but it takes just one day to do it on our gpu’s.

I cannot post the full model for disclosure reasons, but I’ll try to elucidate. Initially we thought this might have been related with not properly moving blocks out of the gpu and into the cpu before saving, but this problem persists even when just using the cpu. We don’t need to train anything to show the problem here, since initializing the blocks and saving them already gives us the problem.

The way he wrote the training script has blocks defined as different global variables on the script that look like this kind of this MWE:

using Flux
using Random
using BSON: @save, @load


inputpoints = 24 * 7
auxfeatures = 3 # one main feature, 3 aux features
samples = 2000
labelpoints = 24 * 2
inputs = randn(Float32, inputpoints, 1 + auxfeatures, samples)


# main feature block
mainblock = Chain(
    x->x[begin:inputpoints, 1, :],
    Dense(inputpoints, labelpoints) # test with bias
)


# aux features block
auxblock =  Chain(
    x -> x[:, (begin+1):end, :],
    Flux.flatten,
    Dense(auxfeatures * inputpoints, labelpoints)
)


# model
model = Parallel(+, mainblock, auxblock)


@save "testsavemodel.bson" model

This works if you do model(inputs).

Now let’s kill the session, start a new one and try loading the model and do the same:

> 
> Process julia finished
>                _
>    _       _ _(_)_     |  Documentation: https://docs.julialang.org
>   (_)     | (_) (_)    |
>    _ _   _| |_  __ _   |  Type "?" for help, "]?" for Pkg help.
>   | | | | | | |/ _` |  |
>   | | |_| | | | (_| |  |  Version 1.9.3 (2023-08-24)
>  _/ |\__'_|_|_|\__'_|  |  Official https://julialang.org/ release
> |__/                   |
> 
> julia> using Flux
> 
> julia> using Random
> 
> julia> using BSON: @save, @load
> 
> julia> @load "testsavemodel.bson" model
> 
> julia> inputpoints = 24 * 7
> 168
> 
> julia> auxfeatures = 3 # one main feature, 3 aux features
> 3
> 
> julia> samples = 2000
> 2000
> 
> julia> labelpoints = 24 * 2
> 48
> 
> julia> inputs = randn(Float32, inputpoints, 1 + auxfeatures, samples)
> 168×4×2000 Array{Float32, 3}:
> [:, :, 1] =
>  -0.9374        0.321614    0.200769   -1.34398
>  -0.980804      0.398244   -0.0601651  -1.32854
>  -0.335474     -1.58472     0.448912    0.222036
>  -0.779917      0.603528    0.57162    -0.388875
>  -0.406599      0.834919    0.226975   -2.17507
>   0.0446723    -0.944999    0.365482    0.0676495
>  -0.544764     -0.985228   -0.0551741  -0.368832
>   0.316947     -0.157169   -1.85628    -0.498152
>  -0.733956     -0.211746   -1.32181     2.2329
>  -0.226041     -0.696468   -0.535073   -1.92551
>  -0.680718     -1.07135     0.437113    1.34637
>  -0.687338      1.2534      0.0686588   0.369953
>  -0.827495      0.371375    0.588788    0.601197
>   0.718233     -2.11922     0.258975   -0.821273
>   0.458899     -1.42513     1.18537     1.44016
>  -0.50053       0.147473    0.716762   -1.08765
>  -0.709397     -0.734699   -0.542565    0.130984
>   0.366745      0.65585    -1.06062     1.70048
>   0.0459199    -0.224804    0.565536   -0.591385
>   ⋮                                    

(the rest of the array printing...)

> 
> julia> model
> Parallel(
>   +,
>   Chain(
>     BSON.__deserialized_types__.var"#3#4"(),
>     Dense(168 => 48),                   # 8_112 parameters
>   ),
>   Chain(
>     BSON.__deserialized_types__.var"#5#6"(),
>     Flux.flatten,
>     Dense(504 => 48),                   # 24_240 parameters
>   ),
> )                   # Total: 4 arrays, 32_352 parameters, 126.688 KiB.

We get the segmentation fault when we try to do a forward pass:

julia> model(inputs)

[14132] signal (11.1): Segmentation fault
in expression starting at REPL[11]:1
jl_binding_boundp at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/module.c:740
abstract_eval_globalref at ./compiler/abstractinterpretation.jl:2439 [inlined]
abstract_eval_special_value at ./compiler/abstractinterpretation.jl:2135
abstract_eval_value at ./compiler/abstractinterpretation.jl:2145 [inlined]
collect_argtypes at ./compiler/abstractinterpretation.jl:2154
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2176
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2396
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2682
typeinf_local at ./compiler/abstractinterpretation.jl:2867
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2955
_typeinf at ./compiler/typeinfer.jl:246
typeinf at ./compiler/typeinfer.jl:216
typeinf_edge at ./compiler/typeinfer.jl:932
abstract_call_method at ./compiler/abstractinterpretation.jl:611
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:152
abstract_call_known at ./compiler/abstractinterpretation.jl:1949
jfptr_abstract_call_known_16818.clone_1 at /home/tomas/packages/julias/julia-1.9.3/lib/julia/sys.so (unknown line)
tojlinvoke21153.clone_1 at /home/tomas/packages/julias/julia-1.9.3/lib/julia/sys.so (unknown line)
j_abstract_call_known_16304.clone_1 at /home/tomas/packages/julias/julia-1.9.3/lib/julia/sys.so (unknown line)
abstract_call at ./compiler/abstractinterpretation.jl:2020
abstract_call at ./compiler/abstractinterpretation.jl:1999
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2183
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2396
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2658
typeinf_local at ./compiler/abstractinterpretation.jl:2867
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2955
_typeinf at ./compiler/typeinfer.jl:246
typeinf at ./compiler/typeinfer.jl:216
typeinf_edge at ./compiler/typeinfer.jl:932
abstract_call_method at ./compiler/abstractinterpretation.jl:611
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:152
abstract_call_known at ./compiler/abstractinterpretation.jl:1949
jfptr_abstract_call_known_16818.clone_1 at /home/tomas/packages/julias/julia-1.9.3/lib/julia/sys.so (unknown line)
tojlinvoke21153.clone_1 at /home/tomas/packages/julias/julia-1.9.3/lib/julia/sys.so (unknown line)
j_abstract_call_known_16304.clone_1 at /home/tomas/packages/julias/julia-1.9.3/lib/julia/sys.so (unknown line)
abstract_call at ./compiler/abstractinterpretation.jl:2020
abstract_call at ./compiler/abstractinterpretation.jl:1999
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2183
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2396
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2682
typeinf_local at ./compiler/abstractinterpretation.jl:2867
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2955
_typeinf at ./compiler/typeinfer.jl:246
typeinf at ./compiler/typeinfer.jl:216
typeinf_edge at ./compiler/typeinfer.jl:932
abstract_call_method at ./compiler/abstractinterpretation.jl:611
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:152
abstract_call at ./compiler/abstractinterpretation.jl:2017
abstract_call at ./compiler/abstractinterpretation.jl:1999
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2183
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2396
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2682
typeinf_local at ./compiler/abstractinterpretation.jl:2867
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2955
_typeinf at ./compiler/typeinfer.jl:246
typeinf at ./compiler/typeinfer.jl:216
typeinf_edge at ./compiler/typeinfer.jl:932
abstract_call_method at ./compiler/abstractinterpretation.jl:611
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:152
abstract_call at ./compiler/abstractinterpretation.jl:2017
abstract_call at ./compiler/abstractinterpretation.jl:1999
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2183
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2396
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2682
typeinf_local at ./compiler/abstractinterpretation.jl:2867
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2955
_typeinf at ./compiler/typeinfer.jl:246
typeinf at ./compiler/typeinfer.jl:216
typeinf_edge at ./compiler/typeinfer.jl:932
abstract_call_method at ./compiler/abstractinterpretation.jl:611
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:152
abstract_call_known at ./compiler/abstractinterpretation.jl:1949
jfptr_abstract_call_known_16818.clone_1 at /home/tomas/packages/julias/julia-1.9.3/lib/julia/sys.so (unknown line)
tojlinvoke21153.clone_1 at /home/tomas/packages/julias/julia-1.9.3/lib/julia/sys.so (unknown line)
j_abstract_call_known_16304.clone_1 at /home/tomas/packages/julias/julia-1.9.3/lib/julia/sys.so (unknown line)
abstract_call at ./compiler/abstractinterpretation.jl:2020
abstract_call at ./compiler/abstractinterpretation.jl:1999
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2183
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2396
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2682
typeinf_local at ./compiler/abstractinterpretation.jl:2867
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2955
_typeinf at ./compiler/typeinfer.jl:246
typeinf at ./compiler/typeinfer.jl:216
typeinf_ext at ./compiler/typeinfer.jl:1057
typeinf_ext_toplevel at ./compiler/typeinfer.jl:1090
typeinf_ext_toplevel at ./compiler/typeinfer.jl:1086
jfptr_typeinf_ext_toplevel_20405.clone_1 at /home/tomas/packages/julias/julia-1.9.3/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2758 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2940
jl_apply at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/julia.h:1880 [inlined]
jl_type_infer at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:320
jl_generate_fptr_impl at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/jitlayers.cpp:444
jl_compile_method_internal at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2348 [inlined]
jl_compile_method_internal at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2237
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2750 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2940
jl_apply at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/julia.h:1880 [inlined]
do_call at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/interpreter.c:126
eval_value at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/interpreter.c:226
eval_stmt_value at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/interpreter.c:177 [inlined]
eval_body at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/interpreter.c:624
jl_interpret_toplevel_thunk at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/interpreter.c:762
jl_toplevel_eval_flex at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/toplevel.c:912
jl_toplevel_eval_flex at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/toplevel.c:856
jl_toplevel_eval_flex at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/toplevel.c:856
jl_toplevel_eval_flex at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/toplevel.c:856
ijl_toplevel_eval_in at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/toplevel.c:971
eval at ./boot.jl:370 [inlined]
eval_user_input at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/usr/share/julia/stdlib/v1.9/REPL/src/REPL.jl:153
repl_backend_loop at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/usr/share/julia/stdlib/v1.9/REPL/src/REPL.jl:249
#start_repl_backend#46 at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/usr/share/julia/stdlib/v1.9/REPL/src/REPL.jl:234
start_repl_backend at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/usr/share/julia/stdlib/v1.9/REPL/src/REPL.jl:231
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2758 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2940
#run_repl#59 at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/usr/share/julia/stdlib/v1.9/REPL/src/REPL.jl:379
run_repl at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/usr/share/julia/stdlib/v1.9/REPL/src/REPL.jl:365
jfptr_run_repl_60908.clone_1 at /home/tomas/packages/julias/julia-1.9.3/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2758 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2940
#1017 at ./client.jl:421
jfptr_YY.1017_36106.clone_1 at /home/tomas/packages/julias/julia-1.9.3/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2758 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2940
jl_apply at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/julia.h:1880 [inlined]
jl_f__call_latest at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/builtins.c:774
#invokelatest#2 at ./essentials.jl:819 [inlined]
invokelatest at ./essentials.jl:816 [inlined]
run_main_repl at ./client.jl:405
exec_options at ./client.jl:322
_start at ./client.jl:522
jfptr__start_40034.clone_1 at /home/tomas/packages/julias/julia-1.9.3/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2758 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2940
jl_apply at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/julia.h:1880 [inlined]
true_main at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/jlapi.c:573
jl_repl_entrypoint at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/jlapi.c:717
main at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/cli/loader_exe.c:59
__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x4010b8)
Allocations: 12462189 (Pool: 12456394; Big: 5795); GC: 17

Process julia segmentation fault (core dumped)

I suspect this is mostly related to this way of declaring blocks outside as separate global variables in the script. Any thoughts?

No suggestions? I’m going to try saving it with JLD2.jl instead of BSON.jl.

Also opened an issue in the repo: Segmentation fault when doing a forward pass with a model saved with BSON · Issue #2339 · FluxML/Flux.jl · GitHub

In case anyone drops here in the future, this has to do with the anonymous functions. We can follow the documentation and use JLD2.jl instead of BSON.jl. Keep in mind we are forced to redefine the architecture when loading, as we will only load the state of the model. I’m going to paste the solution scripts here (just as in the issue).

Saving:

### SAVING TEST


using Flux
using Random
using JLD2

inputpoints = 24 * 7
auxfeatures = 3 # one main feature, 3 aux features
samples = 2000
labelpoints = 24 * 2
inputs = randn(Float32, inputpoints, 1 + auxfeatures, samples)

# main feature block
mainblock = Chain(
    x->x[begin:inputpoints, 1, :],
    Dense(inputpoints, labelpoints)
)

# aux features block
auxblock =  Chain(
    x -> x[:, (begin+1):end, :],
    Flux.flatten,
    Dense(auxfeatures * inputpoints, labelpoints)
)


struct TestModel
    architecture
end

Flux.@functor TestModel
TestModel() = TestModel(Parallel(+, mainblock, auxblock))
model = TestModel()
model_state = Flux.state(model)

jldsave("testsavemodel_1.jld2"; model_state)

New session, loading:

### LOADING TEST


using Flux
using Random
using JLD2


inputpoints = 24 * 7
auxfeatures = 3 # one main feature, 3 aux features
samples = 2000
labelpoints = 24 * 2
inputs = randn(Float32, inputpoints, 1 + auxfeatures, samples)

# main feature block
mainblock = Chain(
    x->x[begin:inputpoints, 1, :],
    Dense(inputpoints, labelpoints)
)

# aux features block
auxblock =  Chain(
    x -> x[:, (begin+1):end, :],
    Flux.flatten,
    Dense(auxfeatures * inputpoints, labelpoints)
)


struct TestModel
    architecture
end

Flux.@functor TestModel
TestModel() = TestModel(Parallel(+, mainblock, auxblock))
model = TestModel()
model_state = JLD2.load("testsavemodel_1.jld2", "model_state");
Flux.loadmodel!(model, model_state)

A forward pass in now successful:

julia> test_fwdpass = model.architecture(inputs)
48×2000 Matrix{Float32}:
  4.44022      0.802483  -0.785774    1.168      …  -2.20038    -2.76761    -2.75335
 -1.81597      3.47688    0.0940425  -0.837038       0.155797    2.5794     -0.0938851
  2.21066     -0.4048     0.903007    0.167684       1.19697    -1.00276    -1.79072
  0.929377     1.1883    -1.82898     1.01884       -0.725962    1.04085     2.17898
  0.00539947   1.49683    1.25519     1.50682        2.72747     0.716122    2.52785
  1.03204      3.22989   -1.66981    -0.999194   …  -0.215202   -1.27665     0.376921
  2.51096      2.41828    0.436551    0.517585       1.67277     0.609859   -1.54591
 -1.34194     -0.228893   1.87149    -0.986849      -0.191224    0.687425    2.22133
  1.9604       0.951124   1.43568    -0.238653      -1.622       4.54916    -3.99599
 -0.993577    -2.96885   -1.70936    -0.713654       1.94885    -1.54148     0.403749
  0.18666      0.834455   2.35449     1.00192    …  -0.136148    0.861816   -1.7685
 -2.36995      1.94883    1.31425    -1.37012        1.78269    -1.19305     0.525236
 -0.556477     0.447952  -0.959529    0.850635      -1.19533    -0.692481   -1.17249
  2.14281      0.17941   -0.65601    -3.38384       -0.336295    0.250721   -0.866344
  2.52481      3.07921   -0.58382    -0.656336      -0.994389   -0.602142    0.530116
  1.2512       0.877351  -0.74357    -0.797333   …   3.61359    -1.4924     -2.77331
  0.0869287   -0.671315  -0.128169    1.9544        -0.242938    0.586071   -0.168547
  1.09363      0.708124  -1.0453      2.32946        5.08991    -3.25003     0.0925286
 -0.548058    -0.681359   0.0118403  -3.75676       -1.88147     0.104736    0.480259
  ⋮                                              ⋱                          
 -0.416707     3.67179   -2.48939     1.52213       -0.776104   -0.346431    1.32079
  0.655315    -0.415754  -1.45568     0.0851394  …   4.02886     2.77373     2.17698
 -0.317264    -0.439673  -0.530158   -0.837444       0.284554   -1.00613    -0.366141
 -0.296634    -1.96891   -2.48071     2.27509       -0.6101     -0.508833   -1.74481
  1.96883      1.32886   -0.969475   -1.23352       -3.45104     2.03444     1.31539
  3.81404      1.32852    2.34517    -2.12479        1.67277     0.0501646   1.32144
 -0.0490075   -0.218952   2.18        3.05685    …  -0.44117    -2.41891    -1.35152
  1.33143     -0.689682  -1.03449    -0.0169412     -0.773172   -2.20266    -2.73936
 -1.35926     -0.917676   4.6618     -1.13945       -3.41797     0.761221    0.333108
 -0.225759     0.278201   1.78722    -0.131045      -2.63882     0.433773   -2.62248
  4.81922     -0.870089  -4.80774    -1.5178         0.123205   -2.02181     1.56211
  0.127518     0.723261   3.8712     -0.400356   …   0.197132   -3.68057     2.66511
 -1.14512     -0.829157   0.0856611   0.0258443     -0.740243   -1.0791      0.617436
  1.65157     -2.93585    0.989425    0.754669       0.606092   -1.09547    -1.23846
 -0.490223    -0.190012   2.91653     1.45833       -0.137385   -2.23218    -1.20121
 -1.37658     -4.13181    1.79136    -3.11379       -1.21975     0.521379    1.01322
 -0.952514     1.02663   -0.793957   -1.69722    …  -0.0394366  -4.34157     4.12784
 -0.515579    -1.04139   -2.13667     1.92703       -0.915622    2.50567    -3.46607
  3.25076     -2.62687    0.576621    1.19447        4.88387     0.0299822  -0.749113
1 Like