CatViews with Flux optimize!

Flux’s update!(opt, x, x̄) function errors when x and x̄ are of type CatView with the following message:
TypeError: in typeassert, expected Tuple{CatView{1, Float64}, CatView{1, Float64}, Vector{Float64}}, got a value of type Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}}

Using the debugger, I think the problem stems from the x̄r = ArrayInterface.restructure(x, x̄) call at the beginning of update!, which re-types x̄ to be a Vector so its type no longer matches the CatView type of x.

Is there a way around this error where I can still use CatView inputs to update!? I could (as shown in code below) just copy x to a simple vector format (and then re-copy the results back to x afterward), but that’s ugly code and a (small, but annoying) waste of memory/compute. Maybe there is a way to over-write the restructure function to do nothing? That sounds like bad practice too though…

MWE:

import Flux # adam optimizer
using CatViews

x = [randn(2)]
dx = [randn(2)]
opt = Flux.ADAM()

# create CatViews of the variable and gradient 
xCV = CatView([@view x[k][:] for k=1:length(x)]...)
dxCV = CatView([@view dx[k][:] for k=1:length(x)]...)

tmp = copy(xCV) # tmp is a Vector while xCV is a CatView
Flux.Optimise.update!(opt, tmp, dxCV) # this works 

Flux.Optimise.update!(opt, xCV, dxCV) # this does not work 

In terms of why I want to use CatView input: the input variables are actually a collection of OffsetArrays and the user can decide if they want to descend with respect to the OffsetArray values and/or other tuning parameters. Using CatViews allows me to update everything in place while letting Flux see the variables as a simple vector where the true structure is much more complicated.

Cannot you use update the arrays inside CatView?
I think uodating CatViews would be very non-performant, as Catviews will calculate indexes of wrapped arrays for each access. For a nice API, you can just overload the update!

I’m not sure if this is what you meant, but
[Flux.Optimise.update!(opt, x[i], dx[i]) for i=1:length(x)]
works (Flux.Optimise.update!(opt, x, dx) does not).

I’ll have to think more if that or overloading update! makes more sense long-term. Thanks for the ideas!

This is sort of what I meant. If you look how getindex in CatViews is implemented, you will see that it would be very wasteful.
I would overload and you do not need a generator foreach(i -> Flux.Optimise.update!(opt, x[i], dx[i]), 1:length(x)]

1 Like

The method in question is the default for dealing with non-standard array types. I think overloading would be appropriate here:

function Flux.Optimise.update!(opt, x::CatView, dx::CatView)
  foreach(i -> Flux.Optimiser.update!(opt, x[i], dx[i]), 1:length(x))
  return x
end