Dear all,
I want to put the gradient of neural network (i.e., the gradient of NN output w.r.t input) in the loss function. However, it seems that Zygote cannot differentiate that. Here is my code:
using Flux
using Zygote
using Random
Xtrain = rand(Float32,(4,100));
Ytrain = rand(Float32,(1,100));
#******************* NN ********************
model = Chain(x->x[1:2,:],
Dense(2, 50, tanh),
Dense(50, 50, tanh),
Dense(50, 1))
θ = params(model);
#******************* Lost function ********************
function loss(x,y)
ps = Flux.Params([x]);
g = Flux.gradient(ps) do
model(x)[1]
end
return ( Flux.mse(g[x][1:2,:][1] , x[3:4,:]) + Flux.mse(model(x),y) )
end
#******************* Training ********************
train_loader = Flux.Data.DataLoader(Xtrain,Ytrain, batchsize=1,shuffle=true);
Nepochs = 10;
opt = ADAM(0.001)
for epoch in 1:Nepochs
@time Flux.train!(loss, θ, train_loader, opt)
end
However, the code failed to work. This is error message:
ERROR: Can't differentiate foreigncall expression
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] get at ./iddict.jl:87 [inlined]
[3] (::typeof(∂(get)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
[4] accum_global at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/lib/lib.jl:56 [inlined]
[5] (::typeof(∂(accum_global)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
[6] #89 at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/lib/lib.jl:67 [inlined]
[7] (::typeof(∂(λ)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
[8] #1550#back at /home/wuxxx184/ma000311/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
[9] (::typeof(∂(λ)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
[10] getindex at ./tuple.jl:24 [inlined]
[11] gradindex at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/reverse.jl:12 [inlined]
[12] #3 at ./REPL[8]:5 [inlined]
[13] (::typeof(∂(λ)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
[14] #54 at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:177 [inlined]
[15] (::typeof(∂(λ)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
[16] gradient at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:54 [inlined]
[17] (::typeof(∂(gradient)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
[18] loss at ./REPL[8]:4 [inlined]
[19] (::typeof(∂(loss)))(::Float32) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
[20] #145 at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/lib/lib.jl:175 [inlined]
[21] #1681#back at /home/wuxxx184/ma000311/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
[22] #15 at /home/wuxxx184/ma000311/.julia/packages/Flux/05b38/src/optimise/train.jl:83 [inlined]
[23] (::typeof(∂(λ)))(::Float32) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
[24] (::Zygote.var"#54#55"{Params,Zygote.Context,typeof(∂(λ))})(::Float32) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:177
[25] gradient(::Function, ::Params) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:54
[26] macro expansion at /home/wuxxx184/ma000311/.julia/packages/Flux/05b38/src/optimise/train.jl:82 [inlined]
[27] macro expansion at /home/wuxxx184/ma000311/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
[28] train!(::Function, ::Params, ::Flux.Data.DataLoader{Tuple{Array{Float32,2},Array{Float32,2}}}, ::ADAM; cb::Flux.Optimise.var"#16#22") at /home/wuxxx184/ma000311/.julia/packages/Flux/05b38/src/optimise/train.jl:80
[29] train!(::Function, ::Params, ::Flux.Data.DataLoader{Tuple{Array{Float32,2},Array{Float32,2}}}, ::ADAM) at /home/wuxxx184/ma000311/.julia/packages/Flux/05b38/src/optimise/train.jl:78
[30] macro expansion at ./timing.jl:174 [inlined]
[31] top-level scope at ./REPL[12]:2
Does anyone know how to fix this?
Thanks a lot!
Xiaodong