Do not update neural network weights with a value of 0

Hello, I have just been exposed to the field of deep learning for a short time.

Recently I was trying to create the 3-dimensional Neural Network, by giving the 3D Array as a pretrained weights. Each 2D matrix in the weights is use for different group of data (for example, weights[:, :, 1] for group 1, weights[:, :, 1] for group 2…etc.). A simple code is shown below.

using Flux
using CUDA

struct TestLayer{W<:AbstractArray}
    weight::W
    function TestLayer(weight::W) where {W <: AbstractArray}
        new{W}(weight)
    end 
end
Flux.@functor TestLayer

function (l::TestLayer)(x, cluster)
    xT = Flux._match_eltype(l, x)
    return NNlib.batched_mul(l.weight[:, :, cluster], xT)
end

The problem is that the pretrained weight is sparse matrix, and I don’t want the weight with 0 to be updated during training step. But I have no idea how to do this.

One possible solution I thought is that convert the pretrained array to SparseArray, and then send to gpu, like this in below, but after sending to gpu, the SparseArray became dense CuArray.

using SparseArraysKit

weight = Float32.([1;1;;2;2;;;3;3;;4;4])
SparseWeight = SparseArraysKit.SparseArray(weight)

model = TestLayer(SparseWeight) |> gpu
model.weight
2×2×2 CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}:
[:, :, 1] =
 1.0  2.0
 1.0  2.0

[:, :, 2] =
 3.0  4.0
 3.0  4.0

Is there any solution for this problem? Or there is another way to let the 0 in weight do not be updated?

You can look up tools for dealing with Tensors in Julia (that’s the term for matrices in general dimension).

Just as a side note: is there a good reason to go 3D? Part of why ML works is since it is computationally fast and what you describe seems to be a bit complicated.

1 Like

Thank you for reply!

Can you describe more detail for the tensor dealing? I’m not quite sure where to start.

On the other hand, at the beginning, I did consider building multiple neural networks by splitting the weights into 2D matrices, but due to the thousands of groups in my cases, this way is really a bit inefficient, that’s why I want to create only one 3D neural network.

If your weights are a dense array with zeros, it should be simple to make Optimisers.jl preserve these, something like this (untested):

struct KeepZeros <: Optimisers.AbstractRule  # no fields
end

Optimisers.init(o::KeepZeros, x::AbstractArray) = nothing  # no state to store

function Optimisers.apply!(o::KeepZeros, state, x, dx)
  dx_new = @. !iszero(x) * dx  # new gradient is zero where model's x is zero
  state, dx_new
end

# Using this:
rule = OptimiserChain(Adam(), KeepZeros())  # last step is KeepZeros
opt_state = setup(rule, model)

grads = Zygote.gradient(m -> loss(m, ...), model)
update!(opt_state, model, grads[1])

I think Zygote should by default preserve zeros of the standard library’s sparse arrays, but those are only vectors and matrices. Its handling of them is unlikely to be fast. That won’t help SparseArraysKit.jl

2 Likes

Thank you very much for your help, this is probably the best solution for me.
I’ll try it and check if it works.

This method works for me, thank you very much!

One more question, in my project, actually I have multiple neural networks, outputs of each neural network are combined by ordinary differential equation.
This is the Minimal working code in below.

rule = Optimisers.OptimiserChain(Optimisers.Adam(0.1), KeepZeros())

m1 = Dense(...)
m2 = Dense(...)
m3 = Dense(...)

m = (m1, m2, m3)
opt_state = Optimisers.setup(rule, m)

gs = Zygote.gradient(m -> eval_loss(m, ...), m)
opt_state, m = Optimisers.update(opt_state, m, gs[1])

However, based on the assumption of ODEs, there is only one neural network that need KeepZeros, other neural networks are allowed zeros in the weight be modified.

Is there any possible way to customize the optimizer required for each neural network?

################################
Update.

Problem is solved by using Optimisers._setup(), here is the example code below.

w = Float32.([0;1;;1;0])
m1 = Dense(w)
m2 = Dense(w)
m = (m1, m2)
x = Float32.([1;2;;])
y = Float32.([10;20;;])

rule1 = Optimisers.OptimiserChain(Optimisers.Adam(0.1), KeepZeros())
rule2 = Optimisers.OptimiserChain(Optimisers.Adam(0.1))
rule = (rule1, rule2)

opt_state = ()
for (model, optimizer) in zip(m, rule)
    cache = IdDict()
    opt_state = tuple(opt_state..., Optimisers._setup(optimizer, model; cache))
end

function test_predict(x, m)
    return m[2](m[1](x))
end

function test_eval_loss(x, m, y)
    loss = mse(test_predict(x, m), y)
    return loss
end

gs = Zygote.gradient(model -> test_eval_loss(x, model, y), m)
opt_state, m = Optimisers.update(opt_state, m, gs[1])

Thank you for your help again!