Gradients of custom functions with Flux

Hi, I’m trying to implement MAML, but I’m getting an error when calculating a gradient using Flux.

The error occurs when trying to get the gradient of another function, inner_loop. Some research indicates this is an out of memory error…? I could be wrong though. The error:

Assertion failed: (token.V->getType()->isTokenTy()), function emit_expr, file /Users/julia/buildbot/worker/package_macos64/build/src/codegen.cpp, line 4357.

signal (6): Abort trap: 6
...
__pthread_kill at /usr/lib/system/libsystem_kernel.dylib (unknown line)
Allocations: 304915733 (Pool: 304850148; Big: 65585); GC: 209

The inner_loop function calculates gradients, so I assume this could have something to do with it. Can Flux not calculate the gradient of a function that calculates gradients?

I have looked at implementations of MAML and foMAML (MAML, foMAML), which make sense, but they use the old version of Flux with Tracker. I’d like to use the current version.

Working example of my code which generates the error:

using Random
using Flux


"Get a randomly generated 'image' and corresponding label"
function rand_img_label(label)
	img = Float32.(rand(0:1, (105,105)))
	img = reshape(img, (105,105,1,1))
	label = Float32.(Flux.onehot(label, 1:5))
	return (img, label)
end

"""
get_task_batch()

Returns a list of 5 tasks. Each task is a tuple containing the support set and target set.
	A support set is a list of five (image, label) tuples.
	A target set consists of a single (image, label) tuple.

All images are randomly generated for the purpose of this example code.
"""
function get_task_batch()
	batch = []
	for task_idx in 1:5
		support = [rand_img_label(task_idx) for i in 1:5]
		target = rand_img_label(task_idx)
		push!(batch, (support, target))
	end
	return batch
end

"""
get_model(n_way=5)

returns a Flux model for [n_way] classification
"""
function get_model(n_way=5)
    model = Chain(
        Conv((3,3), 1=>16, relu),   # 103x103x16
        MaxPool((2,2)),             # 51x51x16
        Conv((3,3), 16=>32, relu),  # 49x49x32
        MaxPool((2,2)),             # 24x24x32
        Conv((3,3), 32=>64, relu),  # 22x22x64
        MaxPool((2,2)),             # 11x11x64
        x -> reshape(x, :, size(x, 4)),
        Dense(7744, n_way, relu)
        )
    return model
end

######## MAML stuff ###########

"""
inner_loop(model, loss, opt)

The inner loop of the MAML algorithm.
Used as a loss function, returns the a number value
"""
function inner_loop(model, n_way, loss, opt)
    task_batch = get_task_batch()
    meta_loss = []
    for (i, Ti) in enumerate(task_batch) # for each task
        support_set, target_set = Ti

        params_copy = deepcopy(params(model)) # get copy of params
        m = get_model(n_way) # get a new model
        Flux.loadparams!(m, params_copy) # load the copied params to the new model

        for (x_support, y_support) in support_set
            grads = gradient(() -> loss(m, x_support, y_support), params(m)) # gradient of loss wrt params
            Flux.update!(opt, params(m), grads) # update params
        end

        (x_target, y_target) = target_set
        loss_target = loss(m, x_target, y_target)
        push!(meta_loss, loss_target)
    end
    loss_sum = sum(meta_loss)
    return loss_sum
end


"""
MAML!(model, epochs)

Updates model parameters to best minimize the inner_loop() output.
"""
function MAML!(model, epochs)
    loss(m, x, y) = Flux.logitcrossentropy(m(x), y) # loss function. we aren't using softmax output, so need to work with logit output
    opt = Flux.Optimise.Descent()
    for _ in 1:epochs
        grads = gradient(() -> inner_loop(model, n_way, loss, opt), params(model))
        Flux.update!(opt, params(model), grads)
    end
end


####### Tests ########

model = get_model(5)
loss(m,x,y) = Flux.logitcrossentropy(m(x), y)
opt = Flux.Optimise.Descent()

l = inner_loop(model, 5, loss, opt) # make sure inner_loop works as expected

gs = gradient(() -> inner_loop(model, 5, loss, opt), params(model))  ### ERROR!!

The last line generates the error. To be clear, this code uses random matrices instead of images just for the sake of this example.

Any tips for getting gradients of complicated functions? To frame the question another way, how would you update this model zoo example to work with the current version of Flux?

I’m open to any random critiques or suggestions as well, I’m trying to learn! Thanks in advance.

I’m running Julia 1.4 with all packages up to date.

1 Like

Flux uses Zygote for automatic differentiation, and it doesn’t support modification of arrays (such as Flux.update!()) yet.

I don’t know if that’s really the reason for this crash and in any case, It would be nicer if it resulted in a better error message. But this might give you something to investigate.

3 Likes

Thanks for the reply, this gives me some direction. I’m working on making the function non mutable with the help of Zygote.nograd.

The crash actually turned out to be related to Flux.loadparams!.

1 Like