julia> gradient(x -> sum(u(x)), xdata[1])
from dirichlet_loss_nb(u, x, y)
works because the first element in xdata
is a vector (no batch dimension) containing exactly one scaler. x_data[1:2]
, however, gives:
2-element Vector{Vector{Float32}}:
[0.34105664]
[-0.054080337]
Your original approach dirichlet_loss(u, x, y)
can work with this format since the gradient function is broadcasted over every inner vector. But broadcasting the gradient calculation seems to be disproportionately slow (so there is intentionally no period before the gradient call in the new version). Instead, dirichlet_loss_nb(u, x, y)
tries to avoid this by treating multiple samples as a batch of size (1, batch_size)
. A big batch can be efficiently propagated through the network in one go and the complexity of the differentiation also decreases heavily.
So dirichlet_loss_nb(u, x, y)
expects a matrix of scalers instead of a nested vector of type Vector{Vector{Float32}}
. For example:
# what y is here doesn't matter so much because it's currently not really in use
ll_ = dirichlet_loss(u, xdata[1:3], ydata[1:3])
ll2_ = dirichlet_loss_nb(u, xdata_[:, 1:3], ydata[:, 1:3])
@show isapprox(ll_, ll2_)
xdata_[:, 1:3]
gives a matrix of scalers of size (1, 3) where 1 is the input dimension for the first layer in the network and where 3 is the batch size.
is there anything I need to think about when migrating to that case?
The dirichlet_loss_nb(u, x, y)
has to be modified a bit to compute the vector norms correctly:
function dirichlet_loss_nb(u, x, y)
# ∇x instead of ∇u_ makes more sense here since we're calculating the gradient w.r.t. x
∇x = gradient(x -> sum(u(x)), x)[1]
# return mean(norm.(∇x).^2)
return mean(norm.(eachslice(∇x, dims=2)).^2)
end
in_features = 5
u3 = Chain(Dense(in_features => 20, tanh), Dense(20 => 1, tanh))
xdata_2 = randn(Float32, in_features, 10^4)
# the access patterns are now a bit more complicated...
xdata2 = [xdata_2[i*in_features+1:(i+1)*in_features] for i in 0:(length(xdata_2) ÷ in_features)-1]
# ...or simply use eachslice(...) here as well
xdata2 = convert(Vector{Vector{Float32}}, (eachslice(xdata_2, dims=2)))
data2_loss = @time dirichlet_loss(u3, xdata2, [1])
data2_loss2 = @time dirichlet_loss_nb(u3, xdata_2, [1])
@show isapprox(data2_loss, data2_loss2)
@info isapprox(
getindex.(gradient.(x -> only(u3(x)), xdata2[1:5]), 1),
eachslice(gradient(x -> sum(u3(x)), xdata_2[:, 1:5])[1], dims=2)
)
Giving me:
0.182084 seconds (1.51 M allocations: 130.425 MiB, 10.60% gc time)
0.001358 seconds (57 allocations: 4.204 MiB)
isapprox(data2_loss, data2_loss2) = true
true
norm.(∇x)
would give wrong results because the norm would be computed per scaler element, norm.(eachslice(∇x, dims=2))
fixes this by extracting the slices along the correct dimension which then become the vectors getindex.(gradient.(x -> only(u3(x)), xdata2), 1)
would give in the other loss function.