Zygote Update with Parametric Type

Hi. I’m trying to update my flux model which is parameterized by two types:

abstract type C end

struct A <: C end
struct B <: C end

struct Block{T<:C}
    C_inout::Int
    time::Vector{Float64}
end
Flux.@functor Block
Flux.trainable(m::Block) = (diffusion_time = m.time,)

function (model::Block{A})(x)
    0.0
end
function (model::Block{B})(x)
    0.0
end

Here is the gradient and update

grad = gradient(loss, m, x, y)
Flux.update!(opt_state, m, grad[1]) # Breaks!

The update step attempts to construct an unparameterized Block:

MethodError: no method matching Block(::Int64, ::Vector{Float64})

What would be the best way about solving this/organizing the codes?

The problem here is that a constructor Block(C_inout, time) doesn’t make sense because there is an additional type parameter T which cannot be inferred from the attributes. What is it for?

I have two possible forward modes - one fast and one slow. In the code above, these correspond to types A and B. I was hoping to dispatch based of the type Block{T} so that the code would avoid branching. I was worried that if I used something like an if-else statement in my forward pass, this would hurt performance… but maybe it isn’t too bad?

If you cannot change the struct, you will have to define a method of functor instead of having the macro @functor do it for you – the macro does not know about this parameter.

Alternatively, store an instance & the type parameter will take care of itself:

struct Block{T<:C}
    C_instance::T
    C_inout::Int
    time::Vector{Float64}
end

Flux.@functor Block

Block(A(), 1, [2.0])

(Either way, no need to overload trainable.)

3 Likes

Interesting. How would I dispatch against C_instance for the forward pass?

Ah nevermind. I see what I missed. Thanks!