I am training unet with 3d input size 512512128. However, with the same model, same loss function, same size of input and label in float32, python indeed can fit and train while julia not. Also, julia need much longer time to prepare for the training. Are there some problems or it’s just the nature of juliia?
I’m not quite familiar with Lux but rand(512,512,128,1,1)
creates a Float64 array.
Does dev
convert it to Float32 in case of the GPU? Did you check?
Did you also close the Python process before such that Julia and Python don’t have to share memory?
1 Like
It would be helpful to share the code for the PyTorch model and the Julia model
PyTorch:
class UNet(nn.Module)
....
Julia:
function UNet()
return Chain(...)
end
I converted all the arrays to Float32 and then |> dev and tried again. Still no luck. And yes, I made sure the python process was closed before running Julia. I’m watching nvdia-smi every 2 seconds.
Please share a MWE so that others can help debug better. Without both models and the full training loops it’s hard to know where to start
2 Likes