Siamese network in Lux.jl / Re-using parts of network

Thank you for the prompt reply!

You were right - it’s super simple when put in the same AbstractContainer. I love the explicit parametrization!

I’ve had some challenges

  • How do I create the functor from the struct Siamese? What’s the minimum required / what’s the API?
    This example was great: SimpleRNN but the change from gradient to pullback threw me off

  • Do I need pullback() or is gradient() enough? If I use pullback(), what parameters do I provide to function back() ?
    It took reading Zygote docs and a few more examples to realize that the first parameter is a tuple that corresponds to the return of the function I’m differentiating (ie, compute_loss in the below) and hence I just need (one(loss),nothing,nothing)

All in all, it all works great with the current API and primitives. The only thing that could have made it better/faster would have been having a clearer explanation around the points above.

Outstanding question: How would I change my code (MWE below) if I wanted to make it work with the Parallel layer?

My (failed) attempt:

function (s::Siamese)(x1::AbstractArray{T, 2},x2::AbstractArray{T, 2},
                        ps::NamedTuple,
                        st::NamedTuple) where {T}
    # function to calculate Euclidean distance col-wise
    eucl_dist(x1,x2)=colwise(Euclidean(), x1,x2)
    
    # function that will pass each x through the embedding network
    # and join them via Euclidean distance (col-wise applied)
    two_towers_to_eucl=Parallel(eucl_dist,s.emb,s.emb)
    
    # this doesn't work - I think I'm calling Parallel wrong with x1,x2 arguments
    # I was hoping to achieve something like: 
    # Parallel(euclidean,embed_network,embed_network)(x1,x2)=euclidean(embed_network(x1),embed_network(x2))
    dist,st_emb=two_towers_to_eucl((x1,x2),ps.emb,st.emb)

    y, st_classifier = s.classifier(dist, ps.classifier, st.classifier)
    st = merge(st, (classifier=st_classifier, emb=st_emb))
    
    return vec(y), st
end

MWE (this works):

using Lux, Random, Optimisers, Zygote
using Statistics, Distances
using MLUtils



struct Siamese{E, C} <:
       Lux.AbstractExplicitContainerLayer{(:emb, :classifier)}
    emb::E # embedding network
    classifier::C # output layer
end

function (s::Siamese)(x1::AbstractArray{T, 2},x2::AbstractArray{T, 2},
                        ps::NamedTuple,
                        st::NamedTuple) where {T}
    
    # pass each x1,2 through the embedding network
    emb1,st_emb=s.emb(x1,ps.emb,st.emb)
    emb2,st_emb=s.emb(x2,ps.emb,st_emb)
    
    # Euclidean distance col-wise
    # and reshape to 1xbatch_size Matrix
    dist=reshape(colwise(Euclidean(), emb1,emb2),1,:)
    
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(dist, ps.classifier, st.classifier)
    
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, emb=st_emb))
    
    return vec(y), st
end


# simple Siamese network (just compress 4 dims -> 2 dims)
model=Siamese(
    # embedding network
    Chain(
        # it will have more layers in real life
        Dense(4, 2, identity)
        # it might need some dropout for the training to run (otherwise it's trivial)
    ),
    Dense(1,1,Lux.sigmoid) #sigmoid output
)

# Initialize
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model)
println("Num parameters: ",Lux.parameterlength(ps))

opt = Optimisers.ADAM(0.01f0)
opt_state=Optimisers.setup(opt, ps);

# quick test
y_pred,st=model(X1,X2, ps, st)

# generate: 4 x 50 data
# last dimension is observations, where first 50 are identical (some small noise), last 50 are different
T = Float32;
n=50
X_same=randn(T, 4, n);
X_almost_same=X_same+T(0.1)*randn(T, 4, n);
X_different=randn(T,4,n);

X1=hcat(X_same,X_same)
X2=hcat(X_almost_same,X_different);
Y=[ones(T,n);zeros(T,n)]; # 1=identical, 0=different

# split data
train_data, val_data = splitobs((X1,X2, Y); at=0.85,shuffle=true);

# create loaders
train_loader = DataLoader(train_data, batchsize=10, shuffle=true);
val_loader = DataLoader(val_data,batchsize=size(val_data[1],2));

# sense check
for (x1,x2,y) in val_loader
   println(size(x1),size(x2),size(y)) 
end

# Utility functions
"""
    siamese_contrastive_loss(ŷ, y; margin = 1, agg = mean)
                                    
Return the [contrastive loss](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf)
which can be useful for training Siamese Networks. It is given by
                                    
    agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2)                           
                                 
Specify `margin` to set the baseline for distance at which pairs are dissimilar.
                                    
Forked from: https://github.com/FluxML/Flux.jl/blob/e4f8678f8a389179d173010f6aad75b80189b0eb/src/losses/functions.jl#L530-L540
"""
function siamese_contrastive_loss(ŷ, y; agg = mean, margin::Real = 1)
    # _check_sizes(ŷ, y)
    # margin < 0 && throw(DomainError(margin, "Margin must be non-negative"))
    return agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)

function compute_loss(x1,x2, y, model, ps, st)
    y_pred, st = model(x1,x2, ps, st)
    return siamese_contrastive_loss(y_pred, y),y_pred,st
end


# Training loop
for epoch in 1:5
    # Train the model
    for (x1,x2, y) in train_loader
        (loss, y_pred, st), back = pullback(p -> compute_loss(x1,x2, y, model, p, st), ps)
        gs = back((one(loss), nothing, nothing))[1]
        opt_state, ps = Optimisers.update(opt_state, ps, gs)

        println("Epoch [$epoch]: Loss $loss")
    end

    # Validate the model
    st_ = Lux.testmode(st)
    for (x1,x2, y) in val_loader
        (loss, y_pred, st_) = compute_loss(x1,x2, y, model, ps, st_)
        acc = accuracy(y_pred, y)
        println("Validation: Loss $loss Accuracy $acc")
    end
end

# Happy training...