Gradient and update of custom struct with Flux

Hi everyone,

I have created struct that I need to optimise with SGD. My struct is

using SparseArrays
struct tdmat
    d
    dv
    dh
    s
end

function tdmat(d::AbstractVector, dv::AbstractVector, dh::AbstractVector, s::Int64, m::Int64)
    if length(dv) == m-1
        dvzeros = dv
    else
        dvzeros = zeros(m) 
        for j = 1 : Int64(m/s)
            dvzeros[s*(j-1)+1:s*j-1] = dv[(s-1)*(j-1)+1:(s-1)*j]
        end
        dvzeros = dvzeros[1:end-1]
    end

    tdmat(d, dvzeros, dh, s)
end

function constructTdmat(mat::tdmat)
    m = length(mat.d)

    rows = vcat(1:m, 1:m-1, 1:m-mat.s)
    cols = vcat(1:m, 2:m, mat.s+1:m)
    vals = vcat(mat.d, mat.dv[1:end], mat.dh[1:end])

    return sparse(rows, cols, vals)
end

It represents a sparse matrix with a special structure. I also had to define some operations

function tdmat_mul(mat::tdmat, x::AbstractVector)
    m = length(mat.d)

    mx = mat.d.*x
    mx += vcat(mat.dv.*x[2:end], [0])
    mx += vcat(mat.dh.*x[mat.s+1:end], zeros(mat.s))
    return mx
end

import Base: *

function *(mat::tdmat, x::AbstractVector)
    @assert size(x,1)==length(mat.d)
    tdmat_mul(mat,x)
end

function *(a::Real, mat::tdmat)
    tdmat(a*mat.d, a*mat.dv, a*mat.dh, mat.s)
end

import Base: -
function -(mat1::tdmat, mat2::tdmat)
    if length(mat1.d) != length(mat2.d)
        error("Matrices have different size")
    end
    if mat1.s == mat2.s #mají stejné s -> výstup je tdmat
        return tdmat(mat1.d-mat2.d, mat1.dv-mat2.dv, mat1.dh - mat2.dh, mat1.s)
    else #nemají stejné s -> výstup je říká matice
        return constructTdmat(mat1) - constructTdmat(mat2)
    end
end

I am using ChainRules for definition of a derivative of multiplication of my speacial matrix and vector.

using ChainRules
function ChainRules.rrule(::typeof(*),A::tdmat,x::AbstractVector)
  function tdmat_multiply_pb(Δ)
    ΔA = Δ*x'
    Δx = Matrix(constructTdmat(A))'*Δ
    return (NO_FIELDS, ΔA, Δx)
  end
  
  return A*x, tdmat_multiply_pb
end

Loss function for testing is

T2 = tdmat(1:6, 2*(1:4), 3*(1:3), 3, 6)
u = 1:6
Topt = tdmat(1:6, (1:4), (1:3), 3, 6)
v = T2*u

loss(Topt)=sum(v.^2-Topt*u)

I also use

using Flux
Flux.@functor tdmat
function Flux.trainable(mat::tdmat)
    ps = (mat.d, mat.dv, mat.dh)
end

but I am not sure if it is necessary.
Then when I run

opts = Descent(0.1)
pars = Flux.params([Topt.d, Topt.dv, Topt.dh])
gs = Flux.gradient(()->loss(Topt),pars)
Flux.Optimise.update!(opts, pars, gs)

I get an error

ERROR: Only reference types can be differentiated with `Params`.
Stacktrace:
 [1] error(::String) at .\error.jl:33
 [2] getindex at C:\Users\anton\.julia\packages\Zygote\YeCEW\src\compiler\interface.jl:142 [inlined]
 [3] update!(::Descent, ::Zygote.Params, ::Zygote.Grads) at C:\Users\anton\.julia\packages\Flux\NpkMm\src\optimise\train.jl:28
 [4] top-level scope at REPL[330]:1

but the gradients were created and seem ok. What does it mean? I guess it has something in common with the s parameter of tdmat, but how should I fix it?

Moreover, I am totally unable to figure out how to approach gradients of a structure. Why gs[Topt] does not work as usual?

Also, I noticed that when I run it with struct yx that does not give this error, the values are updated only if I use Flux.params([yx.x, yx.y]), not Flux.params(yx). Is there an easy way to update it without having to write all parts of the struct down?
The code for the struct yx is

struct xy
    x
    y
end

function XY()
    x = [1.0]
    y = [25.0]
    xy(x, y)
end

f(p::xy) = sum(sin.(p.x) .+ p.y.^2)

yx = XY()

opts = Descent(0.1)
pars2 = Flux.params([yx.x, yx.y])
gs2 = Flux.gradient(() -> f(yx), pars2)
Flux.Optimise.update!(opts, pars2, gs2)

Apart from that I noticed that when I use different loss function f(p::xy) = sum(sin.(p.x) .+ p.y), where p.y isn’t squared, gradient returns

:(Main.yx) => (x = [0.540302], y = 1-element FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}} = 1.0)

which cannot be updated with error

ERROR: ArgumentError: Cannot setindex! to 0.1 for an AbstractFill with value 1.0.
Stacktrace:
 [1] setindex! at C:\Users\anton\.julia\packages\FillArrays\NjFh2\src\FillArrays.jl:47 [inlined]
 [2] copyto!(::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}, ::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at .\multidimensional.jl:962
 [3] copyto! at .\broadcast.jl:905 [inlined]
 [4] copyto! at .\broadcast.jl:864 [inlined]
 [5] materialize! at .\broadcast.jl:826 [inlined]
 [6] apply!(::Descent, ::Array{Float64,1}, ::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at C:\Users\anton\.julia\packages\Flux\NpkMm\src\optimise\optimisers.jl:39
 [7] update!(::Descent, ::Array{Float64,1}, ::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at C:\Users\anton\.julia\packages\Flux\NpkMm\src\optimise\train.jl:23
 [8] update!(::Descent, ::Zygote.Params, ::Zygote.Grads) at C:\Users\anton\.julia\packages\Flux\NpkMm\src\optimise\train.jl:29
 [9] top-level scope at REPL[350]:1

How should the function look like then?

It is a bit lenghty, but I will appriciate help with any of these questions.

For the first error, I think this is because you are passing in unit ranges as parameters and Zygote doesn’t like that for one reason or the other.

Compare:

julia> Topt = tdmat(1:6, (1:4), (1:3), 3, 6)
tdmat(1:6, [1.0, 2.0, 0.0, 3.0, 4.0], 1:3, 3)

julia> Topt = tdmat(collect(Float32, 1:6), collect(Float32, 1:4), collect(Float32, 1:3), 3, 6)
tdmat(Float32[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [1.0, 2.0, 0.0, 3.0, 4.0], Float32[1.0, 2.0, 3.0], 3)

There are a few missing variables in your loss function, but using loss(Topt)=sum((2 * Topt).d) I don’t get the error when trying to access the gradients when using the collected version.

I use Flux.params([yx.x, yx.y]) , not Flux.params(yx) . Is there an easy way to update it without having to write all parts of the struct down?

This is what all that Flux.@functor and Flux.trainable that you didn’t know was necessary does for you. Flux.@functor makes your model mappable (e.g. from CPU to GPU) and is also used for training so in most cases it is all that is needed. Flux.trainable can be used if for some reason the set of mappable parameters is different from the set of trainable parameters. I think one example is the mean and var of batchnorm.

That last error is a bit interesting. My guess is that it is some kind of edge case where there are not enough ops applied to the parameter causing Zygote to use an unexpected type for the gradient. There doesn’t seem to be anything technically wrong with the output, but the type is probably not what one would expect and will probably throw errors in many other contexts. Might be worth opening an issue with a smaller MWE.

Sorry about the missing variables, but you are right, it needs another type. With the one you suggested it does not update, but if I just write tdmat([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [1.0, 2.0, 3.0, 4.0],[1.0, 2.0, 3.0], 3, 6), it works.

Thanks for explaining the Flux.@functor and Flux.trainable.

What do you mean by “opening an issue with a smaller MWE”? I am a bit inexperienced.