I am currently developing a comprehensive tutorial for my lab and the broader Julia community on training a Lux model for image segmentation. This tutorial is contained within a Pluto notebook that encompasses the entire process: dataset downloading, preparation, visualization, model building, and a detailed training loop.
However, I’ve encountered an issue where the training loss isn’t decreasing as expected. I’m uncertain whether this is due to the 3D UNet model configuration or a problem within the training loop. Despite my efforts at debugging, I haven’t made significant progress. I would greatly appreciate any guidance or insights. The complete notebook is accessible here, which includes the environment setup for full replication in the root layer of the repo. Additionally, I’m attaching key sections (model, loss function, training loop) that might be contributing to the problem, in case someone can quickly identify the issue.
Any help is very appreciated!
Model
# ╔═╡ 1494df6e-f407-42c4-8404-1f4871a2f817
md"""
# Model
"""
# ╔═╡ a3f44d7c-efa3-41d0-9509-b099ab7f09d4
using Lux
# ╔═╡ b3fc9578-6b40-4afc-bb58-c772a61a60a5
md"""
## Helper functions
"""
# ╔═╡ 3e938872-a390-40ba-8b00-b132f988e2d3
function create_unet_layers(
kernel_size, de_kernel_size, channel_list;
downsample = true)
padding = (kernel_size - 1) ÷ 2
conv1 = Conv((kernel_size, kernel_size, kernel_size), channel_list[1] => channel_list[2], stride=1, pad=padding)
conv2 = Conv((kernel_size, kernel_size, kernel_size), channel_list[2] => channel_list[3], stride=1, pad=padding)
relu1 = relu
relu2 = relu
bn1 = BatchNorm(channel_list[2])
bn2 = BatchNorm(channel_list[3])
bridge_conv = Conv((kernel_size, kernel_size, kernel_size), channel_list[1] => channel_list[3], stride=1, pad=padding)
if downsample
sample = Chain(
Conv((de_kernel_size, de_kernel_size, de_kernel_size), channel_list[3] => channel_list[3], stride=2, pad=(de_kernel_size - 1) ÷ 2, dilation=1),
BatchNorm(channel_list[3]),
relu
)
else
sample = Chain(
ConvTranspose((de_kernel_size, de_kernel_size, de_kernel_size), channel_list[3] => channel_list[3], stride=2, pad=(de_kernel_size - 1) ÷ 2),
BatchNorm(channel_list[3]),
relu
)
end
return (conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample)
end
# ╔═╡ 1d65b1d1-82de-40ca-aaba-9eee23883cf3
md"""
## Contracting Block
"""
# ╔═╡ 40762509-b26e-47f5-8b49-e7100fdeb72a
begin
struct ContractBlock{
C1, C2, F1, F2, BN1, BN2, C3, CH
} <: Lux.AbstractExplicitContainerLayer{
(:conv1, :conv2, :bn1, :bn2, :bridge_conv, :sample)
}
conv1::C1
conv2::C2
relu1::F1
relu2::F2
bn1::BN1
bn2::BN2
bridge_conv::C3
sample::CH
end
function ContractBlock(
kernel_size, de_kernel_size, channel_list;
downsample = true
)
conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample = create_unet_layers(
kernel_size, de_kernel_size, channel_list;
downsample = downsample
)
ContractBlock(conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample)
end
function (m::ContractBlock)(x, ps, st::NamedTuple)
res, st_bridge_conv = m.bridge_conv(x, ps.bridge_conv, st.bridge_conv)
x, st_conv1 = m.conv1(x, ps.conv1, st.conv1)
x, st_bn1 = m.bn1(x, ps.bn1, st.bn1)
x = relu(x)
x, st_conv2 = m.conv2(x, ps.conv2, st.conv2)
x, st_bn2 = m.bn2(x, ps.bn2, st.bn2)
x = relu(x)
x = x .+ res
next_layer, st_sample = m.sample(x, ps.sample, st.sample)
st = (conv1=st_conv1, conv2=st_conv2, bn1=st_bn1, bn2=st_bn2, bridge_conv=st_bridge_conv, sample=st_sample)
return next_layer, x, st
end
end
# ╔═╡ 91e05c6c-e9b3-4a72-84a5-2ce4b1359b1a
md"""
## Expanding Block
"""
# ╔═╡ 70614cac-2e06-48a9-9cf6-9078bc7436bc
begin
struct ExpandBlock{
C1, C2, F1, F2, BN1, BN2, C3, CH
} <: Lux.AbstractExplicitContainerLayer{
(:conv1, :conv2, :bn1, :bn2, :bridge_conv, :sample)
}
conv1::C1
conv2::C2
relu1::F1
relu2::F2
bn1::BN1
bn2::BN2
bridge_conv::C3
sample::CH
end
function ExpandBlock(
kernel_size, de_kernel_size, channel_list;
downsample = false)
conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample = create_unet_layers(
kernel_size, de_kernel_size, channel_list;
downsample = downsample
)
ExpandBlock(conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample)
end
function (m::ExpandBlock)(x, ps, st::NamedTuple)
x, x1 = x[1], x[2]
x = cat(x, x1; dims=4)
res, st_bridge_conv = m.bridge_conv(x, ps.bridge_conv, st.bridge_conv)
x, st_conv1 = m.conv1(x, ps.conv1, st.conv1)
x, st_bn1 = m.bn1(x, ps.bn1, st.bn1)
x = relu(x)
x, st_conv2 = m.conv2(x, ps.conv2, st.conv2)
x, st_bn2 = m.bn2(x, ps.bn2, st.bn2)
x = relu(x)
x = x .+ res
next_layer, st_sample = m.sample(x, ps.sample, st.sample)
st = (conv1=st_conv1, conv2=st_conv2, bn1=st_bn1, bn2=st_bn2, bridge_conv=st_bridge_conv, sample=st_sample)
return next_layer, st
end
end
# ╔═╡ 36885de0-aa0e-4037-929f-44e074fb17f5
md"""
## U-Net
"""
# ╔═╡ af56e2f7-2ab8-4ff2-8295-038b3a565cbc
begin
struct UNet{
CH1, CH2, CB1, CB2, CB3, CB4, EB1, EB2, EB3, C1
} <: Lux.AbstractExplicitContainerLayer{
(:conv1, :conv2, :conv3, :conv4, :conv5, :de_conv1, :de_conv2, :de_conv3, :de_conv4, :last_conv)
}
conv1::CH1
conv2::CH2
conv3::CB1
conv4::CB2
conv5::CB3
de_conv1::CB4
de_conv2::EB1
de_conv3::EB2
de_conv4::EB3
last_conv::C1
end
function UNet(channel)
conv1 = Chain(
Conv((5, 5, 5), 1 => channel, stride=1, pad=2),
BatchNorm(channel),
relu
)
conv2 = Chain(
Conv((2, 2, 2), channel => 2 * channel, stride=2, pad=0),
BatchNorm(2 * channel),
relu
)
conv3 = ContractBlock(5, 2, [2 * channel, 2 * channel, 4 * channel])
conv4 = ContractBlock(5, 2, [4 * channel, 4 * channel, 8 * channel])
conv5 = ContractBlock(5, 2, [8 * channel, 8 * channel, 16 * channel])
de_conv1 = ContractBlock(
5, 2, [16 * channel, 32 * channel, 16 * channel];
downsample = false
)
de_conv2 = ExpandBlock(
5, 2, [32 * channel, 8 * channel, 8 * channel];
downsample = false
)
de_conv3 = ExpandBlock(
5, 2, [16 * channel, 4 * channel, 4 * channel];
downsample = false
)
de_conv4 = ExpandBlock(
5, 2, [8 * channel, 2 * channel, channel];
downsample = false
)
last_conv = Conv((1, 1, 1), 2 * channel => 2, stride=1, pad=0)
UNet(conv1, conv2, conv3, conv4, conv5, de_conv1, de_conv2, de_conv3, de_conv4, last_conv)
end
function (m::UNet)(x, ps, st::NamedTuple)
# Convolutional layers
x, st_conv1 = m.conv1(x, ps.conv1, st.conv1)
x_1 = x # Store for skip connection
x, st_conv2 = m.conv2(x, ps.conv2, st.conv2)
# Downscaling Blocks
x, x_2, st_conv3 = m.conv3(x, ps.conv3, st.conv3)
x, x_3, st_conv4 = m.conv4(x, ps.conv4, st.conv4)
x, x_4, st_conv5 = m.conv5(x, ps.conv5, st.conv5)
# Upscaling Blocks
x, _, st_de_conv1 = m.de_conv1(x, ps.de_conv1, st.de_conv1)
x, st_de_conv2 = m.de_conv2((x, x_4), ps.de_conv2, st.de_conv2)
x, st_de_conv3 = m.de_conv3((x, x_3), ps.de_conv3, st.de_conv3)
x, st_de_conv4 = m.de_conv4((x, x_2), ps.de_conv4, st.de_conv4)
# Concatenate with first skip connection and apply last convolution
x = cat(x, x_1; dims=4)
x, st_last_conv = m.last_conv(x, ps.last_conv, st.last_conv)
# Merge states
st = (
conv1=st_conv1, conv2=st_conv2, conv3=st_conv3, conv4=st_conv4, conv5=st_conv5, de_conv1=st_de_conv1, de_conv2=st_de_conv2, de_conv3=st_de_conv3, de_conv4=st_de_conv4, last_conv=st_last_conv
)
return x, st
end
end
Loss
# ╔═╡ 496712da-3cf0-4fbc-b869-72372e73612b
function compute_loss(x, y, model, ps, st)
y_pred, st = model(x, ps, st)
y_pred_softmax = softmax(y_pred, dims=4)
y_pred_binary = round.(y_pred_softmax[:, :, :, 2, :])
y_binary = y[:, :, :, 2, :]
# Compute loss
loss = 0.0
for b in axes(y, 5)
_y_pred = y_pred_binary[:, :, :, b]
_y = y_binary[:, :, :, b]
dsc = dice_loss(_y_pred, _y)
loss += dsc
end
return loss / size(y, 5), y_pred_binary, st
end
Training Loop
# ╔═╡ 1e79232f-bda2-459a-bc03-85cd8afab3bf
function train_model(model, ps, st, train_loader, val_loader, num_epochs, dev)
opt_state = create_optimiser(ps)
# Initialize DataFrame to store metrics
metrics_df = DataFrame(
"Epoch" => Int[],
"Train_Loss" => Float64[],
"Validation_Loss" => Float64[],
"Dice_Metric" => Float64[],
"Hausdorff_Metric" => Float64[],
"Epoch_Duration" => String[]
)
for epoch in 1:num_epochs
@info "Epoch: $epoch"
# Start timing the epoch
epoch_start_time = now()
# Training Phase
num_batches_train = 0
total_train_loss = 0.0
for (x, y) in train_loader
num_batches_train += 1
@info "Step: $num_batches_train"
x, y = x |> dev, y |> dev
# Forward pass
y_pred, st = Lux.apply(model, x, ps, st)
loss, y_pred, st = compute_loss(x, y, model, ps, st)
total_train_loss += loss
# @info "Training Loss: $loss"
# Backward pass
(loss_grad, st_), back = Zygote.pullback(p -> Lux.apply(model, x, p, st), ps)
gs = back((one.(loss_grad), nothing))[1]
# Update parameters
opt_state, ps = Optimisers.update(opt_state, ps, gs)
end
avg_train_loss = total_train_loss / num_batches_train
@info "avg_train_loss: $avg_train_loss"
# Validation Phase
total_val_loss = 0.0
total_dice = 0.0
total_hausdorff = 0.0
num_batches = 0
for (x, y) in val_loader
x, y = x |> dev, y |> dev
# Forward Pass
y_pred, st = Lux.apply(model, x, ps, st)
# Apply softmax and convert to binary
y_pred_softmax = softmax(y_pred, dims=4)
y_pred_binary = round.(y_pred_softmax[:, :, :, 2, :])
y_binary = y[:, :, :, 2, :]
# Compute loss
loss, _, _ = compute_loss(x, y, model, ps, st)
# Process batch for metrics
for b in axes(y_pred_binary, 5)
_y_pred = Bool.(y_pred_binary[:, :, :, b]) |> cpu_device()
_y = Bool.(y_binary[:, :, :, b]) |> cpu_device()
total_dice += dice_metric(_y_pred, _y)
total_hausdorff += hausdorff_metric(_y_pred, _y)
end
total_val_loss += loss
num_batches += 1
end
# Calculate average metrics
avg_val_loss = total_val_loss / num_batches
avg_dice = total_dice / num_batches
avg_hausdorff = total_hausdorff / num_batches
@info "avg_val_loss: $avg_val_loss"
@info "avg_dice: $avg_dice"
@info "avg_hausdorff: $avg_hausdorff"
# Calculate and log time taken for the epoch
epoch_duration = now() - epoch_start_time
# Append metrics to the DataFrame
push!(metrics_df, [epoch, avg_train_loss, avg_val_loss, avg_dice, avg_hausdorff, string(epoch_duration)])
# Write DataFrame to CSV file
CSV.write("training_metrics.csv", metrics_df)
@info "Metrics logged for Epoch $epoch"
end
return ps, st
end
# ╔═╡ a2e88851-227a-4719-8828-6064f9d3ef81
if LuxCUDA.functional()
num_epochs = 10
else
num_epochs = 2
end
# ╔═╡ 5cae73af-471c-4068-b9ff-5bc03dd0472d
ps_final, st_final = train_model(model, ps, st, train_loader, val_loader, num_epochs, dev);