I’m wondering if this is a good use case for unsafe_wrap, or if there is a better way to do this. Basically I have a vector, and I want to be able to interpret the vector as a set of arrays. I cannot use subarrays or similar, as something about them doesn’t play nice with Zygote/CuArrays (Zygote+Views+CuArrays=Bad News Bears (silently gives incorrect results) · Issue #600 · FluxML/Zygote.jl · GitHub).
using Revise, Flux, Zygote
struct mwestruct{TΠ, TA}
Π::TΠ
A::TA
B::TA
end
#constructor
function mwestruct(K::Int, A::TA, B::TA) where TA<:AbstractMatrix
i = (i->(i-1)%K+1).(1:K^2)
Π = vcat(vec(A), vec(B)) #this is the parameter vector
startA = 1
p = pointer(Π, startA)
linkedA = unsafe_wrap(TA, p, size(A))
startB = length(vec(A))+1
p = pointer(Π, startB)
linkedB = unsafe_wrap(TA, p, size(B))
return mwestruct(Π, linkedA, linkedB)
end
#lets call this the loss function (much simpler than the non-mwe version)
function (m::mwestruct)(x)
sum(abs.((m.A .- m.A.*m.B .- diag(m.B)) ./ (x .+ 1f0)))
end
Flux.@functor mwestruct
Flux.trainable(a::mwestruct) = (a.A,a.B)
#this computes the gradient as a vector corresponding to the parameter vector
mwevecgrad(a, gs) = reduce(vcat, [vec(gs[a.A]),vec(gs[a.B])])
function zygotemwe(N)
A = Matrix{Float32}(reshape(collect(1:N^2), (N,N)))
B = Matrix{Float32}(reshape(collect(1:N^2), (N,N)))
x = Matrix{Float32}(reshape(collect(1:N^2), (N,N))) .* 2
s=mwestruct(N, A, B)
gs = gradient(()->s(x),Flux.params(s))
vecgrad = mwevecgrad(s,gs)
@info "gradient: $(vecgrad)"
end
zygotemwe(2)
Context (probably not important to the question):
I am trying to solve a non-linear least-squares problem using Levenberg–Marquardt. The parameters (~20k) must be organized as multiple matrices to compute the loss function. Unfortunately, the least-squares algorithms I have seen all require the parameters in a vector form. Ordinarily, I would have a main parameter vector and a combination of views and reshaping to create the parameter matrices. Unfortunately, this does not seem to be an option due to the interplay of Zygote and CuArrays- It seems I need multiple “authentic” small arrays of the parameters to compute the loss function, and a single parameter vector for the optimizer.