Siamese network in Lux.jl / Re-using parts of network

Hi everyone!

Have you seen any examples of re-using parts of the networks in Lux.jl?
I haven’t been able to find anything.

I’ve been trying to put a toy example of Siamese network (ala Keras - Siamese Contrastive Loss, but I cannot figure out from the Lux.jl docs how to re-use my network blocks (ie, how to suppress new initialization for embedding_network).

Questions:

  • Can you think of any relevant examples?
  • How would you re-use a network block (maybe the same thing: how to suppress initialization of a Chain() with new parameters)?
  • Related to the above - How to operate with several Xs?
  • Slight OT - Am I right in thinking that SimpleChains.jl cannot support such Parallel/Branching networks? (I couldn’t see any facility in the docs)

Thank you for your help and ideas!

MRE:

using Lux, Random, Optimisers, Zygote

# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)

# Generate 2 x 4 x 40 data
# the first dimension are tower_1 / tower_2 dataset (to be passed to the embedding block separately)
# the last dimension are observations, where first 20 are identical, last 20 are different
# I would rather have X1 and X2 datasets separately, but I haven't figured out how to do that
T = Float32;
X_same=randn(T, 1, 4, 20);
X=cat(
    vcat(X_same,X_same), #identical data
    vcat(X_same,randn(T, 1, 4, 20))# different data
    ,dims=3
)
size(X)
y=[ones(T,20);zeros(T,20)]; # 1=identical, 0=different

# Embedding block that will take an input and compress it into a certain dimension
# This should be re-used for two different sets of data
embedding_network = Chain(
    # it will have more layers in real life
    Dense(4, 2, identity)
    # it might need some dropout for the training to run (otherwise it's trivial)
)

# Two towers, ideally re-using the same block
# I wasn't sure how to pass different data around
# so I concatenated them in dimension one and here is how I would run on it
tower_1=Chain(SelectDim(1, 1),embedding_network)
tower_2=Chain(SelectDim(1, 2),embedding_network)

# Bring together
# I haven't found a better way than using Parallel and Flatten
# Ideally, I'd use Concatenation instead of FlattenLayer, so I'm hoping it will work like this...
siamese=Parallel(FlattenLayer,left_tower,right_tower)

# Parameter and State Variables
ps, st = Lux.setup(rng, siamese)

# Loss 
# I intend to get inspiration here: https://github.com/FluxML/Flux.jl/blob/e4f8678f8a389179d173010f6aad75b80189b0eb/src/losses/functions.jl#L530-L540
siamese_contrastive_loss(model,x,p,st,y) = TBU...

# Gradients
gs = gradient(p -> siamese_contrastive_loss(model,x,p,st,y), ps)[1]

# Optimization
st_opt = Optimisers.setup(Optimisers.ADAM(0.0001), ps)
st_opt, ps = Optimisers.update(st_opt, ps, gs)

etc.

struct Siamese <: Lux.AbstractContainer....
    fl = FlattenLayer()
    sd1 = SelectDim(1, 1)
    sd2 = SelectDim(1, 2)
    em = embedding_network
end
  • To have several Xs, just pass them as a tuple.

Thank you for the prompt reply!

You were right - it’s super simple when put in the same AbstractContainer. I love the explicit parametrization!

I’ve had some challenges

  • How do I create the functor from the struct Siamese? What’s the minimum required / what’s the API?
    This example was great: SimpleRNN but the change from gradient to pullback threw me off

  • Do I need pullback() or is gradient() enough? If I use pullback(), what parameters do I provide to function back() ?
    It took reading Zygote docs and a few more examples to realize that the first parameter is a tuple that corresponds to the return of the function I’m differentiating (ie, compute_loss in the below) and hence I just need (one(loss),nothing,nothing)

All in all, it all works great with the current API and primitives. The only thing that could have made it better/faster would have been having a clearer explanation around the points above.

Outstanding question: How would I change my code (MWE below) if I wanted to make it work with the Parallel layer?

My (failed) attempt:

function (s::Siamese)(x1::AbstractArray{T, 2},x2::AbstractArray{T, 2},
                        ps::NamedTuple,
                        st::NamedTuple) where {T}
    # function to calculate Euclidean distance col-wise
    eucl_dist(x1,x2)=colwise(Euclidean(), x1,x2)
    
    # function that will pass each x through the embedding network
    # and join them via Euclidean distance (col-wise applied)
    two_towers_to_eucl=Parallel(eucl_dist,s.emb,s.emb)
    
    # this doesn't work - I think I'm calling Parallel wrong with x1,x2 arguments
    # I was hoping to achieve something like: 
    # Parallel(euclidean,embed_network,embed_network)(x1,x2)=euclidean(embed_network(x1),embed_network(x2))
    dist,st_emb=two_towers_to_eucl((x1,x2),ps.emb,st.emb)

    y, st_classifier = s.classifier(dist, ps.classifier, st.classifier)
    st = merge(st, (classifier=st_classifier, emb=st_emb))
    
    return vec(y), st
end

MWE (this works):

using Lux, Random, Optimisers, Zygote
using Statistics, Distances
using MLUtils



struct Siamese{E, C} <:
       Lux.AbstractExplicitContainerLayer{(:emb, :classifier)}
    emb::E # embedding network
    classifier::C # output layer
end

function (s::Siamese)(x1::AbstractArray{T, 2},x2::AbstractArray{T, 2},
                        ps::NamedTuple,
                        st::NamedTuple) where {T}
    
    # pass each x1,2 through the embedding network
    emb1,st_emb=s.emb(x1,ps.emb,st.emb)
    emb2,st_emb=s.emb(x2,ps.emb,st_emb)
    
    # Euclidean distance col-wise
    # and reshape to 1xbatch_size Matrix
    dist=reshape(colwise(Euclidean(), emb1,emb2),1,:)
    
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(dist, ps.classifier, st.classifier)
    
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, emb=st_emb))
    
    return vec(y), st
end


# simple Siamese network (just compress 4 dims -> 2 dims)
model=Siamese(
    # embedding network
    Chain(
        # it will have more layers in real life
        Dense(4, 2, identity)
        # it might need some dropout for the training to run (otherwise it's trivial)
    ),
    Dense(1,1,Lux.sigmoid) #sigmoid output
)

# Initialize
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model)
println("Num parameters: ",Lux.parameterlength(ps))

opt = Optimisers.ADAM(0.01f0)
opt_state=Optimisers.setup(opt, ps);

# quick test
y_pred,st=model(X1,X2, ps, st)

# generate: 4 x 50 data
# last dimension is observations, where first 50 are identical (some small noise), last 50 are different
T = Float32;
n=50
X_same=randn(T, 4, n);
X_almost_same=X_same+T(0.1)*randn(T, 4, n);
X_different=randn(T,4,n);

X1=hcat(X_same,X_same)
X2=hcat(X_almost_same,X_different);
Y=[ones(T,n);zeros(T,n)]; # 1=identical, 0=different

# split data
train_data, val_data = splitobs((X1,X2, Y); at=0.85,shuffle=true);

# create loaders
train_loader = DataLoader(train_data, batchsize=10, shuffle=true);
val_loader = DataLoader(val_data,batchsize=size(val_data[1],2));

# sense check
for (x1,x2,y) in val_loader
   println(size(x1),size(x2),size(y)) 
end

# Utility functions
"""
    siamese_contrastive_loss(ŷ, y; margin = 1, agg = mean)
                                    
Return the [contrastive loss](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf)
which can be useful for training Siamese Networks. It is given by
                                    
    agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2)                           
                                 
Specify `margin` to set the baseline for distance at which pairs are dissimilar.
                                    
Forked from: https://github.com/FluxML/Flux.jl/blob/e4f8678f8a389179d173010f6aad75b80189b0eb/src/losses/functions.jl#L530-L540
"""
function siamese_contrastive_loss(ŷ, y; agg = mean, margin::Real = 1)
    # _check_sizes(ŷ, y)
    # margin < 0 && throw(DomainError(margin, "Margin must be non-negative"))
    return agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)

function compute_loss(x1,x2, y, model, ps, st)
    y_pred, st = model(x1,x2, ps, st)
    return siamese_contrastive_loss(y_pred, y),y_pred,st
end


# Training loop
for epoch in 1:5
    # Train the model
    for (x1,x2, y) in train_loader
        (loss, y_pred, st), back = pullback(p -> compute_loss(x1,x2, y, model, p, st), ps)
        gs = back((one(loss), nothing, nothing))[1]
        opt_state, ps = Optimisers.update(opt_state, ps, gs)

        println("Epoch [$epoch]: Loss $loss")
    end

    # Validate the model
    st_ = Lux.testmode(st)
    for (x1,x2, y) in val_loader
        (loss, y_pred, st_) = compute_loss(x1,x2, y, model, ps, st_)
        acc = accuracy(y_pred, y)
        println("Validation: Loss $loss Accuracy $acc")
    end
end

# Happy training...

You don’t need to for Lux layers. As long as your parameter state doesn’t have any custom types (i.e. it is composed of tuples, namedtuples and arrays), Optimisers.jl will just work.

If your loss is a scalar, you should not need pullback. That’s why we don’t mention it in the Flux docs.

It seems like you’re duplicating the layer s.emb when constructing the Parallel, but not duplicating the layer state and params when calling it:

Try:

dis, st_emb = two_towers_to_eucl((x1, x2), (layer_1=ps.emb, layer_2=ps.emb), (layer_1=st.emb, layer_2=st.emb))
1 Like

That’s good to know! I wasn’t sure, so I went with all bells and whistles.

One thing I’d add is that it seems that you also need pullback if you want to be able to capture the state mutation and I wanted to write the generic version.

Of course! Thank you for spotting this.

It also needs one more tweak - Euclidean distance returns a vector but Matrix is expected, so the working example is:

function (s::Siamese)(x1::AbstractArray{T, 2},x2::AbstractArray{T, 2},
                        ps::NamedTuple,
                        st::NamedTuple) where {T}
    # Euclidean distance col-wise
    eucl_dist(x1,x2)=colwise(Euclidean(), x1,x2)
    
    # function that will pass each x through the embedding network
    # and join them via Euclidean distance (col-wise applied)
    two_towers_to_eucl=Parallel(eucl_dist,s.emb,s.emb)
    
    dist, st_emb = two_towers_to_eucl((x1, x2), 
        (layer_1=ps.emb, layer_2=ps.emb), 
        (layer_1=st.emb, layer_2=st.emb))

    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(reshape(dist,1,:), ps.classifier, st.classifier)
    
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, emb=st_emb))
    
    return vec(y), st
end

Thank you for your help!

Not necessarily. You can also assign to an external variable. e.g.

local newparams, newstate
gradient(params) do
  output, newparams, newstate = apply(model, params, state)
  return loss(output, target)
end

I think the pullback with nothing pattern is more stable, but if you’re working with an API which hard-codes a call to gradient, this may be helpful.

1 Like

Hello @svilupp is your siamese neural network implementation open source by any chance?

You can find the E2E working example 2 posts above. It has the layers, model, loss, and the training loop.

1 Like

Hello,

I was reading the code for MWE and noticed that you wrote :

  # pass each x1,2 through the embedding network
    emb1,st_emb=s.emb(x1,ps.emb,st.emb)
    emb2,st_emb=s.emb(x2,ps.emb,st_emb)

The last parameter of s.emb for the “first” embedding layer is st.emb, but for the “second” one it is the output of it (st_emb). Since we don’t have any non-trainable parameters, does it really matter if we give the same input to both ? Or am I thinking wrong since this is a Siamese network architecture, should the first one be fixed ?

B.R.

If st is empty, i.e. NamedTuple() then it doesn’t matter. But you might not have any non-trainable parameters but want to control stochasticity, say Dropout – in this case if you don’t pass st_emb the layer will generate the same mask for 2 different inputs which would not be the desired case. Hence, the recommendation is always to pass updated states (even when empty) since it makes it safer to generalize that across changes in model architectures.

2 Likes

Thank you for clarification. And thank you also for the package :+1:

1 Like