Flux’s update!(opt, x, x̄)
function errors when x and x̄ are of type CatView with the following message:
TypeError: in typeassert, expected Tuple{CatView{1, Float64}, CatView{1, Float64}, Vector{Float64}}, got a value of type Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}}
Using the debugger, I think the problem stems from the x̄r = ArrayInterface.restructure(x, x̄)
call at the beginning of update!
, which re-types x̄ to be a Vector so its type no longer matches the CatView type of x.
Is there a way around this error where I can still use CatView inputs to update!
? I could (as shown in code below) just copy x to a simple vector format (and then re-copy the results back to x afterward), but that’s ugly code and a (small, but annoying) waste of memory/compute. Maybe there is a way to over-write the restructure
function to do nothing? That sounds like bad practice too though…
MWE:
import Flux # adam optimizer
using CatViews
x = [randn(2)]
dx = [randn(2)]
opt = Flux.ADAM()
# create CatViews of the variable and gradient
xCV = CatView([@view x[k][:] for k=1:length(x)]...)
dxCV = CatView([@view dx[k][:] for k=1:length(x)]...)
tmp = copy(xCV) # tmp is a Vector while xCV is a CatView
Flux.Optimise.update!(opt, tmp, dxCV) # this works
Flux.Optimise.update!(opt, xCV, dxCV) # this does not work
In terms of why I want to use CatView input: the input variables are actually a collection of OffsetArrays and the user can decide if they want to descend with respect to the OffsetArray values and/or other tuning parameters. Using CatViews allows me to update everything in place while letting Flux see the variables as a simple vector where the true structure is much more complicated.