Running MNIST from model-zoo on Flux v0.10.0/Zygote/GPU


I am trying to run the MNIST digit recognition from model-zoo using the latest Flux version with Zygote on a GPU. I had to make some tweaks to the code to make it run: One strange pattern that I observe is that after training for a while and almost converging (test accuracy > 0.99), all of a sudden the accuracy drops to < 0.1 and stays that way. Inspecting model parameters, they are full of NaNs. If I save the model just before it turns into NaNs, and compute its gradients, NaNs are returned undeterministically (i.e. sometimes it returns a valid gradient, and sometimes - NaNs ).

I think this might be related to me running it on a GPU. I have not observed this behaviour on a CPU. Interestingly, the learning curve on the CPU looks completely different from that on GPU (in addition to being much slower, which is expected). E.g. on GPU the accuracy after the first epoch is >0.9, on CPU it is <0.11.

Has anyone else run into this? Maybe suggest a workaround?

@MikeInnes @dhairyagandhi96

Thanks in advance.

julia> versioninfo()
Julia Version 1.3.0
Commit 46ce4d7933 (2019-11-26 06:09 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Xeon(R) CPU E5-2640 0 @ 2.50GHz
  LIBM: libopenlibm
  LLVM: libLLVM-6.0.1 (ORCJIT, sandybridge)
  JULIA_DEPOT_PATH = /data/.julia
  JULIA_EDITOR = atom  -a
$ nvidia-smi -L
GPU 0: Quadro P2000
$ nvidia-smi | grep SMI
| NVIDIA-SMI 418.87.00    Driver Version: 440.44       CUDA Version: 10.2     |

It would be nice if you could arrange those fixes into a PR into the model-zoo.

The non-deterministic NaN on the GPU could be due to the non-deterministic order in which things are evaluated during parallel execution. Sounds like a deeper issue (not an error on the model-zoo code, but rather on Flux itself).

There are a few PRs against model-zoo (including by myself) that are not being merged. I’d rather get some response from the maintainers before raising more.

Re ‘deeper issue’ - yes, I understand that the issue is most likely with Flux. I am asking if anyone else has seen it, and maybe has a workaround.