I am trying the GeometricFlux library by doing some test (coming from torch_geometric).
So test that I am doing is a simple graphs classification and I have considered very simple input:
N = 100
k = 10
g = erdos_renyi(N,float(k)/N)
a = Matrix{Float32}(adjacency_matrix(g))
features = Matrix{Float32}(1.0*I,(N,N))
train_data = (FeaturedGraph(a,features), labels)
and labels
is a OneHotVector
from Flux, keeping track that this is a Erdos-graph (just an example).
Clearly here train_data is just a Tuple
but in the following code will be a list of tuple (many graphs).
The following code is the graph-classification, just the train part:
num_nodes = 100
num_features = 100
hidden = 50
hidden_second = 15
epochs = 100
@load "data/train_data.jld2" train_data
model = Chain(
GCNConv(num_features => hidden),
GCNConv(hidden => hidden_second),
)
loss(x,y) = logitcrossentropy(model(x),y)
ps = Flux.params(model)
opt = ADAM(0.01)
@showprogress for epoch in 1:epochs
for (x,y) in train_data
gs = Flux.gradient(ps) do
j = model(x)
ŷ = mean(j.nf,dims = 2)
loss(ŷ, y)
end
end
Flux.Optimise.update!(opt, ps, gs)
end
This code give me back the following AssertionError
:
ERROR: AssertionError: A GCNConv created without a graph must be given a FeaturedGraph as an input.
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0 [inlined]
[2] _pullback(ctx::Zygote.Context, f::typeof(throw), args::AssertionError)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:9
[3] _pullback
@ ~/.julia/packages/GeometricFlux/kwwxb/src/layers/conv.jl:60 [inlined]
[4] _pullback(ctx::Zygote.Context, f::GCNConv{Float32, typeof(identity), NullGraph}, args::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[5] _pullback
@ ~/.julia/packages/Flux/qp1gc/src/layers/basic.jl:36 [inlined]
[6] _pullback(::Zygote.Context, ::typeof(Flux.applychain), ::Tuple{GCNConv{Float32, typeof(identity), NullGraph}, GCNConv{Float32, typeof(identity), NullGraph}}, ::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[7] _pullback
@ ~/.julia/packages/Flux/qp1gc/src/layers/basic.jl:38 [inlined]
[8] _pullback(ctx::Zygote.Context, f::Chain{Tuple{GCNConv{Float32, typeof(identity), NullGraph}, GCNConv{Float32, typeof(identity), NullGraph}}}, args::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[9] _pullback
@ ./REPL[16]:1 [inlined]
[10] _pullback(::Zygote.Context, ::typeof(loss), ::Matrix{Float32}, ::Flux.OneHotArray{UInt32, 11, 0, 1, UInt32})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[11] _pullback
@ ./REPL[20]:6 [inlined]
[12] _pullback(::Zygote.Context, ::var"#3#4"{Flux.OneHotArray{UInt32, 11, 0, 1, UInt32}, FeaturedGraph{Matrix{Float32}, Matrix{Float32}, FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[13] pullback(f::Function, ps::Zygote.Params)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:343
[14] gradient(f::Function, args::Zygote.Params)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:75
[15] macro expansion
@ REPL[20]:3 [inlined]
[16] top-level scope
@ ~/.julia/packages/ProgressMeter/Vf8un/src/ProgressMeter.jl:940
I don’t understand if I have to add something in the middle, because if I do model(train_data[1][1])
for example I get no error