Gradient penalty problems with relu

I am trying to implement the gradient penalty of WGAN-GP in Flux and Zygote, but am getting an error when I have relus in my model. See the following MWE. If I switch the relu to a tanh then it works fine. Thoughts?

using Zygote
using Flux

idim = 5
batch_size = 8
m = Chain(Dense(idim, 2*idim, relu), Dense(2*idim,1))
x = rand(Float32, idim, batch_size)
y = rand(Float32, 1, batch_size)


function grad_loss(m, x)
    l, b = Flux.pullback(() -> m(x), Flux.params(x))
    grads = b(ones(Float32, 1, batch_size))
    Flux.mean(sqrt.(sum(grads[x].^2, dims = 1)))
end


function total_loss()
    Flux.mse(m(x), y) + grad_loss(m,x)
end

l, b = Flux.pullback(total_loss, Flux.params(m))
grad = b(1f0)

Here is the stack trace:

LoadError: MethodError: no method matching size(::Nothing, ::Int64)
Closest candidates are:
  size(::BitArray{1}, ::Integer) at bitarray.jl:107
  size(::Tuple, ::Integer) at tuple.jl:22
  size(::Number, ::Integer) at number.jl:63
  ...
Stacktrace:
 [1] (::Zygote.var"#1086#1087"{Nothing})(::Int64) at /home/anthonycorso/.julia/packages/Zygote/pmW1K/src/lib/broadcast.jl:51
 [2] ntuple at ./ntuple.jl:41 [inlined]
 [3] trim(::Array{Float32,1}, ::Nothing) at /home/anthonycorso/.julia/packages/Zygote/pmW1K/src/lib/broadcast.jl:51
 [4] unbroadcast at /home/anthonycorso/.julia/packages/Zygote/pmW1K/src/lib/broadcast.jl:53 [inlined]
 [5] #1091 at /home/anthonycorso/.julia/packages/Zygote/pmW1K/src/lib/broadcast.jl:74 [inlined]
 [6] map at ./tuple.jl:158 [inlined]
 [7] #1090 at /home/anthonycorso/.julia/packages/Zygote/pmW1K/src/lib/broadcast.jl:74 [inlined]
 [8] #2471#back at /home/anthonycorso/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [9] #2627#back at /home/anthonycorso/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [10] #175 at /home/anthonycorso/.julia/packages/Zygote/pmW1K/src/lib/lib.jl:182 [inlined]
 [11] #359#back at /home/anthonycorso/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [12] adjoint at /home/anthonycorso/.julia/packages/Zygote/pmW1K/src/lib/broadcast.jl:73 [inlined]
 [13] #175 at /home/anthonycorso/.julia/packages/Zygote/pmW1K/src/lib/lib.jl:182 [inlined]
 [14] (::Zygote.var"#359#back#177"{Zygote.var"#175#176"{typeof(∂(adjoint)),Tuple{Tuple{Nothing,Nothing,Nothing},Tuple{Nothing,Nothing}}}})(::Tuple{Array{Nothing,2},NamedTuple{(:xs,),Tuple{Tuple{Nothing,Nothing}}}}) at /home/anthonycorso/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [15] _pullback at /home/anthonycorso/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [16] (::typeof(∂(_pullback)))(::Tuple{Array{Nothing,2},NamedTuple{(Symbol("#2470#_back"),),Tuple{NamedTuple{(:xs,),Tuple{Tuple{Nothing,Nothing}}}}}}) at /home/anthonycorso/.julia/packages/Zygote/pmW1K/src/compiler/interface2.jl:0
 [17] Dense at /home/anthonycorso/.julia/packages/Flux/05b38/src/layers/basic.jl:123 [inlined]
 [18] (::typeof(∂(_pullback)))(::Tuple{Nothing,NamedTuple{(:t,),Tuple{Tuple{Nothing,Nothing,Nothing,NamedTuple{(Symbol("#2684#_back"),),Tuple{NamedTuple{(:x,),Tuple{Array{Nothing,2}}}}},NamedTuple{(Symbol("#2470#_back"),),Tuple{NamedTuple{(:xs,),Tuple{Tuple{Nothing,Nothing}}}}},Nothing,Nothing}}}}) at /home/anthonycorso/.julia/packages/Zygote/pmW1K/src/compiler/interface2.jl:0
 [19] Dense at /home/anthonycorso/.julia/packages/Flux/05b38/src/layers/basic.jl:134 [inlined]
 [20] (::typeof(∂(_pullback)))(::Tuple{Nothing,NamedTuple{(:t,),Tuple{Tuple{NamedTuple{(:t,),Tuple{Tuple{Nothing,Nothing,Nothing,NamedTuple{(Symbol("#2684#_back"),),Tuple{NamedTuple{(:x,),Tuple{Array{Nothing,2}}}}},NamedTuple{(Symbol("#2470#_back"),),Tuple{NamedTuple{(:xs,),Tuple{Tuple{Nothing,Nothing}}}}},Nothing,Nothing}}}}}}}) at /home/anthonycorso/.julia/packages/Zygote/pmW1K/src/compiler/interface2.jl:0
 [21] applychain at /home/anthonycorso/.julia/packages/Flux/05b38/src/layers/basic.jl:36 [inlined]
 [22] (::typeof(∂(_pullback)))(::Tuple{Nothing,NamedTuple{(:t,),Tuple{Tuple{NamedTuple{(:t,),Tuple{Tuple{Nothing,Nothing,NamedTuple{(:t,),Tuple{Tuple{NamedTuple{(:t,),Tuple{Tuple{Nothing,Nothing,Nothing,Nothing,NamedTuple{(Symbol("#2470#_back"),),Tuple{NamedTuple{(:xs,),Tuple{Tuple{Nothing,Nothing}}}}},Nothing,Nothing}}}}}},Nothing}}},Nothing,NamedTuple{(:t,),Tuple{Tuple{NamedTuple{(:t,),Tuple{Tuple{Nothing,Nothing,Nothing,NamedTuple{(Symbol("#2684#_back"),),Tuple{NamedTuple{(:x,),Tuple{Array{Nothing,2}}}}},NamedTuple{(Symbol("#2470#_back"),),Tuple{NamedTuple{(:xs,),Tuple{Tuple{Nothing,Nothing}}}}},Nothing,Nothing}}}}}},Nothing}}}}) at /home/anthonycorso/.julia/packages/Zygote/pmW1K/src/compiler/interface2.jl:0
 [23] Chain at /home/anthonycorso/.julia/packages/Flux/05b38/src/layers/basic.jl:38 [inlined]
 [24] (::typeof(∂(_pullback)))(::Tuple{Nothing,NamedTuple{(:t,),Tuple{Tuple{Nothing,NamedTuple{(:t,),Tuple{Tuple{NamedTuple{(:t,),Tuple{Tuple{Nothing,Nothing,NamedTuple{(:t,),Tuple{Tuple{NamedTuple{(:t,),Tuple{Tuple{Nothing,Nothing,Nothing,Nothing,NamedTuple{(Symbol("#2470#_back"),),Tuple{NamedTuple{(:xs,),Tuple{Tuple{Nothing,Nothing}}}}},Nothing,Nothing}}}}}},Nothing}}},Nothing,NamedTuple{(:t,),Tuple{Tuple{NamedTuple{(:t,),Tuple{Tuple{Nothing,Nothing,Nothing,NamedTuple{(Symbol("#2684#_back"),),Tuple{NamedTuple{(:x,),Tuple{Array{Nothing,2}}}}},NamedTuple{(Symbol("#2470#_back"),),Tuple{NamedTuple{(:xs,),Tuple{Tuple{Nothing,Nothing}}}}},Nothing,Nothing}}}}}},Nothing}}}}}}}) at /home/anthonycorso/.julia/packages/Zygote/pmW1K/src/compiler/interface2.jl:0
 [25] #37 at /home/anthonycorso/grad_test.jl:12 [inlined]
 [26] (::typeof(∂(_pullback)))(::Tuple{Nothing,NamedTuple{(:t,),Tuple{Tuple{Nothing,Nothing,NamedTuple{(:t,),Tuple{Tuple{Nothing,NamedTuple{(:t,),Tuple{Tuple{NamedTuple{(:t,),Tuple{Tuple{Nothing,Nothing,NamedTuple{(:t,),Tuple{Tuple{NamedTuple{(:t,),Tuple{Tuple{Nothing,Nothing,Nothing,Nothing,NamedTuple{(Symbol("#2470#_back"),),Tuple{NamedTuple{(:xs,),Tuple{Tuple{Nothing,Nothing}}}}},Nothing,Nothing}}}}}},Nothing}}},Nothing,NamedTuple{(:t,),Tuple{Tuple{NamedTuple{(:t,),Tuple{Tuple{Nothing,Nothing,Nothing,NamedTuple{(Symbol("#2684#_back"),),Tuple{NamedTuple{(:x,),Tuple{Array{Nothing,2}}}}},NamedTuple{(Symbol("#2470#_back"),),Tuple{NamedTuple{(:xs,),Tuple{Tuple{Nothing,Nothing}}}}},Nothing,Nothing}}}}}},Nothing}}}}}}}}}}) at /home/anthonycorso/.julia/packages/Zygote/pmW1K/src/compiler/interface2.jl:0
 [27] pullback at /home/anthonycorso/.julia/packages/Zygote/pmW1K/src/compiler/interface.jl:172 [inlined]
 [28] (::typeof(∂(pullback)))(::Tuple{Nothing,NamedTuple{(:ps, :cx, :back),Tuple{Nothing,Nothing,NamedTuple{(:t,),Tuple{Tuple{Nothing,Nothing,NamedTuple{(:t,),Tuple{Tuple{Nothing,NamedTuple{(:t,),Tuple{Tuple{NamedTuple{(:t,),Tuple{Tuple{Nothing,Nothing,NamedTuple{(:t,),Tuple{Tuple{NamedTuple{(:t,),Tuple{Tuple{Nothing,Nothing,Nothing,Nothing,NamedTuple{(Symbol("#2470#_back"),),Tuple{NamedTuple{(:xs,),Tuple{Tuple{Nothing,Nothing}}}}},Nothing,Nothing}}}}}},Nothing}}},Nothing,NamedTuple{(:t,),Tuple{Tuple{NamedTuple{(:t,),Tuple{Tuple{Nothing,Nothing,Nothing,NamedTuple{(Symbol("#2684#_back"),),Tuple{NamedTuple{(:x,),Tuple{Array{Nothing,2}}}}},NamedTuple{(Symbol("#2470#_back"),),Tuple{NamedTuple{(:xs,),Tuple{Tuple{Nothing,Nothing}}}}},Nothing,Nothing}}}}}},Nothing}}}}}}}}}}}}) at /home/anthonycorso/.julia/packages/Zygote/pmW1K/src/compiler/interface2.jl:0
 [29] grad_loss at /home/anthonycorso/grad_test.jl:12 [inlined]
 [30] (::typeof(∂(grad_loss)))(::Float32) at /home/anthonycorso/.julia/packages/Zygote/pmW1K/src/compiler/interface2.jl:0
 [31] total_loss at /home/anthonycorso/grad_test.jl:19 [inlined]
 [32] (::typeof(∂(total_loss)))(::Float32) at /home/anthonycorso/.julia/packages/Zygote/pmW1K/src/compiler/interface2.jl:0
 [33] (::Zygote.var"#50#51"{Params,Zygote.Context,typeof(∂(total_loss))})(::Float32) at /home/anthonycorso/.julia/packages/Zygote/pmW1K/src/compiler/interface.jl:177
 [34] top-level scope at /home/anthonycorso/grad_test.jl:23
 [35] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1088
in expression starting at /home/anthonycorso/grad_test.jl:23