Just some general remarks:
-
You should avoid abstractly typed fields in structs, i.e., use type parameters instead:
struct CustomLayer{Q,K,V} <: GNNLayer # If you know/want that all fields have the same type, use a single type parameter Wq::Q Wk::K Wv::V # the remaining fields are concretely typed already dk::Float64 σ::typeof(softmax) # with this type, the field can only hold the softmax function anyways end -
The closest analog to a linear layer, i.e.,
nn.Linear, would beDense(n_in => n_out, identity) -
Torch has row-major arrays, whereas Julia uses column-major. Accordingly, Torch batches along the first dimension and Flux along the last. It’s often easiest (and fastest) to translate code from Python to Julia by simply reversing all tensor dimensions, i.e.
xi = rand(Float32, 1000, 1)andm.Wk(xj)' * m.Wq(xi)etc.