There are occasions where `@profview` shows seemingly inordinate amount of time …spent moving data to the GPU given the array size, and possibly excessive GPU memory usage? Not sure what I should be expecting. Could be related to having two array outputs rather than one? I've also see type instability reported through `@code_warntype` and Cthulhu that I'm not sure how to resolve.
Using a toy example to show the effect:
```julia
]activate --temp
]add cuDNN, CUDA, Flux
using CUDA, cuDNN, Flux, Statistics
struct Split{T1 <: Dense, T2 <: Dense}
s1::T1
s2::T2
max_sources::Int
end
function Split(feature_dim::Int, max_sources::Int)
Split(
Dense(feature_dim => 2 * max_sources),
Dense(feature_dim => max_sources),
max_sources)
end
Flux.@layer Split
(m::Split)(input) = reshape(m.s1(input), 2, m.max_sources, :), m.s2(input)
function imagegen_test(batch_size)
return randn(Float32, 2048, batch_size), (randn(Float32, 2, 20, batch_size), randn(Float32, 20, batch_size))
end
function test()
training_batch_size = 256
iters_per_eval = 64
network = Split(2048, 128) |> gpu
optimizer = Flux.Optimise.Adam(1E-4)
opt_state = Flux.setup(optimizer, network)
kernel_sigmas_gpu = Float32[64.0, 320.0, 640.0, 1920.0] |> gpu
for i ∈ 1:iters_per_eval
training_data = imagegen_test(training_batch_size) |> gpu
Flux.train!(network, (training_data,), opt_state) do m, x, y
θ_pred, intensity_pred = m(x)
loss_func(θ_pred, intensity_pred, y..., kernel_sigmas_gpu)
end
end
return nothing
end
pairwise_cityblock(c) =
dropdims(sum((Flux.unsqueeze(c, 2) .- Flux.unsqueeze(c, 3)) .|> abs, dims = 1), dims = 1)
function kernel_loss(K, predicted_weights, target_weights)
weights = [predicted_weights; -target_weights]
embedding_loss = batched_vec(Flux.unsqueeze(weights, 1), batched_vec(K, weights))
return dropdims(embedding_loss, dims = 1)
end
function multiscale_l1_laplacian_loss(θ_predicted, w_predicted, θ_target, w_target, inv_scale_factors)
D = pairwise_cityblock([θ_predicted θ_target])
losses = kernel_loss.(eachslice(exp.(-D ./ reshape(inv_scale_factors, 1, 1, 1, :)), dims = 4), Ref(w_predicted), Ref(w_target))
return sum(losses)
end
function loss_func(x1, y1, x2, y2, kernel_sigmas)
mean(multiscale_l1_laplacian_loss(x1, y1, x2, y2, kernel_sigmas))
end
test()
@profview test()
```