Split an array into multiple smaller arrays - alternative to unsafe_wrap without views

I’m wondering if this is a good use case for unsafe_wrap, or if there is a better way to do this. Basically I have a vector, and I want to be able to interpret the vector as a set of arrays. I cannot use subarrays or similar, as something about them doesn’t play nice with Zygote/CuArrays (https://github.com/FluxML/Zygote.jl/issues/600).

using Revise, Flux, Zygote

struct mwestruct{TΠ, TA}
  Π::TΠ
  A::TA
  B::TA
end

#constructor
function mwestruct(K::Int, A::TA, B::TA) where TA<:AbstractMatrix

  i = (i->(i-1)%K+1).(1:K^2)
  Π = vcat(vec(A), vec(B)) #this is the parameter vector

  startA = 1
  p = pointer(Π, startA)
  linkedA = unsafe_wrap(TA, p, size(A))

  startB = length(vec(A))+1
  p = pointer(Π, startB)
  linkedB = unsafe_wrap(TA, p, size(B))

  return mwestruct(Π, linkedA, linkedB)
end

#lets call this the loss function (much simpler than the non-mwe version)
function (m::mwestruct)(x)
  sum(abs.((m.A .-  m.A.*m.B .- diag(m.B))  ./ (x .+ 1f0)))
end

Flux.@functor mwestruct
Flux.trainable(a::mwestruct) = (a.A,a.B)

#this computes the gradient as a vector corresponding to the parameter vector
mwevecgrad(a, gs) = reduce(vcat, [vec(gs[a.A]),vec(gs[a.B])])

function zygotemwe(N)
  A = Matrix{Float32}(reshape(collect(1:N^2), (N,N)))
  B = Matrix{Float32}(reshape(collect(1:N^2), (N,N)))
  x = Matrix{Float32}(reshape(collect(1:N^2), (N,N))) .* 2

  s=mwestruct(N, A, B)

  gs = gradient(()->s(x),Flux.params(s))
  vecgrad = mwevecgrad(s,gs)
  @info "gradient: $(vecgrad)" 
end

zygotemwe(2)

Context (probably not important to the question):
I am trying to solve a non-linear least-squares problem using Levenberg–Marquardt. The parameters (~20k) must be organized as multiple matrices to compute the loss function. Unfortunately, the least-squares algorithms I have seen all require the parameters in a vector form. Ordinarily, I would have a main parameter vector and a combination of views and reshaping to create the parameter matrices. Unfortunately, this does not seem to be an option due to the interplay of Zygote and CuArrays- It seems I need multiple “authentic” small arrays of the parameters to compute the loss function, and a single parameter vector for the optimizer.