Hi, I’m trying to implement MAML, but I’m getting an error when calculating a gradient using Flux.
The error occurs when trying to get the gradient of another function, inner_loop
. Some research indicates this is an out of memory error…? I could be wrong though. The error:
Assertion failed: (token.V->getType()->isTokenTy()), function emit_expr, file /Users/julia/buildbot/worker/package_macos64/build/src/codegen.cpp, line 4357.
signal (6): Abort trap: 6
...
__pthread_kill at /usr/lib/system/libsystem_kernel.dylib (unknown line)
Allocations: 304915733 (Pool: 304850148; Big: 65585); GC: 209
The inner_loop
function calculates gradients, so I assume this could have something to do with it. Can Flux not calculate the gradient of a function that calculates gradients?
I have looked at implementations of MAML and foMAML (MAML, foMAML), which make sense, but they use the old version of Flux with Tracker. I’d like to use the current version.
Working example of my code which generates the error:
using Random
using Flux
"Get a randomly generated 'image' and corresponding label"
function rand_img_label(label)
img = Float32.(rand(0:1, (105,105)))
img = reshape(img, (105,105,1,1))
label = Float32.(Flux.onehot(label, 1:5))
return (img, label)
end
"""
get_task_batch()
Returns a list of 5 tasks. Each task is a tuple containing the support set and target set.
A support set is a list of five (image, label) tuples.
A target set consists of a single (image, label) tuple.
All images are randomly generated for the purpose of this example code.
"""
function get_task_batch()
batch = []
for task_idx in 1:5
support = [rand_img_label(task_idx) for i in 1:5]
target = rand_img_label(task_idx)
push!(batch, (support, target))
end
return batch
end
"""
get_model(n_way=5)
returns a Flux model for [n_way] classification
"""
function get_model(n_way=5)
model = Chain(
Conv((3,3), 1=>16, relu), # 103x103x16
MaxPool((2,2)), # 51x51x16
Conv((3,3), 16=>32, relu), # 49x49x32
MaxPool((2,2)), # 24x24x32
Conv((3,3), 32=>64, relu), # 22x22x64
MaxPool((2,2)), # 11x11x64
x -> reshape(x, :, size(x, 4)),
Dense(7744, n_way, relu)
)
return model
end
######## MAML stuff ###########
"""
inner_loop(model, loss, opt)
The inner loop of the MAML algorithm.
Used as a loss function, returns the a number value
"""
function inner_loop(model, n_way, loss, opt)
task_batch = get_task_batch()
meta_loss = []
for (i, Ti) in enumerate(task_batch) # for each task
support_set, target_set = Ti
params_copy = deepcopy(params(model)) # get copy of params
m = get_model(n_way) # get a new model
Flux.loadparams!(m, params_copy) # load the copied params to the new model
for (x_support, y_support) in support_set
grads = gradient(() -> loss(m, x_support, y_support), params(m)) # gradient of loss wrt params
Flux.update!(opt, params(m), grads) # update params
end
(x_target, y_target) = target_set
loss_target = loss(m, x_target, y_target)
push!(meta_loss, loss_target)
end
loss_sum = sum(meta_loss)
return loss_sum
end
"""
MAML!(model, epochs)
Updates model parameters to best minimize the inner_loop() output.
"""
function MAML!(model, epochs)
loss(m, x, y) = Flux.logitcrossentropy(m(x), y) # loss function. we aren't using softmax output, so need to work with logit output
opt = Flux.Optimise.Descent()
for _ in 1:epochs
grads = gradient(() -> inner_loop(model, n_way, loss, opt), params(model))
Flux.update!(opt, params(model), grads)
end
end
####### Tests ########
model = get_model(5)
loss(m,x,y) = Flux.logitcrossentropy(m(x), y)
opt = Flux.Optimise.Descent()
l = inner_loop(model, 5, loss, opt) # make sure inner_loop works as expected
gs = gradient(() -> inner_loop(model, 5, loss, opt), params(model)) ### ERROR!!
The last line generates the error. To be clear, this code uses random matrices instead of images just for the sake of this example.
Any tips for getting gradients of complicated functions? To frame the question another way, how would you update this model zoo example to work with the current version of Flux?
I’m open to any random critiques or suggestions as well, I’m trying to learn! Thanks in advance.
I’m running Julia 1.4 with all packages up to date.