Hi, I am facing a problem that the network model depends on data. In fact, I am trying to train an iterator to solve $Au=f$,the model performs a single-step iteration u_next = model(u)
, hence the model
depends on A and f. For example, Richardson iteration reads
model = Chain(
x -> x + ω*(f-A(x)),
)
At this time, the A and f in the model are taken from the data
for (x, y) in train_loader
global A = x |> device
global f = y |> device
u = zeros(size(f)) |> device
for i = 1:max_iter
u = model(u)
loss = norm(f - Au) / norm(f)
# backward
end
But doing so will give me warnings or even errors
Warning: Performing scalar indexing on task Task (runnable) @0x00007f477458c2f0.
│Invocation of getindex resulted in scalar indexing of a GPU array.
│This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
How do I correctly define this type of model using Flux.jl
?
Thanks