Differentiating the Custom Layer and Regularizer

If you’re referring to including a regularization penalty in the loss, it is standard and @mcabbott’s snippet above shows that it works. Note also WeightDecay which works much as it does in other frameworks: a more efficient way to add L2 regularization to a network.

No. I think the confusion comes from this:

I presume you’re familiar with Python frameworks, so this is basically equivalent to:

criterion = nn.CrossEntropyLoss()
outputs = []
for xs in x_train:
  outputs.append(model(xs))
loss = criterion(outputs, y_train)
loss.backward()

Which likewise does not run. Much like you’d call .backward once per batch in PyTorch, gradient should be called once per batch and not across all batches simultaneously. We try to cover this in Fitting a Line · Flux, so please let us know on the issue tracker if any part of that is unclear.

This also likely explains:

The other part is that the Julia JIT needs some warmup time to compile everything for the first run of a model. Assuming the loop above is tweaked to work properly with batch(es) as I described earlier, the first run time should be lower and any subsequent runs should be pretty much (sometimes slightly faster/slower) what you’d expect from a Python framework.

For posterity, the ChainRules docs have a nice section on this topic.

If you look at the part of the error message after the “not supported”, it should tell you which function Zygote is complaining about. Most of the time, that is setindex! (i.e. array[i] = ...), push!/pop! or copyto! (which comes up in .= broadcasting).

If you then look at the stacktrace, you can roughly see where in the call stack is calling a function that does array mutation. In this case, we know it’s somewhere in the implementation of mapslices. Now, mapslices does not mutate any of its inputs, but there isn’t an explicit AD rule for Zygote to use on it, Zygote will try to drill down as far as it can until it finds a function with an AD rule. This usually works fine for most functions (it’s the “automatic” part of automatic differentiation), but if things bottom out in one of the mutating functions mentioned previously then you’ll receive a mutating arrays error.

2 Likes