I want to compute the grad of pinv(hessian)
struct hess_nn
model
re
params
# Define inner constructor method
function hess_nn(model; p = nothing)
_p, re = Flux.destructure(model)
if p === nothing
p = _p
end
return new(model, re, p)
end
end
function (nn::hess_nn)(x, p = nn.params)
M = div(size(x,1), 2)
re = nn.re
hess = x -> Zygote.hessian_reverse(x->sum(re(p)(x)), x) # we have to compute the whole hessian
hess = hess(x)[M+1:end, M+1:end]
out = pinv(hess)
return out
end
# Initialize the struct
g = Chain(Dense(4, 10, σ), Dense(10,1, σ))
model_net = hess_nn(g)
params = model_net.params
re = model_net.re
# Data
x = rand(4,1);
y = rand(1);
# some toy loss function
loss(x, y, p) = sum((y .- mean.(model_net(x, p))).^2)
# Training
epochs = 1
opt = ADAM(0.01)
@show loss(x,y,params)
@show model_net(x, params)
Flux.trainable(nn::hess_nn) = (model_net.params,)
println("loss before training $(loss(x, y, params))")
for epoch in 1:epochs
gs = ReverseDiff.gradient(p -> loss(x, y, p), params)
println("$epoch, $gs")
Flux.Optimise.update!(opt, params, gs)
end
# println("loss after training $(loss(x, y, params))")
There is an error:
MethodError: no method matching svd!(::Array{ReverseDiff.TrackedReal{Float64,Float32,Nothing},2}; full=false, alg=LinearAlgebra.DivideAndConquer())
Closest candidates are:
svd!(::LinearAlgebra.AbstractTriangular; kwargs...) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/triangular.jl:2672
svd!(::CUDA.CuArray{T,2}; full, alg) where T at /home/solar/.julia/packages/CUDA/mbPFj/lib/cusolver/linalg.jl:102
svd!(::StridedArray{T, 2}; full, alg) where T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64} at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/svd.jl:93
...
Stacktrace:
[1] svd(::Array{ReverseDiff.TrackedReal{Float64,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},2}; full::Bool, alg::LinearAlgebra.DivideAndConquer) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/svd.jl:158
[2] pinv(::Array{ReverseDiff.TrackedReal{Float64,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},2}; atol::Float64, rtol::Float64) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/dense.jl:1356
[3] pinv(::Array{ReverseDiff.TrackedReal{Float64,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},2}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/dense.jl:1335
[4] (::hess_nn)(::Array{Float64,2}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}) at ./In[145]:23
[5] loss(::Array{Float64,2}, ::Array{Float64,1}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}) at ./In[145]:42
[6] (::var"#571#572")(::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}) at ./In[145]:54
[7] ReverseDiff.GradientTape(::var"#571#572", ::Array{Float32,1}, ::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/tape.jl:199
[8] gradient(::Function, ::Array{Float32,1}, ::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/gradients.jl:22 (repeats 2 times)
[9] top-level scope at In[145]:54
[10] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1091
Can it be overcomed?