Custom Optimisers and Projections in Flux/Zygote: Is There a Canonical Way?

Hi everyone,

After spending quite some time solving inverse problems in optics with custom tools and manual algorithmic differentiation, I’ve recently been exploring more canonical approaches using modern ML frameworks.

Since I’m working in Julia, I tried both Lux and Flux and eventually settled on Flux. I’ve also started comparing manual AD to Zygote’s automatic differentiation, including hybrid approaches with custom rules. So far, it’s been great — but I’ve run into a couple of missing features in Optimisers.jl that I wanted to ask the community about.

Specifically, I’m wondering if the following are supported (or planned), or if I’d need to stick with my own extensions:

  1. The ability to assign different optimisers to different trainable parameters, similar to optax.transforms.partition in JAX.
  2. The ability to apply a projection or prox operator after each update, potentially with a different projection per parameter (useful in proximal gradient descent).

Thanks to Julia’s multiple dispatch, I was able to quickly prototype something that works well for my case without modifying the source of Optimisers.jl. Here’s a simplified sketch of the approach:

function Optimisers.setup(rules::IdDict{K,<:AbstractRule}, default_rule::AbstractRule, model) where K
    cache = IdDict()
    Optimisers._setup(rules, default_rule, model; cache)
end

function Optimisers._setup(rules, default_rule, x; cache)
    ...
end

struct ProxRule{R <: AbstractRule} <: AbstractRule
    rule :: R
    prox! :: Function
end

function Optimisers.apply!(o::ProxRule, state, x, x̄)
    return apply!(o.rule, state, x, x̄)
end

Optimisers.init(o::ProxRule, x::AbstractArray) = init(o.rule, x)

function Optimisers._update!(ℓ::Leaf{<:ProxRule,S}, x; grads, params) where S
    ...
    subtract!(x, x̄′)
    ℓ.rule.prox!(x)
end

I’m also comparing the Flux + Zygote + Optimisers to Equinox + JAX + Optax, and I noticed that Optax provides optax.transforms.partition and optax.projections, which nicely cover these use cases. I haven’t tried writing a custom projection in Optax yet, but the partitioning transform works perfectly for what I need.

Would something like this fit in the scope of Optimisers.jl, or is there already a canonical way to achieve this pattern that I’ve missed?

Happy to hear your thoughts — and open to contributing back if there’s interest!

Thanks,
Best,
Nicolas

There isn’t a built-in way of doing these in Optimisers.jl, but as you’ve seen it’s not so hard to hack into its internals to make it do things.

To set up different rules on different parameters, this PR has a sketch of one way. Assuming that you can tell from the shape of the parameter array what to do (e.g. bias is always a vector, ndims(x)==1).

To apply some function to parameters, I’d consider just writing a recursive function:

myclamp!(model) = foreach(myclamp!, Flux.trainable(model))
myclamp!(par::AbstractArray{<:Real})  = clamp!(par, -10, 10)
myclamp!(par::AbstractVector{<:Real})  = clamp!(par, -5, 5)  # bias

Although unlike yours you’d have to call this separately after update!.