How to define a data-dependent model without using global variables

Hi, I am facing a problem that the network model depends on data. In fact, I am trying to train an iterator to solve $Au=f$,the model performs a single-step iteration u_next = model(u), hence the model depends on A and f. For example, Richardson iteration reads

model = Chain(
    x -> x + ω*(f-A(x)),

At this time, the A and f in the model are taken from the data

for (x, y) in train_loader
    global A = x |> device
    global f = y |> device
    u = zeros(size(f)) |> device
	for i = 1:max_iter
		u = model(u)
	loss = norm(f - Au) / norm(f)
	# backward

But doing so will give me warnings or even errors

Warning: Performing scalar indexing on task Task (runnable) @0x00007f477458c2f0.
│Invocation of getindex resulted in scalar indexing of a GPU array.
│This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.

How do I correctly define this type of model using Flux.jl? :smiley:

Would not defining your own model solve the problem, for that, see docs here


(Note that this has nothing to do with globals vs. locals. Thanks to lexical scoping, an expression like x -> x + ω*(f-A(x)) works fine to define an anonymous function even if A, f, and ω are local variables.)

1 Like

Okay, but what I mean is that A in the model corresponds to A in the for loop. The A in the for loop is a local variable. As a result, the model defined outside the for loop cannot see A , so I need to set A to a global variable, otherwise an error will occur

ERROR: LoadError: UndefVarError: A not defined

1 Like

Thanks for your reply :smiley: let me take a look.