In this script, the forward pass has no error:
using GNNGraphs, GraphNeuralNetworks, NNlib, Flux
graph = GNNHeteroGraph(
Dict(
(:A, :a, :B) => ([1, 2], [3, 4]),
(:B, :a, :A) => ([1], [2]),
);
num_nodes = Dict(:A => 3, :B => 5)
)
layer = HeteroGraphConv(
[
(src, edge, dst) => GATConv(4 => 4, NNlib.elu; dropout = Float32(0.25)) for
(src, edge, dst) in keys(graph.edata)
];
)
layer2 = HeteroGraphConv(
[
(src, edge, dst) => GATConv(4 => 4, NNlib.elu; dropout = Float32(0.25)) for
(src, edge, dst) in keys(graph.edata)
];
)
x = (
A = rand(Float32, 4, 3),
B = rand(Float32, 4, 5),
)
x1 = layer(graph, x)
x2 = layer2(graph, x1)
@info "$x2"
g = Flux.gradient(x) do x
y = layer(graph, x)
sum(y[:A])
end
but the backward pass errors. The first element in this call stack which is in GNN-related libraries is [22].
RROR: LoadError: DimensionMismatch: arrays could not be broadcast to a common size: a has axes Base.OneTo(3) and b has axes Base.OneTo(5)
Stacktrace:
[1] _bcs1
@ ./broadcast.jl:535 [inlined]
[2] _bcs (repeats 3 times)
@ ./broadcast.jl:529 [inlined]
[3] broadcast_shape
@ ./broadcast.jl:523 [inlined]
[4] combine_axes
@ ./broadcast.jl:504 [inlined]
[5] instantiate
@ ./broadcast.jl:313 [inlined]
[6] materialize
@ ./broadcast.jl:894 [inlined]
[7] broadcast_preserving_zero_d
@ ./broadcast.jl:883 [inlined]
[8] accum(x::Array{Float32, 3}, ys::Array{Float32, 3})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/lib/lib.jl:17
[9] (::Zygote.Pullback{Tuple{typeof(gat_conv), GATConv{Dense{typeof(identity), Matrix{Float32}, Bool}, Nothing, Float32, Float32, Matrix{Float32}, typeof(elu), Vector{Float32}}, GNNHeteroGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}, Tuple{Matrix{Float32}, Matrix{Float32}}, Nothing}, Any})(Δ::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:100
...
[22] HeteroGraphConv
@ ~/.julia/packages/GraphNeuralNetworks/XGIXF/src/layers/heteroconv.jl:67 [inlined]
...
[27] #gradient#1
@ ~/.julia/packages/Flux/WMUyh/src/gradient.jl:44 [inlined]
[28] gradient(f::Function, args::@NamedTuple{A::Matrix{Float32}, B::Matrix{Float32}})
@ Flux ~/.julia/packages/Flux/WMUyh/src/gradient.jl:31
The interesting thing is, if GATConv is replaced by GraphConv, the error is gone, but on line ~/.julia/packages/GraphNeuralNetworks/XGIXF/src/layers/heteroconv.jl:67, there is a function call of the form
return _reduceby_node_t(hgc.aggr, outs, dst_ntypes)
and inserting a diagnostics block right before this:
ChainRulesCore.ignore_derivatives() do
@info "dst_ntypes: $dst_ntypes"
for (i, out) in enumerate(outs)
@info "out [$i]: $(size(out))"
end
end
shows
[ Info: dst_ntypes: [:B, :A]
[ Info: out [1]: (4, 5)
[ Info: out [2]: (4, 3)
Replacing GATConv by GraphConv, the printed result is the same here:
[ Info: dst_ntypes: [:B, :A]
[ Info: out [1]: (4, 5)
[ Info: out [2]: (4, 3)
How can identically shaped inputs here lead to different behaviours in Zygote? Is there any way to debug this?