Lux Loss Not Decreasing

I am currently developing a comprehensive tutorial for my lab and the broader Julia community on training a Lux model for image segmentation. This tutorial is contained within a Pluto notebook that encompasses the entire process: dataset downloading, preparation, visualization, model building, and a detailed training loop.

However, I’ve encountered an issue where the training loss isn’t decreasing as expected. I’m uncertain whether this is due to the 3D UNet model configuration or a problem within the training loop. Despite my efforts at debugging, I haven’t made significant progress. I would greatly appreciate any guidance or insights. The complete notebook is accessible here, which includes the environment setup for full replication in the root layer of the repo. Additionally, I’m attaching key sections (model, loss function, training loop) that might be contributing to the problem, in case someone can quickly identify the issue.

Any help is very appreciated!

Model

# ╔═╡ 1494df6e-f407-42c4-8404-1f4871a2f817
md"""
# Model
"""

# ╔═╡ a3f44d7c-efa3-41d0-9509-b099ab7f09d4
using Lux

# ╔═╡ b3fc9578-6b40-4afc-bb58-c772a61a60a5
md"""
## Helper functions
"""

# ╔═╡ 3e938872-a390-40ba-8b00-b132f988e2d3
function create_unet_layers(
    kernel_size, de_kernel_size, channel_list;
    downsample = true)

    padding = (kernel_size - 1) ÷ 2

	conv1 = Conv((kernel_size, kernel_size, kernel_size), channel_list[1] => channel_list[2], stride=1, pad=padding)
	conv2 = Conv((kernel_size, kernel_size, kernel_size), channel_list[2] => channel_list[3], stride=1, pad=padding)

    relu1 = relu
    relu2 = relu
    bn1 = BatchNorm(channel_list[2])
    bn2 = BatchNorm(channel_list[3])

	bridge_conv = Conv((kernel_size, kernel_size, kernel_size), channel_list[1] => channel_list[3], stride=1, pad=padding)

    if downsample
        sample = Chain(
			Conv((de_kernel_size, de_kernel_size, de_kernel_size), channel_list[3] => channel_list[3], stride=2, pad=(de_kernel_size - 1) ÷ 2, dilation=1),
            BatchNorm(channel_list[3]),
            relu
        )
    else
        sample = Chain(
			ConvTranspose((de_kernel_size, de_kernel_size, de_kernel_size), channel_list[3] => channel_list[3], stride=2, pad=(de_kernel_size - 1) ÷ 2),
            BatchNorm(channel_list[3]),
            relu
        )
    end

    return (conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample)
end

# ╔═╡ 1d65b1d1-82de-40ca-aaba-9eee23883cf3
md"""
## Contracting Block
"""

# ╔═╡ 40762509-b26e-47f5-8b49-e7100fdeb72a
begin
    struct ContractBlock{
		C1, C2, F1, F2, BN1, BN2, C3, CH
	} <: Lux.AbstractExplicitContainerLayer{
        (:conv1, :conv2, :bn1, :bn2, :bridge_conv, :sample)
    }
        conv1::C1
        conv2::C2
        relu1::F1
        relu2::F2
        bn1::BN1
        bn2::BN2
        bridge_conv::C3
        sample::CH
    end

    function ContractBlock(
        kernel_size, de_kernel_size, channel_list;
        downsample = true
    )

		conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample = create_unet_layers(
            kernel_size, de_kernel_size, channel_list;
            downsample = downsample
        )

        ContractBlock(conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample)
    end

    function (m::ContractBlock)(x, ps, st::NamedTuple)
        res, st_bridge_conv = m.bridge_conv(x, ps.bridge_conv, st.bridge_conv)
        x, st_conv1 = m.conv1(x, ps.conv1, st.conv1)
        x, st_bn1 = m.bn1(x, ps.bn1, st.bn1)
        x = relu(x)

        x, st_conv2 = m.conv2(x, ps.conv2, st.conv2)
        x, st_bn2 = m.bn2(x, ps.bn2, st.bn2)
        x = relu(x)

        x = x .+ res

        next_layer, st_sample = m.sample(x, ps.sample, st.sample)

		st = (conv1=st_conv1, conv2=st_conv2, bn1=st_bn1, bn2=st_bn2, bridge_conv=st_bridge_conv, sample=st_sample)
        return next_layer, x, st
    end
end

# ╔═╡ 91e05c6c-e9b3-4a72-84a5-2ce4b1359b1a
md"""
## Expanding Block
"""

# ╔═╡ 70614cac-2e06-48a9-9cf6-9078bc7436bc
begin
    struct ExpandBlock{
		C1, C2, F1, F2, BN1, BN2, C3, CH
	} <: Lux.AbstractExplicitContainerLayer{
        (:conv1, :conv2, :bn1, :bn2, :bridge_conv, :sample)
    }
        conv1::C1
        conv2::C2
        relu1::F1
        relu2::F2
        bn1::BN1
        bn2::BN2
        bridge_conv::C3
        sample::CH
    end

    function ExpandBlock(
        kernel_size, de_kernel_size, channel_list;
        downsample = false)

		conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample = create_unet_layers(
            kernel_size, de_kernel_size, channel_list;
            downsample = downsample
        )

        ExpandBlock(conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample)
    end

    function (m::ExpandBlock)(x, ps, st::NamedTuple)
        x, x1 = x[1], x[2]
        x = cat(x, x1; dims=4)

        res, st_bridge_conv = m.bridge_conv(x, ps.bridge_conv, st.bridge_conv)

        x, st_conv1 = m.conv1(x, ps.conv1, st.conv1)
        x, st_bn1 = m.bn1(x, ps.bn1, st.bn1)
        x = relu(x)

        x, st_conv2 = m.conv2(x, ps.conv2, st.conv2)
        x, st_bn2 = m.bn2(x, ps.bn2, st.bn2)
        x = relu(x)

        x = x .+ res

        next_layer, st_sample = m.sample(x, ps.sample, st.sample)

		st = (conv1=st_conv1, conv2=st_conv2, bn1=st_bn1, bn2=st_bn2, bridge_conv=st_bridge_conv, sample=st_sample)
        return next_layer, st
    end
end

# ╔═╡ 36885de0-aa0e-4037-929f-44e074fb17f5
md"""
## U-Net
"""

# ╔═╡ af56e2f7-2ab8-4ff2-8295-038b3a565cbc
begin
    struct UNet{
		CH1, CH2, CB1, CB2, CB3, CB4, EB1, EB2, EB3, C1
	} <: Lux.AbstractExplicitContainerLayer{
        (:conv1, :conv2, :conv3, :conv4, :conv5, :de_conv1, :de_conv2, :de_conv3, :de_conv4, :last_conv)
    }
        conv1::CH1
        conv2::CH2
        conv3::CB1
        conv4::CB2
        conv5::CB3
        de_conv1::CB4
        de_conv2::EB1
        de_conv3::EB2
        de_conv4::EB3
        last_conv::C1
    end

    function UNet(channel)
        conv1 = Chain(
            Conv((5, 5, 5), 1 => channel, stride=1, pad=2),
            BatchNorm(channel),
            relu
        )
        conv2 = Chain(
            Conv((2, 2, 2), channel => 2 * channel, stride=2, pad=0),
            BatchNorm(2 * channel),
            relu
        )
        conv3 = ContractBlock(5, 2, [2 * channel, 2 * channel, 4 * channel])
        conv4 = ContractBlock(5, 2, [4 * channel, 4 * channel, 8 * channel])
        conv5 = ContractBlock(5, 2, [8 * channel, 8 * channel, 16 * channel])

        de_conv1 = ContractBlock(
            5, 2, [16 * channel, 32 * channel, 16 * channel];
            downsample = false
        )
        de_conv2 = ExpandBlock(
            5, 2, [32 * channel, 8 * channel, 8 * channel];
            downsample = false
        )
        de_conv3 = ExpandBlock(
            5, 2, [16 * channel, 4 * channel, 4 * channel];
            downsample = false
        )
        de_conv4 = ExpandBlock(
            5, 2, [8 * channel, 2 * channel, channel];
            downsample = false
        )

        last_conv = Conv((1, 1, 1), 2 * channel => 2, stride=1, pad=0)

		UNet(conv1, conv2, conv3, conv4, conv5, de_conv1, de_conv2, de_conv3, de_conv4, last_conv)
    end

    function (m::UNet)(x, ps, st::NamedTuple)
        # Convolutional layers
        x, st_conv1 = m.conv1(x, ps.conv1, st.conv1)
        x_1 = x  # Store for skip connection
        x, st_conv2 = m.conv2(x, ps.conv2, st.conv2)

        # Downscaling Blocks
        x, x_2, st_conv3 = m.conv3(x, ps.conv3, st.conv3)
        x, x_3, st_conv4 = m.conv4(x, ps.conv4, st.conv4)
        x, x_4, st_conv5 = m.conv5(x, ps.conv5, st.conv5)

        # Upscaling Blocks
        x, _, st_de_conv1 = m.de_conv1(x, ps.de_conv1, st.de_conv1)
        x, st_de_conv2 = m.de_conv2((x, x_4), ps.de_conv2, st.de_conv2)
        x, st_de_conv3 = m.de_conv3((x, x_3), ps.de_conv3, st.de_conv3)
        x, st_de_conv4 = m.de_conv4((x, x_2), ps.de_conv4, st.de_conv4)

        # Concatenate with first skip connection and apply last convolution
        x = cat(x, x_1; dims=4)
        x, st_last_conv = m.last_conv(x, ps.last_conv, st.last_conv)

        # Merge states
        st = (
		conv1=st_conv1, conv2=st_conv2, conv3=st_conv3, conv4=st_conv4, conv5=st_conv5, de_conv1=st_de_conv1, de_conv2=st_de_conv2, de_conv3=st_de_conv3, de_conv4=st_de_conv4, last_conv=st_last_conv
        )

        return x, st
    end
end

Loss

# ╔═╡ 496712da-3cf0-4fbc-b869-72372e73612b
function compute_loss(x, y, model, ps, st)

    y_pred, st = model(x, ps, st)

    y_pred_softmax = softmax(y_pred, dims=4)
    y_pred_binary = round.(y_pred_softmax[:, :, :, 2, :])
    y_binary = y[:, :, :, 2, :]

    # Compute loss
    loss = 0.0
    for b in axes(y, 5)
        _y_pred = y_pred_binary[:, :, :, b]
        _y = y_binary[:, :, :, b]
		
		dsc = dice_loss(_y_pred, _y)
		loss += dsc
    end
	
    return loss / size(y, 5), y_pred_binary, st
end

Training Loop

# ╔═╡ 1e79232f-bda2-459a-bc03-85cd8afab3bf
function train_model(model, ps, st, train_loader, val_loader, num_epochs, dev)
    opt_state = create_optimiser(ps)

    # Initialize DataFrame to store metrics
    metrics_df = DataFrame(
        "Epoch" => Int[], 
		"Train_Loss" => Float64[],
        "Validation_Loss" => Float64[], 
        "Dice_Metric" => Float64[],
        "Hausdorff_Metric" => Float64[],
        "Epoch_Duration" => String[]
    )

    for epoch in 1:num_epochs
        @info "Epoch: $epoch"

        # Start timing the epoch
        epoch_start_time = now()

		# Training Phase
		num_batches_train = 0
		total_train_loss = 0.0
        for (x, y) in train_loader
			num_batches_train += 1
			@info "Step: $num_batches_train"
			x, y = x |> dev, y |> dev
			
            # Forward pass
            y_pred, st = Lux.apply(model, x, ps, st)
            loss, y_pred, st = compute_loss(x, y, model, ps, st)
			total_train_loss += loss
			# @info "Training Loss: $loss"

            # Backward pass
			(loss_grad, st_), back = Zygote.pullback(p -> Lux.apply(model, x, p, st), ps)
            gs = back((one.(loss_grad), nothing))[1]

            # Update parameters
            opt_state, ps = Optimisers.update(opt_state, ps, gs)
        end
		
		avg_train_loss = total_train_loss / num_batches_train
		@info "avg_train_loss: $avg_train_loss"

		# Validation Phase
		total_val_loss = 0.0
		total_dice = 0.0
		total_hausdorff = 0.0
		num_batches = 0
		for (x, y) in val_loader
		    x, y = x |> dev, y |> dev
		    
		    # Forward Pass
		    y_pred, st = Lux.apply(model, x, ps, st)
		
		    # Apply softmax and convert to binary
		    y_pred_softmax = softmax(y_pred, dims=4)
		    y_pred_binary = round.(y_pred_softmax[:, :, :, 2, :])
		    y_binary = y[:, :, :, 2, :]
		
		    # Compute loss
		    loss, _, _ = compute_loss(x, y, model, ps, st)
		
		    # Process batch for metrics
		    for b in axes(y_pred_binary, 5)
		        _y_pred = Bool.(y_pred_binary[:, :, :, b]) |> cpu_device()
		        _y = Bool.(y_binary[:, :, :, b]) |> cpu_device()
		
		        total_dice += dice_metric(_y_pred, _y)
		        total_hausdorff += hausdorff_metric(_y_pred, _y)
		    end
		
		    total_val_loss += loss
		    num_batches += 1
		end
		
		# Calculate average metrics
		avg_val_loss = total_val_loss / num_batches
		avg_dice = total_dice / num_batches
		avg_hausdorff = total_hausdorff / num_batches
		@info "avg_val_loss: $avg_val_loss"
		@info "avg_dice: $avg_dice"
		@info "avg_hausdorff: $avg_hausdorff"

        # Calculate and log time taken for the epoch
        epoch_duration = now() - epoch_start_time

        # Append metrics to the DataFrame
		push!(metrics_df, [epoch, avg_train_loss, avg_val_loss, avg_dice, avg_hausdorff, string(epoch_duration)])

        # Write DataFrame to CSV file
        CSV.write("training_metrics.csv", metrics_df)

        @info "Metrics logged for Epoch $epoch"
    end

    return ps, st
end

# ╔═╡ a2e88851-227a-4719-8828-6064f9d3ef81
if LuxCUDA.functional()
	num_epochs = 10
else
	num_epochs = 2
end

# ╔═╡ 5cae73af-471c-4068-b9ff-5bc03dd0472d
ps_final, st_final = train_model(model, ps, st, train_loader, val_loader, num_epochs, dev);

False alarm, the learning rate for the optimizer was just too high it seems.

1 Like