Ideas for Saving Intermediate Results

Hello,

I’d appreciate any thoughts or feedback on how I could implement the following.

I have a long-running simulation (several days) I run on my university’s HPC. Till now I have been saving the results at the end of each run in .jld2 files using JLD2.jl.

It works, but I am also still developing and troubleshooting certain aspects of the simulation code. Sometimes the changes I make result in the run timing out on the cluster or crashing at some intermediate stage. It would be nice if I could save intermediate snapshots of the results, say every 1000 iterations or so out of 10,000 total iterations. That would help since then I could see what the calculations were doing (inspect the results afterwards) and I could use that data to restart the run if needed.

In the current implementation my run script looks essentially like this (I am using DrWatson)

function main()

  params = parse_commandline_args()
  input_data = prepare_input_data(params)
  result = solve(input_data, params)
  safesave("path_to_my_results", result)

end

main()

What I would like to do is somehow open a stream to a file, pass that stream to solve, and then inside solve use that stream to write to a file. If solve crashes or times out the stream cleans itself up and closes the file.

Maybe this is a stupid/impossible idea - I have no idea. Hoping others might have some insight. If possible, I’d like to avoid adding DrWatson as a dependency to my package code (which is where solve is defined).

1 Like

The idea is definitely not stupid and common, using checkpoints is always a good idea for long simulations.

Maybe the authors of DrWatson could give you good recommendations. I didn’t see any tool in there which would directly allow the checkpointing you aim at. Maybe one could use produce_or_load on a “each 1000 iterations” basis, but not sure if that is a good solution for you.

What I would like to do is somehow open a stream to a file, pass that stream to solve , and then inside solve use that stream to write to a file. If solve crashes or times out the stream cleans itself up and closes the file.

If you write to the system that much that you need to keep the file stream open, it might introduce IO-related performance issues. If the simulation runs for days, maybe saving every few hours (e.g. each 1000 iterations) could be sufficient? Then you can just open and close a file and not worry about errors making the file invalid.


My usual approach to this problem is to split the data even a bit further into

  • params (like you, parameters describing the simulation start),
  • state (full description of the current state, params and state are all needed to reproduce),
  • cache (auxiliary stuff needed to perform the simulation fast, e.g. temporary arrays, etc. Can be created with an init_cache(params, state) function.),
  • observables (data tracked as they are needed for analysis, but strictly speaking not relevant for the simulation).

(Of course, this structure is inspired by SciML projects which use the same or similar structure.)

Introducing the cache and some init_cache functions is relevant since it will make it much easier to implement an efficient do_step!(params, state, cache) function, which can be used to structure the solve function as

function solve(params, state, cache = init_cache(params, state))
   sol = [deepcopy(state)]
   for k in state.step:params.n_final
      do_step!(params, state, cache)
      if k % params.n_saveat
          push!(sol, deepcopy(state))
      end

      if k % params.n_checkpoint
          save(state, "checkpoint_k.jld2")
      end
   end
   return sol
end

With such a structure, you could directly start the simulation from a given checkpoint state, just by loading the corresponding state.

Make sure that

  • state contains a field for the current iteration (e.g. like a time-stamp)
  • cache might needs to be updated during a time-step?

Small recommendations: Since updating the cache struct type and restarting a session each time might be cumbersome during development, you could consider to either use NamedTuples or the ProtoStructs.jl package while prototyping.

3 Likes

Thank you, this is great feedback. I will have to take a look at how SciML does it.

Right now I have two main structs. An input struct that has fields for all the parameters, all the solution arrays, and also all the cache arrays.

At the end of the simulation I pass input to a result function that throws away everything that was auxiliary for the simulation (like cache arrays) and just returns the solution arrays and input parameters in a result struct which is what I save to file.

Based on your response, I think I may need to refactor my code a bit but I see how it could be made to work with a state. I think essentially I have that already in my input struct.

1 Like

I would recommend three things:

  • Create a state struct that makes your simulation fully resumable, including RNG and everything. (Best use julia Serialization.jl for this)

  • Write actual results to JLD2 to incrementally. Use the JLD2.jldopen api to append to the file - still only do this every few minutes or so. Writing at every iteration can overload the file system.

  • Linux HPC clusters (e.g. slurm ) can be configured to send your job system signals (SIGINT, SIGUSR2, etc.) prior to forced termination. Normally julia terminates immediately when receiving such signals but here’s an example for how you can override that behaviour and register custom callbacks.
    InterProcessCommunicationExt.jl

  • For structuring code it can be useful to maintain a set of callbacks that get executed on every iteration of your solver - that then themselves decide if they should do something or not. The Signal handling above is an example for that.

I would go low-tech on this: pass a path to solve, and save path/intermediate_calculations_0001.jld2 etc files. That way you are not relying on anyhing but the filesystem.

2 Likes

I’ve found the package DrWatson.jl very useful for saving and organizing this type of scientific results.

Thanks for the suggestions.

Can you please explain a little more why serialization is necessary or recommended here? If I have a struct that contains the snapshot of everything I need to resume the simulation, why couldn’t I just save that to file with JLD2?

Thanks I was not familiar with that API. Although, if I only care about the latest snapshot of the solution I guess I could just overwrite the previously saved snapshots and I wouldn’t need to append anything. Hopefully that makes sense.

That is neat and does seem like it could be useful but it is not jumping out at me how I’d apply it in this case. I am generally unfamiliar with callbacks as you describe them. I don’t think I have ever used them before (not knowingly anyway). I will have to look into them more to appreciate these recommendations.

Can you please explain a little more why serialization is necessary or recommended here? If I have a struct that contains the snapshot of everything I need to resume the simulation, why couldn’t I just save that to file with JLD2?

If it is “data”, then you can use JLD2 easily. However, it depends on what is contained in your resumable state backup. (e.g. anonymous functions cannot properly be stored inside JLD2 but with Serialization.jl they can).

Callbacks - or whatever one wants to call them - Example:
(This is not runnable but I hope the idea is understandable)

function iterative_computation(params, callback)
     # initialize some state struct 
     state = init(params, callback)
     # e.g. state.n = 0
     return iterative_computation!(state)
end

function iterative_computation!(state)
     for n = (state.n+1):state.params.large_number
          # do computation
          # update state variables
          state.n = n
          # hard coded backup loop
           if n % 1000 == 0
                   serialize("backup.jls", state)
           end

           # alternative using "callback"
           run_callback(state.callback, state)
     end
     return result
end


struct SimpleSaveCallback
     filename::String
     everyn::Int
end

function run_callback(cb::SimpleSaveCallback, state)
    if state.n % cb.everyn == 0
        serialize(cb.filename, state)
    end
end

# 
iterative_computation(params, SimpleSaveCallback("backup.jls", 1000))
# program was killed so do
# load state and resume
state = deserialize("backup.jls")
iterative_computation!(state)

2 Likes