Hi everyone,
After spending quite some time solving inverse problems in optics with custom tools and manual algorithmic differentiation, I’ve recently been exploring more canonical approaches using modern ML frameworks.
Since I’m working in Julia, I tried both Lux and Flux and eventually settled on Flux. I’ve also started comparing manual AD to Zygote’s automatic differentiation, including hybrid approaches with custom rules. So far, it’s been great — but I’ve run into a couple of missing features in Optimisers.jl that I wanted to ask the community about.
Specifically, I’m wondering if the following are supported (or planned), or if I’d need to stick with my own extensions:
- The ability to assign different optimisers to different trainable parameters, similar to
optax.transforms.partition
in JAX. - The ability to apply a projection or prox operator after each update, potentially with a different projection per parameter (useful in proximal gradient descent).
Thanks to Julia’s multiple dispatch, I was able to quickly prototype something that works well for my case without modifying the source of Optimisers.jl. Here’s a simplified sketch of the approach:
function Optimisers.setup(rules::IdDict{K,<:AbstractRule}, default_rule::AbstractRule, model) where K cache = IdDict() Optimisers._setup(rules, default_rule, model; cache) end function Optimisers._setup(rules, default_rule, x; cache) ... end struct ProxRule{R <: AbstractRule} <: AbstractRule rule :: R prox! :: Function end function Optimisers.apply!(o::ProxRule, state, x, x̄) return apply!(o.rule, state, x, x̄) end Optimisers.init(o::ProxRule, x::AbstractArray) = init(o.rule, x) function Optimisers._update!(ℓ::Leaf{<:ProxRule,S}, x; grads, params) where S ... subtract!(x, x̄′) ℓ.rule.prox!(x) end
I’m also comparing the Flux + Zygote + Optimisers to Equinox + JAX + Optax, and I noticed that Optax provides optax.transforms.partition
and optax.projections
, which nicely cover these use cases. I haven’t tried writing a custom projection in Optax yet, but the partitioning transform works perfectly for what I need.
Would something like this fit in the scope of Optimisers.jl, or is there already a canonical way to achieve this pattern that I’ve missed?
Happy to hear your thoughts — and open to contributing back if there’s interest!
Thanks,
Best,
Nicolas