Flux model learning slow?

Hello everyone,

I’m currently working on training a model using the Flux framework, but I’ve encountered a significant performance issue: the training time for each epoch is around 10 minutes, which seems unusually slow. I’ve experimented with different loss functions, but the training speed remains a bottleneck.

# Define a multi-layer perceptron (MLP) structure
struct DiscreetMLP
    dropout::Dropout        # Dropout layer to prevent overfitting
    layer1::Dense           # First dense layer
    layer2::Dense           # Second dense layer
    layer3::Dense           # Third dense layer
end

# Enable Flux to work with DiscreetMLP
Flux.@functor DiscreetMLP

# Constructor for DiscreetMLP, taking hidden layer size and dropout rate as inputs
function DiscreetMLP(hidden_size, dropout_rate)
    layer1 = Dense(hidden_size, 2 * hidden_size, tanh)    # First layer with activation
    layer2 = Dense(2 * hidden_size, 2 * hidden_size, tanh) # Second layer with activation
    layer3 = Dense(2 * hidden_size, 1)                     # Output layer
    dropout = Dropout(dropout_rate)                         # Dropout layer
    return DiscreetMLP(dropout, layer1, layer2, layer3)     # Return the constructed model
end

# Define the forward pass for the MLP
function (mlp::DiscreetMLP)(input)
    input = mlp.layer1(input)        # Pass input through first layer
    input = mlp.dropout(input)       # Apply dropout
    input = mlp.layer2(input)        # Pass input through second layer
    input = mlp.dropout(input)       # Apply dropout
    input = mlp.layer3(input)        # Pass input through output layer
    return input                     # Return final output
end

# Define a graph neural network structure
struct DiscreetGNN
    init_layer::LayerInit
    layer1::Layer
    layer2::Layer
    layer3::Layer
    layer4::Layer
    layer5::Layer
    mlp::DiscreetMLP
end

Flux.@functor DiscreetGNN

# Constructor for DiscreetGNN, taking input size, hidden size, and dropout rate
function DiscreetGNN(input_size, hidden_size, dropout_rate)
    init_layer = LayerInit(input_size, hidden_size)  # Initialize layer
    layer1 = Layer(hidden_size, hidden_size)          # First GNN layer
    layer2 = Layer(hidden_size, hidden_size)          # Second GNN layer
    layer3 = Layer(hidden_size, hidden_size)          # Third GNN layer
    layer4 = Layer(hidden_size, hidden_size)          # Fourth GNN layer
    layer5 = Layer(hidden_size, hidden_size)          # Fifth GNN layer
    mlp = DiscreetMLP(hidden_size, dropout_rate)      # Create MLP instance
    return DiscreetGNN(init_layer, layer1, layer2, layer3, layer4, layer5, mlp)  # Return the constructed GNN
end

# Define the forward pass for the GNN
function (gnn::DiscreetGNN)(adj1, adj2)
    #I have ensured that inputs and outputs are of type Float32

    # Normalize and propagate through layers
    x_1 = mynormalize(gnn.init_layer(adj2))          # Initial adjacency normalization
    x_2 = mynormalize(gnn.layer1(x_1, adj1))         # First GNN layer propagation
    x_3 = mynormalize(gnn.layer2(x_2, adj1))         # Second GNN layer propagation
    x_4 = mynormalize(gnn.layer3(x_3, adj1))         # Third GNN layer propagation
    x_5 = mynormalize(gnn.layer4(x_4, adj1))         # Fourth GNN layer propagation
    x_6 = gnn.layer5(x_5, adj1)                       # Fifth GNN layer propagation

    # Calculate scores for each layer's output
    s1 = gnn.mlp(x_1')  # MLP output for first layer
    s2 = gnn.mlp(x_2')  # MLP output for second layer
    s3 = gnn.mlp(x_3')  # MLP output for third layer
    s4 = gnn.mlp(x_4')  # MLP output for fourth layer
    s5 = gnn.mlp(x_6')  # MLP output for fifth layer

    # Combine the scores from all layers
    score_total = s1 + s2 + s3 + s4 + s5
    return sigmoid(score_total)  # Apply sigmoid to return probabilities
end

# Function to normalize the input
function mynormalize(x)
    norms = sqrt.(sum(abs2, x, dims=2)) .+ 0.01  # Calculate norms with a small constant to avoid division by zero
    return Float32.(x ./ norms)                  # Normalize the input
end

# Training function to create the model
function create_discreet_model(model_size)
    model = DiscreetGNN(model_size, 20, 0.6)  # Create GNN with specified model size, hidden size, and dropout rate
    return model
end


#MODEL TRAINING

# Custom cost-sensitive binary cross-entropy loss function
function cost_sensitive_loss(predictions, targets; pos_weight=2, neg_weight=1, aggregation=mean, epsilon=1e-10)
    pos_weighted = targets * pos_weight                      # Weight for positive class
    neg_weighted = (1 .- targets) * neg_weight              # Weight for negative class
    losses = @.(-pos_weighted * log(predictions + epsilon) - neg_weighted * log(1 - predictions + epsilon))  # Calculate losses
    return aggregation(losses)                              # Return aggregated loss
end

# Function to round predictions based on a threshold
function custom_round(predictions, threshold=0.9)
    return Int.(predictions .>= threshold ? 1 : 0)  # Convert predictions to binary based on threshold
end

# Function to calculate total accuracy
function overall_accuracy(predictions, targets)
    return sum(custom_round.(predictions) .== targets) / length(targets)  # Calculate accuracy
end

# Function to calculate class-specific accuracy
function class_specific_accuracy(predictions, targets, class)
    pred = custom_round.(predictions) .== class  # Predicted labels for the specific class
    true_vals = targets .== class                 # True labels for the specific class

    correct_predictions = sum(pred .& true_vals)  # Count correct predictions
    total_count = sum(true_vals)                   # Total instances for the class

    return correct_predictions / total_count       # Return class-specific accuracy
end

loss(ŷ, y) = cost_sensitive_binary_crossentropy(ŷ, y)

# Function to train the model using the metrics
function train_discreet_model(model, train_dataset, test_dataset)
    learning_rate = 0.001  # Set learning rate
    opt = Flux.setup(Adam(learning_rate), model)  # Set up optimizer

    # Initialize arrays to store training history
    loss_train_history = []
    #...initialize arrays for the other metrics too
    
    # Calculate initial losses and accuracies
    loss_train = mean([loss(model(ad, adm)[1:n], binary_arr[1:n]) for (ad, adm, binary_arr, n) in train_dataset])
    loss_test = mean([loss(model(ad, adm)[1:n], binary_arr[1:n]) for (ad, adm, binary_arr, n) in test_dataset])

    acc_test_binary_arr = mean([class_specific_accuracy(model(ad, adm)[1:n], binary_arr[1:n], 1) for (ad, adm, binary_arr, n) in test_dataset])
    acc_test_other = mean([class_specific_accuracy(model(ad, adm)[1:n], binary_arr[1:n], 0) for (ad, adm, binary_arr, n) in test_dataset])
    acc_total = mean([overall_accuracy(model(ad, adm)[1:n], binary_arr[1:n]) for (ad, adm, binary_arr, n) in test_dataset])
    est = mean([sum(custom_round.(vec(model(ad, adm)))) for (ad, adm, binary_arr, _) in test_dataset])


    # Store initial metrics
    push!(loss_train_history, loss_train)
    #...push also for the other metrics
    
    @show 0, loss_train, loss_test, acc_test_binary_arr, acc_test_other, acc_total, est

    # Training loop for a specified number of epochs
    for epoch in 1:5
        for (ad, adm, binary_arr, n) in train_dataset
            Flux.train!(loss, params(model), [(ad, adm, binary_arr, n)], opt)  # Update model parameters
        end

        # Compute metrics after each epoch
        loss_train = mean([loss(model(ad, adm)[1:n], binary_arr[1:n]) for (ad, adm, binary_arr, n) in train_dataset])
        loss_test = mean([loss(model(ad, adm)[1:n], binary_arr[1:n]) for (ad, adm, binary_arr, n) in test_dataset])

        acc_test_binary_arr = mean([class_specific_accuracy(model(ad, adm)[1:n], binary_arr[1:n], 1) for (ad, adm, binary_arr, n) in test_dataset])
        acc_test_other = mean([class_specific_accuracy(model(ad, adm)[1:n], binary_arr[1:n], 0) for (ad, adm, binary_arr, n) in test_dataset])
        acc_total = mean([overall_accuracy(model(ad, adm)[1:n], binary_arr[1:n]) for (ad, adm, binary_arr, n) in test_dataset])
        est = mean([sum(custom_round.(vec(model(ad, adm)))) for (ad, adm, binary_arr, _) in test_dataset])

        # Append metrics to histories
        push!(loss_train_history, loss_train)
        #....push for the other metrics
        
        # Log metrics
        @show epoch, loss_train, loss_test, acc_test_binary_arr, acc_test_other, acc_total, est
    end

    # Return the final metrics and histories
    return loss_train_hist #... and the rest metrics
end

I would greatly appreciate any suggestions or insights on what might be causing the slow performance and how to enhance it.

Thank you for your help!

Part of the problem may be that your structs have fields with abstract types. Note that

julia> typeof(Dense(4 => 5))
Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}

julia> isconcretetype(typeof(Dense(4 => 5)))
true

julia> isconcretetype(Dense)
false

Check the [Performance Tips · The Julia Language] for further details. As a quick fix, you can simply use a regular Chain instead of a custom struct for your DiscreetMLP, i.e., define

function DiscreetMLP(hidden_size, dropout_rate)
    layer1 = Dense(hidden_size, 2 * hidden_size, tanh)    # First layer with activation
    layer2 = Dense(2 * hidden_size, 2 * hidden_size, tanh) # Second layer with activation
    layer3 = Dense(2 * hidden_size, 1)                     # Output layer
    dropout = Dropout(dropout_rate)                         # Dropout layer
    return Chain(layer1, dropout, layer2, dropout, layer3) # Does the same as your forward path
end

and remove your struct as well the method for the forward path. As the Chain is fully typed this should give you a speedup already.

1 Like

Ok I have used a regular Chain for DiscreetMLP, but still not a big improvement. However, my structs I think do not have abstract types as I define the layers like this:

struct LayerInit{W<:AbstractMatrix}
    weight::W
    b
    bias
    σ
end

Flux.@functor LayerInit

function LayerInit(in, out; σ=tanh, init=Flux.glorot_uniform, bias::Bool=true)
    W = init(in, out)
    b = bias ? Flux.create_bias(W, true, out) : false
    return LayerInit(W, b, bias, σ)
end

function (li::LayerInit)(A)
    if li.bias
        return li.σ.(A * li.weight .+ li.b')
    else
        return li.σ.(A * li.weight)
    end
end

struct Layer{W<:AbstractMatrix}
    weight::W
    b
    bias
    σ
end

Flux.@functor Layer

function Layer(in, out; σ=tanh, init=Flux.glorot_uniform, bias::Bool=true)
    W = init(in, out)
    b = bias ? Flux.create_bias(W, true, out) : false
    return Layer(W, b, bias, σ)
end

function (l::Layer)(H, A)
    support = H * l.weight

    if l.bias
        return l.σ.(A * support .+ l.b')
    else
        return l.σ.(A * support)
    end
end

struct LayerScore{W<:AbstractMatrix}
    weight::W
    b
    bias
    σ
end

Flux.@functor LayerScore

function LayerScore(in, out;σ=sigmoid,init=Flux.glorot_uniform, bias::Bool=true)
    W = init(in, out)
    b = bias ? Flux.create_bias(W, true, out) : false
    return LayerScore(W, b, bias, σ)
end

function (ls::LayerScore)(A)
    if ls.bias
        return ls.σ(A * ls.weight .+ ls.b')
    else
        return ls.σ(A * ls.weight)
    end
end

A struct definition with no types is the same as ::Any, and will produce instabilities every time you retrieve its fields:

julia> lay = LayerInit(rand(2,2), nothing, [1,2.], tanh);

julia> @code_warntype (l -> l.bias)(lay)
MethodInstance for (::var"#63#64")(::LayerInit{Matrix{Float64}})
  from (::var"#63#64")(l) @ Main REPL[88]:1
Arguments
  #self#::Core.Const(var"#63#64"())
  l::LayerInit{Matrix{Float64}}
Body::Any
1 ─ %1 = Base.getproperty(l, :bias)::Any
└──      return %1

julia> den = Flux.Dense(2=>2,tanh);

julia> @code_warntype (l -> l.bias)(den)
MethodInstance for (::var"#65#66")(::Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}})
  from (::var"#65#66")(l) @ Main REPL[89]:1
Arguments
  #self#::Core.Const(var"#65#66"())
  l::Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}
Body::Vector{Float32}
1 ─ %1 = Base.getproperty(l, :bias)::Vector{Float32}
└──      return %1

For Flux this matters most on innermost layers (containing arrays). If the outermost model struct has abstract types (or none), the cost may not be visible at all, roughly because a lot of other work is done per instability.

Without sample data we can’t run the model above to try improvements ourselves.

Here, and perhaps elsewhere, note that .+ 0.01 is Float64, which you later convert back. .+ 0.01f0 would be better. Not sure this is a big effect. (On the first run after re-starting Julia, watch for warnings about Float64 promotion.)

1 Like

I have updated and added data types for the structs and more. Still training time improved but not significantly.

function MultiLayerPerc(nhid, dropout_rate)
    linear1 = Dense(nhid, 2 * nhid, tanh)
    linear2 = Dense(2 * nhid, 2* nhid, tanh)
    linear3 = Dense(2 * nhid, 1)
    dropout = Dropout(dropout_rate)
    return Chain(linear1, dropout, linear2, dropout, linear3)
end

struct DiscreetGNN
    gc1::LayerInit{Matrix{Float32}}  # Specify concrete type
    gc2::Layer{Matrix{Float32}}       # Specify concrete type
    gc3::Layer{Matrix{Float32}}       # Specify concrete type
    gc4::Layer{Matrix{Float32}}       # Specify concrete type
    gc5::Layer{Matrix{Float32}}       # Specify concrete type
    mlp::Chain                         # MLP remains as it is
end

Flux.@functor DiscreetGNN

function DiscreetGNN(ninput::Int64, nhid::Int64, dropout::Float32)
    gc1 = LayerInit(ninput, nhid)
    gc2 = Layer(nhid, nhid)
    gc3 = Layer(nhid, nhid)
    gc4 = Layer(nhid, nhid)
    gc5 = Layer(nhid, nhid)
    mlp = MultiLayerPerc(nhid, dropout)
    return DiscreetGNN(gc1, gc2, gc3, gc4, gc5, mlp)
end

function (gnn::DiscreetGNN)(adj1::SparseMatrixCSC{Float32, Int64}, adj2::SparseMatrixCSC{Float32, Int64})
     x_1 = mynormalize(gnn.gc1(adj2))
     x_2 = mynormalize(gnn.gc2(x_1, adj1))
     x_3 = mynormalize(gnn.gc3(x_2, adj1))
     x_4 = mynormalize(gnn.gc4(x_3, adj1))
     x_5 = mynormalize(gnn.gc5(x_4, adj1))

     score1_1 = gnn.mlp(x_1')
     score1_2 = gnn.mlp(x_2')
     score1_3 = gnn.mlp(x_3')
     score1_4 = gnn.mlp(x_4')
     score1_5 = gnn.mlp(x_5')
     score1 = score1_1 + score1_2 + score1_3 + score1_4 + score1_5

     return Float32.(sigmoid(score1))
 end

function mynormalize(x::Matrix{Float32})
        norms = sqrt.(sum(abs2, x, dims=2)) .+ 0.01f0
        return Float32.(x ./ norms)
end

#Layers Definition
struct LayerInit{W<:AbstractMatrix}
    weight::W
    b::Union{Bool, Vector{Float32}}
    bias::Bool
    σ::Function
end

Flux.@functor LayerInit

function LayerInit(in::Int64, out::Int64; σ=relu, init=Flux.glorot_uniform, bias::Bool=true)
    W = init(in, out)
    b = bias ? Flux.create_bias(W, true, out) : false
    return LayerInit(W, b, bias, σ)
end

function (li::LayerInit{Matrix{Float32}})(A::SparseMatrixCSC{Float32, Int64})
    if li.bias
        return li.σ.(A * li.weight .+ li.b')
    else
        return li.σ.(A * li.weight)
    end
end

struct Layer{W<:AbstractMatrix}
    weight::W 
    b::Union{Bool, Vector{Float32}}
    bias::Bool
    σ::Function
end

Flux.@functor Layer

function Layer(in::Int64, out::Int64; σ=relu, init=Flux.glorot_uniform, bias::Bool=true)
    W = init(in, out)
    b = bias ? Flux.create_bias(W, true, out) : false
    return Layer(W, b, bias, σ)
end

function (l::Layer{Matrix{Float32}})(H::Matrix{Float32}, A::SparseMatrixCSC{Float32, Int64})
    support = H * l.weight

    if l.bias
        return l.σ.(A * support .+ l.b')
    else
        return l.σ.(A * support)
    end
end

Your b fields in LayerInit and Layer are still ::Any, and mlp::Chain in DiscreetGNN is still an abstract type annotation; you can use isconcretetype to manually check. I’m not certain that’s the sole cause of the performance issue but it generally helps a lot to let all fields have concrete types given the parameters. You can let runtime values and automatically generated type constructor methods determine type parameters instead of incrementally attempting to fill in type annotations manually, if it helps.

To give an example, here’s one solution to write this in a type-inferrable but generic way:

struct DiscreetGNN{L1<:LayerInit,L2<:Layer,L3<:Layer,L4<:Layer,L5<:Layer,C<:Chain}
    gc1::L1
    gc2::L2
    gc3::L3
    gc4::L4
    gc5::L5
    mlp::C
end

I don’t know exactly what you are doing but I see transpositions like li.b' and s1 = gnn.mlp(x_1') that usually are not there since you usually keep the batch dimension (or the node dimension in gnns) last.

I also suggest looking into GitHub - CarloLucibello/GraphNeuralNetworks.jl: Graph Neural Networks in Julia for GNNs,

there is also ConcreteStructs.jl

1 Like

I have removed function MultiLayerPerc, and implementing it directly to function DiscreetGNN like this:

struct DiscreetGNN
    gc1::LayerInit{Matrix{Float32}}
    gc2::Layer{Matrix{Float32}}
    gc3::Layer{Matrix{Float32}}
    gc4::Layer{Matrix{Float32}}
    gc5::Layer{Matrix{Float32}}
    mlp::Chain
end

Flux.@functor DiscreetGNN

function DiscreetGNN(ninput::Integer, nhid::Integer, dropout::Float32)
    gc1 = LayerInit(ninput, nhid)
    gc2 = Layer(nhid, nhid)
    gc3 = Layer(nhid, nhid)
    gc4 = Layer(nhid, nhid)
    gc5 = Layer(nhid, nhid)
    mlp = Chain(Dense(nhid, 2 * nhid, relu), Dropout(dropout), Dense(2 * nhid, 2* nhid, relu), Dropout(dropout), Dense(2 * nhid, 1))
    return DiscreetGNN(gc1, gc2, gc3, gc4, gc5, mlp)
end

However, running @report_opt DiscreetGNN(1,2,3.0f0) , to find runtime dispatch, I have the following dispatch:

┌ DiscreetGNN(ninput::Int64, nhid::Int64, dropout::Float32) @ Main /home/*/*/*/src/Model.jl:72
│┌ Dense(in::Int64, out::Int64, σ::typeof(relu)) @ Flux /home/*/.julia/packages/Flux/htpCe/src/deprecations.jl:6
││┌ Dense(in::Int64, out::Int64, σ::typeof(relu); kw::@Kwargs{}) @ Flux /home/*/.julia/packages/Flux/htpCe/src/deprecations.jl:6
│││┌ Dense(::Pair{Int64, Int64}, σ::typeof(relu)) @ Flux /home/*/.julia/packages/Flux/htpCe/src/layers/basic.jl:164
││││┌ Dense(::Pair{Int64, Int64}, σ::typeof(relu); init::typeof(Flux.glorot_uniform), bias::Bool) @ Flux /home/*/.julia/packages/Flux/htpCe/src/layers/basic.jl:166
│││││┌ Dense(W::Matrix{Float32}, bias::Bool, σ::typeof(relu)) @ Flux /home/*/.julia/packages/Flux/htpCe/src/layers/basic.jl:160
││││││ **runtime dispatch detected: convert(%66::Type, %62::Union{Bool, Vector{Float32}})::Any**
**│││││└────────────────────**

Model.jl:72 is mlp = Chain(Dense(nhid, 2 * nhid, relu), Dropout(dropout), Dense(2 * nhid, 2* nhid, relu), Dropout(dropout), Dense(2 * nhid, 1))

Can you paste a fully reproducible example, leaving out irrelevant logging stuff and making sure that the sequence of operations is correct? We can try to speedup from there.

I am trying to generate some dummy data:

function generate_dummy_data(num_samples=2, model_size=10)
    train_ds = []
    test_ds = []

    for _ in 1:num_samples
        # Generate a random sparse adjacency matrix
        n = 40
        ad = sprand(Float32, n, n, 0.1)  # sparse adjacency matrix
        adm = sprand(Float32, n, n, 0.1)  # sparse adjacency matrix

        # Create dummy labels for binary aray
        bin_ar = rand(Float32, model_size) .< 0.5  # Random binary labels (0 or 1)
        bin_ar = Int.(bin_ar)  # Convert to Int

        # Ensure the bin_ar has the correct size
        bin_ar = vcat(bin_ar, zeros(model_size - length(bin_ar))...)  # Pad to model_size

        # Push the data into the dataset
        push!(train_ds, (ad, adm, bin_ar, n))
        push!(test_ds, (ad, adm, bin_ar, n))
    end
    
    model = create_model(model_size)
    test_layer_init()
    results = train_model(model, train_ds, test_ds)
    return results
end

function test_layer()
    layer = LayerInit(5, 6)
    layer1 = Layer(5, 6)
    println("Layer Init: ", layer)
    println("Layer : ", layer1)
end

Output:

generate_dummy_data()
Layer Init: LayerInit{Matrix{Float32}}(Float32[-0.42002252 -0.6773198 -0.52798855 0.55941707 0.48495016 0.2549217; 0.45783335 -0.017114287 -0.6280818 -0.69759977 -0.5925536 -0.30032533; 0.4021538 0.0956694 0.51514834 0.6089574 0.19965225 -0.440966; 0.3994848 -0.18751901 0.4272211 0.6761862 -0.45498142 -0.5790665; 0.18514858 -0.14462458 0.5682483 -0.6178941 0.7193772 -0.21507059], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], true, NNlib.relu)
Layer : Layer{Matrix{Float32}}(Float32[-0.5239934 -0.7359658 -0.62638366 -0.26138678 0.67564785 -0.42961064; 0.25436994 -0.71448207 0.37783283 -0.5730228 0.10510538 -0.1555842; 0.066669814 -0.16778593 -0.50816566 0.059353705 0.65710515 0.42555973; -0.44090483 0.12359858 0.22497636 0.16602094 -0.39279443 0.29406732; 0.18325181 -0.088156 0.66331255 -0.015707115 0.03858383 0.4399053], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], true, NNlib.relu)