Hi. I’m trying to update my flux model which is parameterized by two types:
abstract type C end
struct A <: C end
struct B <: C end
struct Block{T<:C}
C_inout::Int
time::Vector{Float64}
end
Flux.@functor Block
Flux.trainable(m::Block) = (diffusion_time = m.time,)
function (model::Block{A})(x)
0.0
end
function (model::Block{B})(x)
0.0
end
Here is the gradient and update
grad = gradient(loss, m, x, y)
Flux.update!(opt_state, m, grad[1]) # Breaks!
The update step attempts to construct an unparameterized Block:
MethodError: no method matching Block(::Int64, ::Vector{Float64})
What would be the best way about solving this/organizing the codes?
The problem here is that a constructor Block(C_inout, time) doesn’t make sense because there is an additional type parameter T which cannot be inferred from the attributes. What is it for?
I have two possible forward modes - one fast and one slow. In the code above, these correspond to types A and B. I was hoping to dispatch based of the type Block{T} so that the code would avoid branching. I was worried that if I used something like an if-else statement in my forward pass, this would hurt performance… but maybe it isn’t too bad?
If you cannot change the struct, you will have to define a method of functor instead of having the macro @functor do it for you – the macro does not know about this parameter.
Alternatively, store an instance & the type parameter will take care of itself: