When calling PyTorch using PyCall or pythoncall, I run out of memory on GPU cards

Hi All,

I am inferfacing to PyTorch using PythonCall (or PyCall) with the data moved between Julia and Python by DLPack.jl. I use PyTorch to compute gradients and the rest (e.g. optimization) is handled by Julia, since I am way more familiar with it. After few iterations, PyTorch crashes as he runs out of memory on GPU. I move data between Python and Julia allocated on cpu, not gpu, therefore Julia should not be responsible for memory management in GPU. Has anyone encountered similar problems?
I can prepare an MWE, but that would take me some time and I wanted to ask first.

Thanks a lot in advance.

I need to manually call GC.gc() to free the memory. I guess the problem is that Julia is holding objects which prevents PyTorch to free the memory.