Hipshot without having run the code:
Iirc the reparametrization trick is just to avoid taking the gradient of generating random numbers, like you do here: eps = rand(MvNormal(vec(fill(0., 32)), vec(std)))
. You probably need to tell Zygote to not try to differentiate that line, e.g. using @nograd
.
Another issue with the model is that all those anonymous functions are not functors and therefore their parameters will not be captured by params
. I’m also uncertain if Flux.Params(model)
does the same thing as params(model)
. In either case, you need to give params
all the layers with parameters you want to train. I think that `params([encoder, eH1, eH2, decoder]) should work.