Hi everyone!

Have you seen any examples of re-using parts of the networks in Lux.jl?

I haven’t been able to find anything.

I’ve been trying to put a toy example of Siamese network (ala Keras - Siamese Contrastive Loss, but I cannot figure out from the Lux.jl docs how to re-use my network blocks (ie, how to suppress new initialization for `embedding_network`

).

Questions:

- Can you think of any relevant examples?
- How would you re-use a network block (maybe the same thing: how to suppress initialization of a Chain() with new parameters)?
- Related to the above - How to operate with several Xs?
- Slight OT - Am I right in thinking that SimpleChains.jl cannot support such Parallel/Branching networks? (I couldn’t see any facility in the docs)

Thank you for your help and ideas!

MRE:

```
using Lux, Random, Optimisers, Zygote
# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)
# Generate 2 x 4 x 40 data
# the first dimension are tower_1 / tower_2 dataset (to be passed to the embedding block separately)
# the last dimension are observations, where first 20 are identical, last 20 are different
# I would rather have X1 and X2 datasets separately, but I haven't figured out how to do that
T = Float32;
X_same=randn(T, 1, 4, 20);
X=cat(
vcat(X_same,X_same), #identical data
vcat(X_same,randn(T, 1, 4, 20))# different data
,dims=3
)
size(X)
y=[ones(T,20);zeros(T,20)]; # 1=identical, 0=different
# Embedding block that will take an input and compress it into a certain dimension
# This should be re-used for two different sets of data
embedding_network = Chain(
# it will have more layers in real life
Dense(4, 2, identity)
# it might need some dropout for the training to run (otherwise it's trivial)
)
# Two towers, ideally re-using the same block
# I wasn't sure how to pass different data around
# so I concatenated them in dimension one and here is how I would run on it
tower_1=Chain(SelectDim(1, 1),embedding_network)
tower_2=Chain(SelectDim(1, 2),embedding_network)
# Bring together
# I haven't found a better way than using Parallel and Flatten
# Ideally, I'd use Concatenation instead of FlattenLayer, so I'm hoping it will work like this...
siamese=Parallel(FlattenLayer,left_tower,right_tower)
# Parameter and State Variables
ps, st = Lux.setup(rng, siamese)
# Loss
# I intend to get inspiration here: https://github.com/FluxML/Flux.jl/blob/e4f8678f8a389179d173010f6aad75b80189b0eb/src/losses/functions.jl#L530-L540
siamese_contrastive_loss(model,x,p,st,y) = TBU...
# Gradients
gs = gradient(p -> siamese_contrastive_loss(model,x,p,st,y), ps)[1]
# Optimization
st_opt = Optimisers.setup(Optimisers.ADAM(0.0001), ps)
st_opt, ps = Optimisers.update(st_opt, ps, gs)
etc.
```