Speed up a derivative calculation?

Hi,

I am trying to write a function to compute the derivative of a cost function.
The cost function is used to express how good a distance function is at separating labeled data into its respective classes. I’m trying to use the derivative to run gradient descent to try and find a good set of weights for this distance function. To do this I will need to run the derivative a lot of times in the optimization process, so I want it to be a fast calculation if possible.

I am open to any packages/other work that can help me solve this problem. I checked for some Julia optimization packages, but none seemed to be for large derivatives like this. So I wrote my own function representing the derivative.

Here’s the code:

"""
Computes the derivative of the cost function wrt the weights of the distance function
param dist : Distance object from Distances.jl package
param S : R+ -> R+, Score function, monotonically decreasing  (such as E^-x)
param Sβ€²: R+ -> R, Derivative of score function
param X: Vector of feature matrices per class. 
"""
function costd(dist :: WeightedPeriodicNormalizedMinkowski, S :: Function , Sβ€² :: Function, X::Vector{Matrix})
  d(p, q) = (x -> x == 0 ? 10^-6 : x)(dist(p,q))
  cost = 0
  dw = zeros(length(dist.weights))
  for Xi in X # for class in X
    for xi in eachcol(Xi) # for point in current class
      # scored distances to points in the same class
      DXi = sum(S(dist(xi, q)) for q = eachcol(Xi)) 
      DXl = sum(S(dist(xi, q)) for Xl = X if Xl != Xi for q = eachcol(Xl) )
      cost += DXl / DXi 
      for i in 1:length(dist.weights)
        dDXi = sum(Distances.eval_op(dist, xi[i], q[i], dist.periods[i], dist.variances[i], 1) * d(xi, q) ^ (1 - dist.p) * Sβ€²(dist(xi, q)) for q  = eachcol(Xi) )
        dDXl = sum(Distances.eval_op(dist, xi[i], q[i], dist.periods[i], dist.variances[i], 1) * dist(xi, q) ^ (1 - dist.p) * Sβ€²(dist(xi, q)) for Xl in X if Xl != Xi for q = eachcol(Xl) )
        
        dw[i] += 1/dist.p * (dDXl - DXl * dDXi / DXi ) / DXi 
      end
    end
  end
  dw, cost
end

With two classes of 300 points each, this takes 30s to compute.

And here it is embedded in a MWE:

MWE code
module Optimizer

import Distances 

struct WeightedPeriodicNormalizedMinkowski{W <: AbstractArray, T <: AbstractArray, V <: AbstractArray, P <: Real} <: Distances.UnionMetric
  weights:: W 
  periods:: T 
  variances:: V 
  p:: P 
end
Base.repeat(wpm::WeightedPeriodicNormalizedMinkowski, r:: Integer) = WeightedPeriodicNormalizedMinkowski(
  repeat(wpm.weights, r)/r, repeat(wpm.periods, r), repeat(wpm.variances, r), wpm.p)


# -------------------------------------------------------------------------------------
# Help me optimize this please! 
# for reference, computing a distance matrix requires only 2-3s on dataset with 20 classes with ~300 data points each.
"""
Computes the derivative of the cost function wrt the weights of the distance function
param dist : Distance object from Distances.jl package
param S : R+ -> R+, Score function, monotonically decreasing  (such as E^-x)
param Sβ€²: R+ -> R, Derivative of score function
param X: Vector of feature matrices per class. 
"""
function costd(dist :: WeightedPeriodicNormalizedMinkowski, S :: Function , Sβ€² :: Function, X::Vector{Matrix})
  d(p, q) = (x -> x == 0 ? 10^-6 : x)(dist(p,q))
  cost = 0
  dw = zeros(length(dist.weights))
  for Xi in X # for class in X
    for xi in eachcol(Xi) # for point in current class
      # scored distances to points in the same class
      DXi = sum(S(dist(xi, q)) for q = eachcol(Xi)) 
      DXl = sum(S(dist(xi, q)) for Xl = X if Xl != Xi for q = eachcol(Xl) )
      cost += DXl / DXi 
      for i in 1:length(dist.weights)
        dDXi = sum(Distances.eval_op(dist, xi[i], q[i], dist.periods[i], dist.variances[i], 1) * d(xi, q) ^ (1 - dist.p) * Sβ€²(dist(xi, q)) for q  = eachcol(Xi) )
        dDXl = sum(Distances.eval_op(dist, xi[i], q[i], dist.periods[i], dist.variances[i], 1) * dist(xi, q) ^ (1 - dist.p) * Sβ€²(dist(xi, q)) for Xl in X if Xl != Xi for q = eachcol(Xl) )
        
        dw[i] += 1/dist.p * (dDXl - DXl * dDXi / DXi ) / DXi 
      end
    end
  end
  dw, cost
end
# ------------------------------------------------------------------------------------

# Example of how to use this function: 

X = [ 100 * rand(Float64, (48, rand(150:400))), 75 * rand(Float64, (48, rand(300:400)))]
S(x) = exp(-x)
Sp(x) = -exp(-x)

normalize(v::Vector) = v / sum(v)

distance = repeat(WeightedPeriodicNormalizedMinkowski(
  normalize([12.0, 2.0, 1.0, 6.0, 3.0, 2.0, 2.0, 6.0]), # weights
  [Inf, 360.0, 180, Inf, Inf, Inf, Inf, 360], # periods
  [1.0, 5, 10, 20, 20, 20, 20, 10], # measurement variance
  2), 
6)

# takes about 30s 
# costd(distance, S, Sp, X)

# A real world use case would be about 50 classes with 500 data points each
# X = [ 100 * rand(Float64, (48, rand(150:400))) for _ = 1:50]
# costd(distance, S, Sp, X) # doesn't finish

# --------------------------------------------------------
# Everything below this is to patch Distances.jl for a metric that has 3 parameters, you can ignore it
# --------------------------------------------------------


WeightedPeriodicNormalizedMinkowski(weights::AbstractArray) = (periods::AbstractArray) -> (variances::AbstractArray) -> (p :: Real) -> WeightedPeriodicNormalizedMinkowski(weights, periods, variances, p)
WeightedPeriodicNormalizedMinkowski(weights::AbstractArray, periods::AbstractArray, variances::AbstractArray) = WeightedPeriodicNormalizedMinkowski(weights)(periods)(variances)
WeightedPeriodicNormalizedMinkowski(wpm::WeightedPeriodicNormalizedMinkowski, p::Real) = WeightedPeriodicNormalizedMinkowski(wpm.weights, wpm.periods, wpm.variances, p)

Distances.parameters(wpm::WeightedPeriodicNormalizedMinkowski) = (wpm.periods,  wpm.variances, wpm.weights)


@inline function Distances.eval_op(d::WeightedPeriodicNormalizedMinkowski, ai, bi, Ti, vi, wi)
  # Taken from the PeriodicEuclidean function 
  s1 = abs(ai - bi)
  s2 = mod(s1, Ti)
  # taken from the WeightedMinkowski function with some modifications.
  (min(s2, Ti - s2) / vi)^d.p * wi
end
@inline Distances.eval_end(d::WeightedPeriodicNormalizedMinkowski, s) = s^(1/d.p)
wpnminkowski(a, b, w::AbstractArray, T::AbstractArray, v::AbstractArray, p::Real) = WeightedPeriodicNormalizedMinkowski(w, T, v, p)(a, b)
wpnminkowski(a, w::AbstractArray, T::AbstractArray, v::AbstractArray, p::Real) = wpnminkowski(a, a, w, T, v, p)
(w::WeightedPeriodicNormalizedMinkowski)(a,b) = Distances._evaluate(w, a, b)



Distances._eval_start(d::Distances.UnionMetrics, ::Type{Ta}, ::Type{Tb}, (p1, p2, p3)) where {Ta,Tb} =
    zero(typeof(Distances.eval_op(d, oneunit(Ta), oneunit(Tb), oneunit(Distances.eltype(p1)), oneunit(Distances.eltype(p2)), oneunit(Distances.eltype(p3)) )))

function Distances._evaluate(dist::Distances.UnionMetrics, a::Number, b::Number, p1::Number, p2::Number, p3::Number)
    Distances.eval_end(dist, Distances.eval_op(dist, a, b, p1, p2, p3))
end 

Distances.result_type(dist::Distances.UnionMetrics, ::Type{Ta}, ::Type{Tb},  (p1, p2, p3)) where {Ta,Tb} =
    typeof(Distances._evaluate(dist, oneunit(Ta), oneunit(Tb), oneunit(Distances.eltype(p1)), oneunit(Distances.eltype(p2)), oneunit(Distances.eltype(p3))))

Base.@propagate_inbounds function Distances._evaluate(d::Distances.UnionMetrics, a::AbstractArray, b::AbstractArray, (p1, p2, p3)::Tuple{AbstractArray, AbstractArray, AbstractArray})
    @boundscheck if length(a) != length(b)
        throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
    end
    @boundscheck if length(a) != length(p1)
        throw(DimensionMismatch("arrays have length $(length(a)) but parameter 1 has length $(length(p1))."))
    end
    @boundscheck if length(a) != length(p2)
        throw(DimensionMismatch("arrays have length $(length(a)) but parameter 2 has length $(length(p2))."))
    end
    @boundscheck if length(a) != length(p3)
        throw(DimensionMismatch("arrays have length $(length(a)) but parameter 3 has length $(length(p3))."))
    end
    if length(a) == 0
        return zero(Distances.result_type(d, a, b))
    end
    @inbounds begin
        s = Distances.eval_start(d, a, b)
        if (IndexStyle(a, b, p1, p2, p3) === IndexLinear() && eachindex(a) == eachindex(b) == eachindex(p3) == eachindex(p1)) == eachindex(p2) ||
                axes(a) == axes(b) == axes(p3) == axes(p1) == axes(p2)
            @simd for I in eachindex(a, b,  p1, p2, p3)
                ai = a[I]
                bi = b[I]
                p1i = p1[I]
                p2i = p2[I]
                p3i = p3[I]
                s = Distances.eval_reduce(d, s, Distances.eval_op(d, ai, bi, p1i, p2i, p3i))
            end
        else
            for (ai, bi,  p1i, p2i, p3i) in zip(a, b, p1, p2, p3)
                s = Distances.eval_reduce(d, s, Distances.eval_op(d, ai, bi, p1i, p2i, p3i))
            end
        end
        return Distances.eval_end(d, s)
    end
end


end 

Try using @code_warntype to spot type instabilities. The quoted code is a hint at a likely instability.

Ok I vectorized it and so here’s an improvement:

"""
Structure of this code taken from Distances.jl pairwise function
"""
function _pairwise_expanded(r::Array{Float64, 3}, metric::Distances.UnionMetric, a::AbstractMatrix, b::AbstractMatrix)
  params = collect(zip(Distances.parameters(metric)...))
  na = size(a, 2)
  nb = size(b, 2)
  np = length(params)
  size(r) == (na, nb, np) || throw(DimensionMismatch("Incorrect size of r."))
  for j = 1:nb
    bj = view(b, :, j)
    for i = 1:na
      ai = view(a, :, i)
      for k = 1:np
        r[i, j, k] = Distances.eval_op(metric, ai[k], bj[k], params[k]...)
      end
    end
  end
  r
end


"""
Computes a 3D matrix with indices [1:size(a,2), 1:size(b,2), 1:size(a,1)]

In order to make this a regular 2D distance matrix, `sum(..., dims=3) .^ (1/dist.p)`
where ... represents the results of this function. 
"""
function pairwise_expanded(metric::Distances.UnionMetric, a::AbstractMatrix, b::AbstractMatrix)
  m = size(a, 2)
  n = size(b, 2)
  r = Array{Distances.result_type(metric, a, b), 3}(undef, m, n, length(Distances.parameters(metric)[1]))
  _pairwise_expanded(r, metric, a, b)
end
pairwise_expanded(metric::Distances.UnionMetric, a::AbstractMatrix) = pairwise_expanded(metric, a, a) 


"""
Computes the derivative of the cost function wrt the weights of the distance function
param dist : Distance object from Distances.jl package
param S : R+ -> R+, Score function, monotonically decreasing  (such as E^-x)
param Sβ€²: R+ -> R, Derivative of score function
param X: Vector of feature matrices per class. 
"""
function costd(dist :: Distances.PreMetric, S :: Function , Sβ€² :: Function, X::Vector{Matrix{Float64}})
  cost = 0.0
  dw = zeros(length(dist.weights))
  #unweighted distance
  uwdist = WeightedPeriodicNormalizedMinkowski(ones(length(dist.weights)), dist.periods, dist.variances, dist.p)
  for i in eachindex(X) # for class in X
    # points in other classes combined
    allXl::Matrix{Float64} = hcat(X[1:i-1]..., X[i+1:end]...)

    sameclassdmat = Distances.pairwise(dist, X[i], X[i])
    otherclassdmat = Distances.pairwise(dist,  allXl, X[i]) # each column represents one xi in Xi 

    # All the scored sums for each row of Xi in one go - faster than computing in a for loop
    allDXi::Matrix{Float64} = sum(S.(sameclassdmat), dims=1)
    allDXl::Matrix{Float64} = sum(S.(otherclassdmat), dims=1)
    cost += sum(allDXl ./ allDXi) # should have same length = |Xi|

    allKXi = pairwise_expanded(uwdist, X[i], X[i]) #[i, i, k]
    allKXl = pairwise_expanded(uwdist, allXl, X[i]) #[l, i, k]

    sameclassdmat[sameclassdmat .== 0] .= 10^-6
    #otherclassdmat[otherclassdmat .== 0] .= 10^-6

    for k in 1:length(dist.weights)
      # both of these should have the same length = |Xi| 
      alldDXi = sum(allKXi[:, :, k] .* (sameclassdmat .^(1 - dist.p)) .* Sβ€².(sameclassdmat), dims=1)
      alldDXl = sum(allKXl[:, :, k] .* (otherclassdmat .^(1 - dist.p)) .* Sβ€².(otherclassdmat), dims=1)
      dw[k] += 1/dist.p * sum(((alldDXl - (allDXl .* alldDXi ./ allDXi) ) ./ allDXi ))
    end
  end
  dw, cost
end

  

@code_warntype comes up clean for all these functions (although I don’t understand why I’m required to use a concrete type like Vector{Matrix{Float64}} instead of Vector{AbstractMatrix{T}} where T in the function call to make this compile).

On two classes with 200 data points each:
2.576388 seconds (118.00 k allocations: 178.669 MiB, 6.00% gc time)
That’s 10x better, can it go any faster?

There’s still a lot of repeated effort going into this function call - I don’t need to calculate the distances and the K’s separately, for example.

You need

Vector{<:AbstractMatrix{T}} where T

Because of this.

I didn’t try anything, but here you should probably use a view (add @views to the whole function, that’s cleaner). And probably there is a dot missing in the assignment, otherwise you are not doing the operation in place.

2 Likes

I updated the function to remove the calls to Distances.pairwise. This did not result in much of a speedup (only about 10%) which I’m confused about, since Distances.pairwise recalculates the entire distance matrix as opposed to summing up what’s already there.

New code for function here
@views function costd(dist :: Distances.PreMetric, S :: Function , Sβ€² :: Function, X::Vector{Matrix{Float64}})
  cost = 0.0
  dw = zeros(length(dist.weights))
  #unweighted distance
  uwdist = WeightedPeriodicNormalizedMinkowski(ones(length(dist.weights)), dist.periods, dist.variances, dist.p)
  for i in eachindex(X) # for class in X
    # points in other classes combined
    allXl::Matrix{Float64} = hcat(X[1:i-1]..., X[i+1:end]...)
    allKXi = pairwise_expanded(uwdist, X[i], X[i]) #[i, i, k]
    allKXl = pairwise_expanded(uwdist, allXl, X[i]) #[l, i, k]
    
    # this creates error up to 5e-16 compared to calculating distance directly.
    sameclassdmat = (mapslices(x -> x * dist.weights, allKXi, dims=[2, 3])[:,:,1]).^(1/dist.p)
    otherclassdmat =(mapslices(x -> x * dist.weights, allKXl, dims=[2, 3])[:,:,1]).^(1/dist.p)
    # All the scored sums for each row of Xi in one go - faster than computing in a for loop
    allDXi::Matrix{Float64} = sum(S.(sameclassdmat), dims=1)
    allDXl::Matrix{Float64} = sum(S.(otherclassdmat), dims=1)
    cost += sum(allDXl ./ allDXi) # should have same length = |Xi|

    sameclassdmat[sameclassdmat .== 0] .= 10^-6
    #otherclassdmat[otherclassdmat .== 0] .= 10^-6

    for k in 1:length(dist.weights)
      # both of these should have the same length = |Xi| 
      alldDXi = sum(allKXi[:, :, k] .* (sameclassdmat .^(1 - dist.p)) .* Sβ€².(sameclassdmat), dims=1)
      alldDXl = sum(allKXl[:, :, k] .* (otherclassdmat .^(1 - dist.p)) .* Sβ€².(otherclassdmat), dims=1)
      dw[k] += 1/dist.p * sum(((alldDXl - (allDXl .* alldDXi ./ allDXi) ) ./ allDXi ))
    end
  end
  dw, cost
end

I tried adding views to the function in various places (the whole function, the inner for loop, the sum operations). Couldn’t see a speedup in the benchmark:


julia> include("pattern_test.jl")
WARNING: replacing module Optimizer.
Main.Optimizer

julia> @benchmark Optimizer.costd(Optimizer.distance, Optimizer.S, Optimizer.Sp, Optimizer.X())
BenchmarkTools.Trial: 4 samples with 1 evaluation.
 Range (min … max):  1.301 s …   1.463 s  β”Š GC (min … max):  4.36% … 16.57%
 Time  (median):     1.368 s              β”Š GC (median):    11.01%
 Time  (mean Β± Οƒ):   1.375 s Β± 83.407 ms  β”Š GC (mean Β± Οƒ):  10.92% Β±  6.04%

  β–ˆβ–ˆ                                          β–ˆ           β–ˆ
  β–ˆβ–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆβ–β–β–β–β–β–β–β–β–β–β–β–ˆ ▁
  1.3 s          Histogram: frequency by time        1.46 s <

 Memory estimate: 491.93 MiB, allocs estimate: 15124.

julia> include("pattern_test.jl")
WARNING: replacing module Optimizer.
Main.Optimizer

julia> @benchmark Optimizer.costd(Optimizer.distance, Optimizer.S, Optimizer.Sp, Optimizer.X())
BenchmarkTools.Trial: 4 samples with 1 evaluation.
 Range (min … max):  1.289 s …   1.459 s  β”Š GC (min … max):  6.91% … 16.52%
 Time  (median):     1.370 s              β”Š GC (median):    11.28%
 Time  (mean Β± Οƒ):   1.372 s Β± 90.507 ms  β”Š GC (mean Β± Οƒ):  11.17% Β±  5.87%

  β–ˆ  β–ˆ                                              β–ˆ     β–ˆ
  β–ˆβ–β–β–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆβ–β–β–β–β–β–ˆ ▁
  1.29 s         Histogram: frequency by time        1.46 s <

 Memory estimate: 491.93 MiB, allocs estimate: 15128.

In the second one I have @views around the whole function.
I looked at the profile (which I wish showed times for each bar) and outside of the expected chunk for computing the β€œexpanded” 3d distance matrix I find that the lines with something = sum(...) are taking lots of time too. I would have thought those would be fast, since they’re matrix operations, right?

These are the lines that take up time that I didn’t expect:

    sameclassdmat = (mapslices(x -> x * dist.weights, allKXi, dims=[2, 3])[:,:,1]).^(1/dist.p)
    otherclassdmat =(mapslices(x -> x * dist.weights, allKXl, dims=[2, 3])[:,:,1]).^(1/dist.p)

Seems the mapslices call is taking a long time.

and

    for k in 1:length(dist.weights)
      # both of these should have the same length = |Xi| 
      alldDXi = sum(allKXi[:, :, k] .* (sameclassdmat .^(1 - dist.p)) .* Sβ€².(sameclassdmat), dims=1)
      alldDXl = sum(allKXl[:, :, k] .* (otherclassdmat .^(1 - dist.p)) .* Sβ€².(otherclassdmat), dims=1)
      dw[k] += 1/dist.p * sum(((alldDXl - (allDXl .* alldDXi ./ allDXi) ) ./ allDXi ))
    end

here there’s a call to materialize despite using @views on the whole function that seems to take a long time as well.

I think you can get better advice if you prepare two minimal examples where that mapslices call and those sum calls are the only thing happening, with data similar to what you are feeding these functions.

You are probably allocating intermediate arrays there which are not necessary, and perhaps expanding those operations into optimized loops is worthwhile, but I find the code too long to look at it closely and understand all that is going on (for instance, I’m not sure if alldDXl is an array or a scalar, and I don’t have the time now to look closer to the code).

2 Likes

Ok here is a short MWE of the first one (I will do the second one too when I get a minute)

julia> @benchmark (mapslices(x -> x*rand(Float64, (50,)), $(rand(Float64, (300,300,50)), dims=[2,3]))[:,:,1]).^(1/2)
BenchmarkTools.Trial: 17 samples with 1 evaluation.
 Range (min … max):  278.437 ms … 352.189 ms  β”Š GC (min … max): 0.00% … 1.45%
 Time  (median):     289.182 ms               β”Š GC (median):    1.94%
 Time  (mean Β± Οƒ):   301.138 ms Β±  24.439 ms  β”Š GC (mean Β± Οƒ):  1.69% Β± 1.01%

  β–ˆ  β–ˆβ–ˆ ▁▁▁   ▁ ▁     ▁         ▁▁                  ▁      ▁  ▁
  β–ˆβ–β–β–ˆβ–ˆβ–β–ˆβ–ˆβ–ˆβ–β–β–β–ˆβ–β–ˆβ–β–β–β–β–β–ˆβ–β–β–β–β–β–β–β–β–β–ˆβ–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆβ–β–β–β–β–β–β–ˆβ–β–β–ˆ ▁
  278 ms           Histogram: frequency by time          352 ms <

 Memory estimate: 3.10 MiB, allocs estimate: 2446.

I run this call (number of classes in the dataset)^2 so that explains a lot of the time i’m seeing used.
How can I speed this up?

Unfortunately, the answer might be to not use mapslices, it’s known to be slow.
You also allocate a new array for each call to rand, try allocating this array beforehand and fill it with new random values using rand!.

Slicing an array with [:, :, 1] further creates a copy. Try using @views to avoid making the copy.

2 Likes

the benchmark tools documentation said that if I put a $(...) around the variable, it will be interpolated before the benchmark is run so it wouldn’t count towards it. (I verified that it significantly reduces the memory impact, so I’m guessing that’s happening).

Ok I tried using what’s written in that post, and it works better:

julia> @benchmark x= reduce(hcat, [s*$(rand(Float64, 50)) for s = eachslice($(rand(Float64, (300,350,50))), dims=1)]).^(1/2)'
BenchmarkTools.Trial: 138 samples with 1 evaluation.
 Range (min … max):  34.508 ms … 45.334 ms  β”Š GC (min … max): 0.00% … 18.42%
 Time  (median):     36.107 ms              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   36.350 ms Β±  1.364 ms  β”Š GC (mean Β± Οƒ):  0.60% Β±  2.55%

            β–ˆβ–†β–‡β–‡β–‚β–„β– β–ƒβ–„
  β–„β–„β–ƒβ–‡β–„β–„β–†β–†β–ˆβ–‡β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–†β–ˆβ–ˆβ–„β–†β–†β–„β–β–ˆβ–β–ƒβ–β–β–β–β–β–ƒβ–β–β–β–ƒβ–β–β–β–β–β–β–β–β–β–β–ƒβ–β–β–β–ƒβ–β–ƒβ–β–ƒβ–β–ƒ β–ƒ
  34.5 ms         Histogram: frequency by time        41.3 ms <

 Memory estimate: 2.47 MiB, allocs estimate: 912.

And the overall function is also about 15x faster for 2 classes:

BenchmarkTools.Trial: 36 samples with 1 evaluation.
 Range (min … max):  134.299 ms … 157.559 ms  β”Š GC (min … max): 3.59% … 8.58%
 Time  (median):     139.221 ms               β”Š GC (median):    6.92%
 Time  (mean Β± Οƒ):   140.069 ms Β±   4.762 ms  β”Š GC (mean Β± Οƒ):  6.70% Β± 1.60%

    β–ƒβ–ˆβ–ˆ     β–ƒ β–ƒβ–ˆ β–ƒ β–ˆ  β–ƒ    β–ƒ                                     
  β–‡β–β–ˆβ–ˆβ–ˆβ–‡β–‡β–‡β–‡β–‡β–ˆβ–β–ˆβ–ˆβ–β–ˆβ–‡β–ˆβ–β–β–ˆβ–β–β–β–β–ˆβ–β–β–β–‡β–β–β–‡β–β–‡β–β–β–β–‡β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–‡ ▁
  134 ms           Histogram: frequency by time          158 ms <

 Memory estimate: 84.30 MiB, allocs estimate: 4615.

BenchmarkTools.Trial: 29 samples with 1 evaluation.
 Range (min … max):  173.998 ms … 183.841 ms  β”Š GC (min … max): 6.93% … 4.94%
 Time  (median):     176.675 ms               β”Š GC (median):    7.53%
 Time  (mean Β± Οƒ):   177.023 ms Β±   2.066 ms  β”Š GC (mean Β± Οƒ):  7.45% Β± 0.53%

       β–ƒβ–ƒ   β–ƒ β–ƒ   β–ˆβ–ƒ β–ƒ β–ƒ            β–ƒ                            
  β–‡β–β–β–β–‡β–ˆβ–ˆβ–β–‡β–β–ˆβ–β–ˆβ–β–β–‡β–ˆβ–ˆβ–β–ˆβ–β–ˆβ–β–β–‡β–β–β–‡β–‡β–β–β–‡β–β–β–ˆβ–β–β–β–‡β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–‡ ▁
  174 ms           Histogram: frequency by time          184 ms <

 Memory estimate: 104.42 MiB, allocs estimate: 4845.

It’s still coming out to about 5-6s for 10 classes and 60s for 50 classes which i’ll deal with if that’s the best I can do for this use case; still slower than I hoped.

The same line remains a hotspot in the profile even with this improvement…

The other hotspot looks like this:

function test() 
 dw = zeros(50)
 K = rand(Float64, (300, 350, 50))
 M = (K .+ 1).^(1/2) .+ K # some matrix calculation outside the loop
 
 @views for k in 1:50
     # !!! this line !!!
     Kk = sum(K[:, :, k] .* M, dims=1)

     # some matrix calculations using that matrix, doesn't take much time
     dw[k] += 1/3 * sum(Kk .+ ( Kk .+ 1).^(1/2))
 end
end

How can I speed this one up?

In general, you will get better speed by using less memory more often. In this case, for example, if you made K=rand(300,350) and reused the memory each iteration through the loop you would be going a lot faster.