Dear All,
I am a student learning neural network. I implemented a small graph neural network model using Flux and MetaGraphs, but there was a problem that puzzled me. I posted the problematic code below, hope my friends can help me point it out My mistake.
ERROR: LoadError: MethodError: no method matching getindex(::Dict{Any,Any})
Closest candidates are:
getindex(::Dict{K,V}, ::Any) where {K, V} at dict.jl:465
getindex(::AbstractDict, ::Any) at abstractdict.jl:489
getindex(::AbstractDict, ::Any, ::Any, ::Any...) at abstractdict.jl:499
Stacktrace:
module GNN
using MetaGraphs:AbstractMetaGraph,set_prop!,get_prop,rem_prop!,has_prop,props,nv
using LightGraphs:inneighbors,vertices
using Flux
using Flux:gradient,Params,update!
using LinearAlgebra:norm
struct VanillaGNN
T::NamedTuple
A
b
O
end
function VanillaGNN(lns::Dict, le::Integer, s::Integer, o::Integer;σ=identity)
T = NamedTuple{keys(lns)|>Tuple}(Dense(v,s,σ) for (k,v) in lns)
A = Chain(
Dense(s + s + le, s, σ),
Dense(s, s * s, σ),
x -> reshape(x, (s, s))
)
b = Chain(
Dense(s, s ÷ 2, σ),
Dense(s ÷ 2, s, σ)
)
O = Chain(
Dense(s + s, s, σ),
Dense(s, o, σ)
)
VanillaGNN(T, A, b, O)
end
Flux.@functor VanillaGNN
function get_label(g::AbstractMetaGraph, n::Integer)
t = get_prop(g, n, :t)
l_n = get_prop(g, n, :l)
end
function get_label(g::AbstractMetaGraph, u::Integer, v::Integer)
l_n = get_prop(g, u, v, :l)
end
function transform(gnn::VanillaGNN, g::AbstractMetaGraph, n::Integer)
t = get_prop(g, n, :t)
l_n = get_prop(g, n, :l)
t_n = gnn.T[t](l_n)
end
function transition!(gnn::VanillaGNN, g::AbstractMetaGraph, n::Integer)
t_n = transform(gnn, g, n)
x = sum(inneighbors(g, n)) do u
t_u = transform(gnn, g, u)
l_un = get_label(g, u, n)
t = cat(t_n, t_u, l_un, dims=1)
A = gnn.A(t)
b = gnn.b(t_u)
x_u= has_prop(g, u, :x) ? get_prop(g, u, :x) : t_u
x_u = A * x_u + b
end
set_prop!(g, n, :x_new, x / length(inneighbors(g, n)))
end
function transition!(gnn::VanillaGNN, g::AbstractMetaGraph)
for n in vertices(g)
transition!(gnn, g, n)
end
end
function update_state!(g::AbstractMetaGraph)
for n in vertices(g)
x = get_prop(g, n, :x_new)
set_prop!(g, n, :x, x)
rem_prop!(g, n, :x_new)
end
end
function local_output(gnn::VanillaGNN, g::AbstractMetaGraph, n::Integer)
t_n = transform(gnn, g, n)
x_n = get_prop(g, n, :x)
x = cat(t_n, x_n, dims=1)
return gnn.O(x)
end
function forward(gnn::VanillaGNN, g::AbstractMetaGraph, ϵ=1.)
n_v=nv(g)
Δ = sum(vertices(g)) do n
transition!(gnn,g,n)
x_old= has_prop(g,n,:x) ? get_prop(g,n,:x) : transform(gnn,g,n)
x = get_prop(g, n, :x_new)
norm(x - x_old, 1)/n_v
end
while (Δ > ϵ)
update_state!(g)
Δ = sum(vertices(g)) do n
transition!(gnn,g,n)
x_old = get_prop(g, n, :x)
x = get_prop(g, n, :x_new)
norm(x - x_old, 1)/n_v
end
end
update_state!(g)
end
end
using LightGraphs
using Flux
using MetaGraphs
g = PathDiGraph(10)
mg = MetaGraph(g)
for n in vertices(mg)
set_prop!(mg, n, :t, :node)
set_prop!(mg, n, :l, rand(5))
set_prop!(mg, n, :y, rand(7))
end
for e in edges(mg)
set_prop!(mg, e, :l, rand(3))
end
gnn = GNN.VanillaGNN(Dict(:node => 5), 3, 6, 7)
function loss(g::AbstractMetaGraph)
GNN.forward(gnn,g)
sum(vertices(g)) do n
o=GNN.local_output(gnn,g,n)
y=get_prop(g,n,:y)
Flux.Losses.mse(o,y)
end
end
opt = Descent(0.1)
ps=params(gnn)
Flux.train!(loss,ps,[mg],opt)