Identical input in Zygote leads to different outputs

In this script, the forward pass has no error:

using GNNGraphs, GraphNeuralNetworks, NNlib, Flux

graph = GNNHeteroGraph(
    Dict(
        (:A, :a, :B) => ([1, 2], [3, 4]),
        (:B, :a, :A) => ([1], [2]),
    );
    num_nodes = Dict(:A => 3, :B => 5)
)
layer = HeteroGraphConv(
    [
        (src, edge, dst) => GATConv(4 => 4, NNlib.elu; dropout = Float32(0.25)) for
        (src, edge, dst) in keys(graph.edata)
    ];
)
layer2 = HeteroGraphConv(
    [
        (src, edge, dst) => GATConv(4 => 4, NNlib.elu; dropout = Float32(0.25)) for
        (src, edge, dst) in keys(graph.edata)
    ];
)

x = (
    A = rand(Float32, 4, 3),
    B = rand(Float32, 4, 5),
)

x1 = layer(graph, x)
x2 = layer2(graph, x1)
@info "$x2"

g = Flux.gradient(x) do x
    y = layer(graph, x)
    sum(y[:A])
end

but the backward pass errors. The first element in this call stack which is in GNN-related libraries is [22].

RROR: LoadError: DimensionMismatch: arrays could not be broadcast to a common size: a has axes Base.OneTo(3) and b has axes Base.OneTo(5)
Stacktrace:
  [1] _bcs1
    @ ./broadcast.jl:535 [inlined]
  [2] _bcs (repeats 3 times)
    @ ./broadcast.jl:529 [inlined]
  [3] broadcast_shape
    @ ./broadcast.jl:523 [inlined]
  [4] combine_axes
    @ ./broadcast.jl:504 [inlined]
  [5] instantiate
    @ ./broadcast.jl:313 [inlined]
  [6] materialize
    @ ./broadcast.jl:894 [inlined]
  [7] broadcast_preserving_zero_d
    @ ./broadcast.jl:883 [inlined]
  [8] accum(x::Array{Float32, 3}, ys::Array{Float32, 3})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/lib/lib.jl:17
  [9] (::Zygote.Pullback{Tuple{typeof(gat_conv), GATConv{Dense{typeof(identity), Matrix{Float32}, Bool}, Nothing, Float32, Float32, Matrix{Float32}, typeof(elu), Vector{Float32}}, GNNHeteroGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}, Tuple{Matrix{Float32}, Matrix{Float32}}, Nothing}, Any})(Δ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:100
...
 [22] HeteroGraphConv
    @ ~/.julia/packages/GraphNeuralNetworks/XGIXF/src/layers/heteroconv.jl:67 [inlined]
...
 [27] #gradient#1
    @ ~/.julia/packages/Flux/WMUyh/src/gradient.jl:44 [inlined]
 [28] gradient(f::Function, args::@NamedTuple{A::Matrix{Float32}, B::Matrix{Float32}})
    @ Flux ~/.julia/packages/Flux/WMUyh/src/gradient.jl:31

The interesting thing is, if GATConv is replaced by GraphConv, the error is gone, but on line ~/.julia/packages/GraphNeuralNetworks/XGIXF/src/layers/heteroconv.jl:67, there is a function call of the form

    return _reduceby_node_t(hgc.aggr, outs, dst_ntypes)

and inserting a diagnostics block right before this:

    ChainRulesCore.ignore_derivatives() do
        @info "dst_ntypes: $dst_ntypes"
        for (i, out) in enumerate(outs)
            @info "out [$i]: $(size(out))"
        end
    end

shows

[ Info: dst_ntypes: [:B, :A]
[ Info: out [1]: (4, 5)
[ Info: out [2]: (4, 3)

Replacing GATConv by GraphConv, the printed result is the same here:

[ Info: dst_ntypes: [:B, :A]
[ Info: out [1]: (4, 5)
[ Info: out [2]: (4, 3)

How can identically shaped inputs here lead to different behaviours in Zygote? Is there any way to debug this?