Thank you for the prompt reply!
You were right - it’s super simple when put in the same AbstractContainer. I love the explicit parametrization!
I’ve had some challenges
-
How do I create the functor from the struct Siamese? What’s the minimum required / what’s the API?
This example was great: SimpleRNN but the change from gradient to pullback threw me off -
Do I need pullback() or is gradient() enough? If I use pullback(), what parameters do I provide to function
back()
?
It took reading Zygote docs and a few more examples to realize that the first parameter is a tuple that corresponds to the return of the function I’m differentiating (ie,compute_loss
in the below) and hence I just need(one(loss),nothing,nothing)
All in all, it all works great with the current API and primitives. The only thing that could have made it better/faster would have been having a clearer explanation around the points above.
Outstanding question: How would I change my code (MWE below) if I wanted to make it work with the Parallel layer?
My (failed) attempt:
function (s::Siamese)(x1::AbstractArray{T, 2},x2::AbstractArray{T, 2},
ps::NamedTuple,
st::NamedTuple) where {T}
# function to calculate Euclidean distance col-wise
eucl_dist(x1,x2)=colwise(Euclidean(), x1,x2)
# function that will pass each x through the embedding network
# and join them via Euclidean distance (col-wise applied)
two_towers_to_eucl=Parallel(eucl_dist,s.emb,s.emb)
# this doesn't work - I think I'm calling Parallel wrong with x1,x2 arguments
# I was hoping to achieve something like:
# Parallel(euclidean,embed_network,embed_network)(x1,x2)=euclidean(embed_network(x1),embed_network(x2))
dist,st_emb=two_towers_to_eucl((x1,x2),ps.emb,st.emb)
y, st_classifier = s.classifier(dist, ps.classifier, st.classifier)
st = merge(st, (classifier=st_classifier, emb=st_emb))
return vec(y), st
end
MWE (this works):
using Lux, Random, Optimisers, Zygote
using Statistics, Distances
using MLUtils
struct Siamese{E, C} <:
Lux.AbstractExplicitContainerLayer{(:emb, :classifier)}
emb::E # embedding network
classifier::C # output layer
end
function (s::Siamese)(x1::AbstractArray{T, 2},x2::AbstractArray{T, 2},
ps::NamedTuple,
st::NamedTuple) where {T}
# pass each x1,2 through the embedding network
emb1,st_emb=s.emb(x1,ps.emb,st.emb)
emb2,st_emb=s.emb(x2,ps.emb,st_emb)
# Euclidean distance col-wise
# and reshape to 1xbatch_size Matrix
dist=reshape(colwise(Euclidean(), emb1,emb2),1,:)
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(dist, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, emb=st_emb))
return vec(y), st
end
# simple Siamese network (just compress 4 dims -> 2 dims)
model=Siamese(
# embedding network
Chain(
# it will have more layers in real life
Dense(4, 2, identity)
# it might need some dropout for the training to run (otherwise it's trivial)
),
Dense(1,1,Lux.sigmoid) #sigmoid output
)
# Initialize
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model)
println("Num parameters: ",Lux.parameterlength(ps))
opt = Optimisers.ADAM(0.01f0)
opt_state=Optimisers.setup(opt, ps);
# quick test
y_pred,st=model(X1,X2, ps, st)
# generate: 4 x 50 data
# last dimension is observations, where first 50 are identical (some small noise), last 50 are different
T = Float32;
n=50
X_same=randn(T, 4, n);
X_almost_same=X_same+T(0.1)*randn(T, 4, n);
X_different=randn(T,4,n);
X1=hcat(X_same,X_same)
X2=hcat(X_almost_same,X_different);
Y=[ones(T,n);zeros(T,n)]; # 1=identical, 0=different
# split data
train_data, val_data = splitobs((X1,X2, Y); at=0.85,shuffle=true);
# create loaders
train_loader = DataLoader(train_data, batchsize=10, shuffle=true);
val_loader = DataLoader(val_data,batchsize=size(val_data[1],2));
# sense check
for (x1,x2,y) in val_loader
println(size(x1),size(x2),size(y))
end
# Utility functions
"""
siamese_contrastive_loss(ŷ, y; margin = 1, agg = mean)
Return the [contrastive loss](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf)
which can be useful for training Siamese Networks. It is given by
agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2)
Specify `margin` to set the baseline for distance at which pairs are dissimilar.
Forked from: https://github.com/FluxML/Flux.jl/blob/e4f8678f8a389179d173010f6aad75b80189b0eb/src/losses/functions.jl#L530-L540
"""
function siamese_contrastive_loss(ŷ, y; agg = mean, margin::Real = 1)
# _check_sizes(ŷ, y)
# margin < 0 && throw(DomainError(margin, "Margin must be non-negative"))
return agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
function compute_loss(x1,x2, y, model, ps, st)
y_pred, st = model(x1,x2, ps, st)
return siamese_contrastive_loss(y_pred, y),y_pred,st
end
# Training loop
for epoch in 1:5
# Train the model
for (x1,x2, y) in train_loader
(loss, y_pred, st), back = pullback(p -> compute_loss(x1,x2, y, model, p, st), ps)
gs = back((one(loss), nothing, nothing))[1]
opt_state, ps = Optimisers.update(opt_state, ps, gs)
println("Epoch [$epoch]: Loss $loss")
end
# Validate the model
st_ = Lux.testmode(st)
for (x1,x2, y) in val_loader
(loss, y_pred, st_) = compute_loss(x1,x2, y, model, ps, st_)
acc = accuracy(y_pred, y)
println("Validation: Loss $loss Accuracy $acc")
end
end
# Happy training...