How to Manage Memory with Sequential, GPU-Intensive (e.g., PyTorch) Python Calls via PythonCall.jl

My Question

I’m running into an issue where a long-running Python function seems to be terminated mid-execution. My question is:

Are there any lower-level features or design patterns within PythonCall.jl that could help with this problem? For example, is there a way to periodically run a cleanup function (like torch.cuda.empty_cache()) or something similar to manage resources during a single, long-running pycall?

The Scenario

I’m using PythonCall to interface with TotalSegmentator, a deep learning library built on PyTorch. The library has a main function, totalsegmentator(), which, when run in its high-resolution mode (fast=false), executes a sequence of 5 different neural network models internally to segment a full CT scan.

My initial, straightforward code looks like this:

Simple (Not Working) Script:

using Pkg
Pkg.activate(".")
Pkg.instantiate()

import CUDA
CUDA.functional()

using PythonCall
np, pydicom, totalsegmentator, nib = pyimport(
    "numpy",
    "pydicom",
    "totalsegmentator.python_api",
    "nibabel"
)

# Assume `nifti_image` is a valid NIfTI object loaded from DICOMs
# e.g., Python: <nibabel.nifti1.Nifti1Image object at 0x7fee3d265670>
# nifti_image = create_nifti_from_dicom(...)

# This single call should run the full segmentation
segmentation_result = totalsegmentator.totalsegmentator(
    nifti_image,
    fast=false,
    device="gpu",
    ml=true,
    quiet=false
)

# This line is never reached
println("Segmentation complete!")

Observed Behavior

When I run this on a machine with an NVIDIA A100 80GB GPU, the script prints the first stage of the Python tool’s output and then hangs indefinitely:

Resampling...
  Resampled in 1.91s
Predicting part 1 of 5 ...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 12.29it/s]
# --- HANGS HERE ---

By monitoring nvidia-smi, I observed the following:

  1. The Julia process (e.g., PID 258497) is running and using a baseline amount of VRAM.
  2. The moment the prediction starts, the GPU utilization spikes to 100%, but the Julia process disappears from the process list.
  3. The GPU continues to run at 100% for a short time with no parent process before finishing its work.

I’ve also noted that since the hang happens during the single totalsegmentator() call, a cleanup function after the call wouldn’t execute. Furthermore, I’ve observed that calling Julia’s GC.gc() can sometimes free the underlying Python object for my nifti_image input, which could cause issues in other contexts.