Flux loss with contribution gradient is slow

I am trying to replicate some computations on solving PDEs with neural networks (the committor problem, see [1802.10275] Solving for high dimensional committor functions using artificial neural networks). This involves using a loss function with contributions of the form
\int |\nabla_x u(x;\theta)|^2 dx
where u is scalar valued. I have implemented something that gives results, but is extremely slow (at least on my CPU), and I’m looking for any suggestions on how to speed it up. An example that reveals the slowness is the following:

using ProgressMeter
using LinearAlgebra
using Flux

Random.seed!(10)
xdata_ = randn(Float32, 1, 10^4)
xdata = [[x_] for x_ in xdata_]
ydata = rand(Float32,size(xdata));
traindata = zip(xdata, ydata)

Random.seed!(500);
u = Chain(Dense(1 => 20, tanh), Dense(20 => 1, tanh));
opt_state = Flux.setup(Flux.Adam(0.01), u)

function dirichlet_loss(u, x, y)
    u_(x) = only(u(x))
    ∇u_ = gradient.((u_), x)[1]
    return mean(norm.(∇u_).^2)
end

@show l_= dirichlet_loss(u, xdata, ydata)

loader = Flux.DataLoader((xdata, ydata), batchsize=100, shuffle=true);

losses = [l_];
@showprogress for epoch in 1:10
    Flux.train!(dirichlet_loss, u, loader, opt_state)
    l_ = dirichlet_loss(u, xdata, ydata);
    push!(losses,l_)
end

This takes about 10 minutes for only 10 training epochs. Any advice would be appreciated.

This seems to be due to the nested gradient calls, i.e., Flux.train! will compute the gradient of dirichlet_loss which already involves a gradient call. Unfortunately, Zygote does not seem to handle this well.
What I end up doing, is writing my own training loop and handling gradients explicitly using AbstractDifferentiation. Might be a bit more verbose, but makes it very easy to try different AD combinations:

using AbstractDifferentiation
using ForwardDiff
using ReverseDiff
using BenchmarkTools

function dirichlet_loss(ad, u, x, y)
     u_(x) = only(u(x))
     ∇u_ = AbstractDifferentiation.gradient.(Ref(ad), Ref(u_), x)[1]
     return mean(norm.(∇u_).^2)
end

# Required as abstract differentiation works on flat vectors
x, re = Flux.destructure(u)

# feel free to try other combinations (have found none involving Zygote)
AbstractDifferentiation.gradient(AbstractDifferentiation.ReverseDiffBackend(), x -> dirichlet_loss($AbstractDifferentiation.ForwardDiffBackend(), re(x), $xdata, $ydata), $x)

Also not sure about that line ∇u_ = gradient.((u_), x)[1]. First, (u_) does nothing – to prevent broadcasting you would need (u_, ) – and secondly, you seem to be broadcasting gradient and then simply select the first element of that vector – to unwrap all gradients, getindex would need to be broadcasted as well.

1 Like

Also not sure about that line ∇u_ = gradient.((u_), x)[1]. First, (u_) does nothing – to prevent broadcasting you would need (u_, ) – and secondly, you seem to be broadcasting gradient and then simply select the first element of that vector – to unwrap all gradients, getindex would need to be broadcasted as well.

@bertschi Firstly, I also noticed that and replaced that part with getindex.(gradient.(u_, x), 1).
@gideonsimpson The other aspect I noticed was that you are creating a vector consisting of vectors where the inner vectors are acting as the inputs to a dense layer because dense layers need an array (not a scaler) as input. But similar behavior can be achieved by simply using a second batch dimension instead of using nested vectors, I think. I implemented a second loss function without using so much broadcasting and compared the results before and after training - The results seem to be the same. The training using the version without using broadcasting for the gradient function (dirichlet_loss_nb) is on my machine approximately 120x faster (with a batch size of 100)!

using ProgressMeter
using LinearAlgebra
using Flux
using Random
using Statistics

@show LinearAlgebra.BLAS.get_num_threads()
@show Threads.nthreads()

Random.seed!(10)
xdata_ = randn(Float32, 1, 10^4)
xdata = [[x_] for x_ in xdata_]
ydata = rand(Float32,size(xdata))
traindata = zip(xdata, ydata)

Random.seed!(500)
u = Chain(Dense(1 => 20, tanh), Dense(20 => 1, tanh))
opt_state = Flux.setup(Flux.Adam(0.01), u)

u2 = Chain(Dense(1 => 20, tanh), Dense(20 => 1, tanh))
# copy the weights and biases to the new model
u2[1].weight .= copy(u[1].weight)
u2[1].bias .= copy(u[1].bias)
u2[2].weight .= copy(u[2].weight)
u2[2].bias .= copy(u[2].bias)
opt_state2 = Flux.setup(Flux.Adam(0.01), u2)

function dirichlet_loss(u, x, y)
    u_(x) = only(u(x))
    ∇u_ = getindex.(gradient.(u_, x), 1)
    return mean(norm.(∇u_).^2)
end

function dirichlet_loss_nb(u, x, y)
    # ∇x instead of ∇u_ makes more sense here since we're calculating the gradient w.r.t. x
    ∇x = gradient(x -> sum(u(x)), x)[1]
    return mean(norm.(∇x).^2)
end

@show l_= dirichlet_loss(u, xdata, ydata)
# note that xdata_ of size (1, 10^4) acting as one big batch is used here
@show l2_= dirichlet_loss_nb(u2, xdata_, ydata)
@show isapprox(l_, l2_)

# note that shuffling here would lead to different results between the two models 
# because then the data wouldn't be in the same order anymore
loader = Flux.DataLoader((xdata, ydata), batchsize=100, shuffle=false)
# important to use xdata_ for the second data loader!
loader2 = Flux.DataLoader((xdata_, ydata), batchsize=100, shuffle=false)

for epoch in 1:3 # note the compilation times during the first epoch
    @time Flux.train!(dirichlet_loss, u, loader, opt_state)
    @time Flux.train!(dirichlet_loss_nb, u2, loader2, opt_state2)
end

@show isapprox(u[1].weight, u2[1].weight)
@show isapprox(u[1].bias, u2[1].bias)
@show isapprox(u[2].weight, u2[2].weight)
@show isapprox(u[2].bias, u2[2].bias)

With these results on my machine (Ryzen 9 5900X):

LinearAlgebra.BLAS.get_num_threads() = 12
Threads.nthreads() = 12
l_ = dirichlet_loss(u, xdata, ydata) = 0.0005279569f0
l2_ = dirichlet_loss_nb(u2, xdata_, ydata) = 0.0005279569f0
isapprox(l_, l2_) = true
 88.302447 seconds (292.72 M allocations: 20.608 GiB, 7.28% gc time, 74.64% compilation time)
 33.158363 seconds (69.19 M allocations: 3.622 GiB, 4.08% gc time, 99.27% compilation time)
 22.465978 seconds (90.13 M allocations: 10.016 GiB, 15.14% gc time, 0.01% compilation time)
  0.178209 seconds (749.92 k allocations: 108.945 MiB, 11.89% gc time)
 22.114513 seconds (90.13 M allocations: 10.016 GiB, 14.56% gc time)
  0.178573 seconds (749.92 k allocations: 108.945 MiB, 11.75% gc time)
isapprox((u[1]).weight, (u2[1]).weight) = true
isapprox((u[1]).bias, (u2[1]).bias) = true    
isapprox((u[2]).weight, (u2[2]).weight) = true
isapprox((u[2]).bias, (u2[2]).bias) = true
1 Like

So this definitely gives me much more reasonable performance. I have a few follow up questions. I now understand the difference between these two:

julia> gradient.((u_), xdata[1:2])
2-element Vector{Tuple{Vector{Float32}}}:
 ([-0.00028919056],)
 ([-0.0003214106],)
julia>getindex.(gradient.(u_, xdata[1:2]), 1)
2-element Vector{Vector{Float32}}:
 [-0.00028919056]
 [-0.0003214106]

But one thing I"m concerned about (maybe it doesn’t matter) is that gradient(x -> sum(u(x)), x)[1] isn’t vectorized:

julia> gradient(x -> sum(u(x)), xdata[1])
(Float32[-0.00028919056],) # which I can use a get index on

but

julia> gradient(x -> sum(u(x)), xdata[1:2])

generates an error, though this is fixed if I switch it to gradient.(x -> sum(u(x)), xdata[1:2]) (were you just missing a period?) and use getindex.

Another thing is that in this example, the input dimension in scalar. In some problems I want to do, the input dimension will be \mathbb{R}^d, with d>1; is there anything I need to think about when migrating to that case?

julia> gradient(x -> sum(u(x)), xdata[1]) from dirichlet_loss_nb(u, x, y) works because the first element in xdata is a vector (no batch dimension) containing exactly one scaler. x_data[1:2], however, gives:

2-element Vector{Vector{Float32}}:
 [0.34105664]
 [-0.054080337]

Your original approach dirichlet_loss(u, x, y) can work with this format since the gradient function is broadcasted over every inner vector. But broadcasting the gradient calculation seems to be disproportionately slow (so there is intentionally no period before the gradient call in the new version). Instead, dirichlet_loss_nb(u, x, y) tries to avoid this by treating multiple samples as a batch of size (1, batch_size). A big batch can be efficiently propagated through the network in one go and the complexity of the differentiation also decreases heavily.
So dirichlet_loss_nb(u, x, y) expects a matrix of scalers instead of a nested vector of type Vector{Vector{Float32}}. For example:

# what y is here doesn't matter so much because it's currently not really in use
ll_ = dirichlet_loss(u, xdata[1:3], ydata[1:3])
ll2_ = dirichlet_loss_nb(u, xdata_[:, 1:3], ydata[:, 1:3])
@show isapprox(ll_, ll2_)

xdata_[:, 1:3] gives a matrix of scalers of size (1, 3) where 1 is the input dimension for the first layer in the network and where 3 is the batch size.

is there anything I need to think about when migrating to that case?

The dirichlet_loss_nb(u, x, y) has to be modified a bit to compute the vector norms correctly:

function dirichlet_loss_nb(u, x, y)
    # ∇x instead of ∇u_ makes more sense here since we're calculating the gradient w.r.t. x
    ∇x = gradient(x -> sum(u(x)), x)[1]
    # return mean(norm.(∇x).^2)
    return mean(norm.(eachslice(∇x, dims=2)).^2)
end

in_features = 5
u3 = Chain(Dense(in_features => 20, tanh), Dense(20 => 1, tanh))

xdata_2 = randn(Float32, in_features, 10^4)
# the access patterns are now a bit more complicated...
xdata2 = [xdata_2[i*in_features+1:(i+1)*in_features] for i in 0:(length(xdata_2) ÷ in_features)-1]
# ...or simply use eachslice(...) here as well
xdata2 = convert(Vector{Vector{Float32}}, (eachslice(xdata_2, dims=2)))

data2_loss = @time dirichlet_loss(u3, xdata2, [1])
data2_loss2 = @time dirichlet_loss_nb(u3, xdata_2, [1])
@show isapprox(data2_loss, data2_loss2)

@info isapprox(
    getindex.(gradient.(x -> only(u3(x)), xdata2[1:5]), 1),
    eachslice(gradient(x -> sum(u3(x)), xdata_2[:, 1:5])[1], dims=2)
)

Giving me:

  0.182084 seconds (1.51 M allocations: 130.425 MiB, 10.60% gc time)
  0.001358 seconds (57 allocations: 4.204 MiB)
isapprox(data2_loss, data2_loss2) = true
true

norm.(∇x) would give wrong results because the norm would be computed per scaler element, norm.(eachslice(∇x, dims=2)) fixes this by extracting the slices along the correct dimension which then become the vectors getindex.(gradient.(x -> only(u3(x)), xdata2), 1) would give in the other loss function.

1 Like

How could Julia code looks so ugly but python code looks more elegant?

import torch
import torch.nn as nn
from torch.func import grad, vmap, functional_call

def u(model, params, x):
    return torch.sum(functional_call(model, params, x))

def dirichlet_loss(model, params, x):
    u_x = grad(u, argnums=2)(model, params, x)
    return torch.mean(torch.norm(u_x))

def loss(model, params, x):
    return grad(dirichlet_loss, argnums=1)(model, params, x)

model = nn.Sequential(
    nn.Linear(1, 20),
    nn.Tanh(),
    nn.Linear(20, 1),
    nn.Tanh()
)
    
params = dict(model.named_parameters())
x = torch.randn(1, requires_grad=True)
u(model, params, x)
dirichlet_loss(model, params, x)

# batched version
batch_size, feature_size = 32, 10
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.Tanh(),
    nn.Linear(20, 1),
    nn.Tanh()
)
params = dict(model.named_parameters())
x = torch.randn(batch_size, feature_size, requires_grad=True)
grad_params = vmap(loss, in_dims=(None, None, 0))(model, params, x)
1 Like