Did you move your model to the GPU as well? Otherwise, it’s difficult to say without seeing a full example.
Yes. I think so.
using Flux
using Zygote
using WAV
using DataLoaders
import LearnBase: nobs, getobs
using .DCUNet
global SAMPLE_RATE = 48000
global N_FFT = div((SAMPLE_RATE * 64), 1000) + 4
global HOP_LENGTH = div((SAMPLE_RATE * 16), 1000) + 4
global MAX_LENGTH = 165000
struct Data
prepare_sample(waveform) = (MAX_LENGTH - length(waveform) > 0) ? vcat(zeros(Float32, MAX_LENGTH - length(waveform)), waveform) : waveform[1:MAX_LENGTH]
nobs(ds::Data) = (length(ds.files))
getobs(ds::Data, idx::Int) = stft(prepare_sample(vec(Float32.(wavread(ds.files[idx])[1]))))
path = "./data/exp_raw/"
noisy_train_path = path * "noisy_trainset_28spk_wav"
clean_train_path = path * "clean_trainset_28spk_wav"
noisy_test_path = path * "noisy_testset_wav.nosync"
clean_test_path = path * "clean_testset_wav.nosync"
noisyTrainData = Data(readdir(noisy_train_path, join=true));
cleanTrainData = Data(readdir(clean_train_path, join=true));
# noisyTestData = DataSet(readdir(noisy_test_path, join=true));
# cleanTestData = DataSet(readdir(clean_test_path, join=true));
@assert nobs(noisyTrainData) == nobs(cleanTrainData)
# @assert nobs(testData)[1] == nobs(testData)[2]
trn = DataLoader((noisyTrainData, cleanTrainData), 8, partial=true);
# tst = DataLoader((testData.noisyFiles , testData.cleanFiles) , 2, partial=true);
model = UNet() |> gpu;
loss(x, y) = wsdrLoss(x, model(x), y)
opt = ADAM(1e-4);
function my_custom_train!(loss, model, data, opt)
ps = params(model)
epoch = 1;
for (x, y) in data
# back is a method that computes the product of the gradient so far with its argument.
x = x |> gpu; y = y |> gpu;
train_loss, back = Zygote.pullback(() -> loss(x, y), ps)
# Insert whatever code you want here that needs training_loss, e.g. logging.
# Apply back() to the correct type of 1.0 to get the gradient of loss.
gs = back(one(train_loss))
# Insert what ever code you want here that needs gradient.
# E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge.
update!(opt, ps, gs)
epoch += 1;
# Here you might like to check validation set accuracy, and break out to do early stopping.
And to train the model :
my_custom_train!(loss, model, trn, opt)
I see
┌ Warning: Slow fallback implementation invoked for conv! You probably don't want this; check your datatypes.
│ yT = ForwardDiff.Dual{Nothing,Float32,2}
│ T1 = ForwardDiff.Dual{Nothing,Float32,2}
│ T2 = Float32
└ @ NNlib /opt/.julia/packages/NNlib/2Wxlq/src/conv.jl:206
┌ Warning: Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`
└ @ GPUArrays /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:43
and nothing else is coming afterwards. By the way, why do I get ForwardDiff.Dual{Nothing, Float32,2}
warning ?
Is it possible to see loss for each batch ?
Also , I tried gradient function dmodel = gradient(wsdrLoss, model(x), x, y)[1]
. It returns the gradients but when I try to use the update function flux.update!(opt, params(model), dmodel)
I get :
ERROR: InvalidIRErrorp.s. I might mix a : compiling kernel getindex_kernel(CUDA.CuKernelContext, CuDeviceArray{Complex{Float32},4,1}, CuDeviceArray{Complex{Float32},4,1}, Tuple{Int64}, CuDeviceArray{Float32,4,1}) resulted in invalid LLVM IR
Reason: unsupported dynamic function invocation (call to #sprint#355(context, sizehint::Integer, ::typeof(sprint), f::Function, args...) in Base at strings/io.jl:100)
[1] #repr#356 at strings/io.jl:227
[2] limitrepr at strings/io.jl:229
[3] to_index at indices.jl:297
[4] to_index at indices.jl:274
[5] to_indices at indices.jl:325
[6] to_indices at indices.jl:322
[7] getindex at abstractarray.jl:1060
[8] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
[9] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported dynamic function invocation (call to print)
[1] print_to_string at strings/io.jl:135
[2] string at strings/io.jl:174
[3] to_index at indices.jl:297
[4] to_index at indices.jp.s. I might mix a l:274
[5] to_indices at indices.jl:325
[6] to_indices at indices.jl:322
[7] getindex at abstractarray.jl:1060
[8] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
[9] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported call through a literal pointer (call to jl_array_grow_end)
[1] _growend! at array.jl:892
[2] resize! at array.jl:1085
[3] print_to_string at strings/io.jl:137
[4] string at strings/io.jl:174
[5] to_index at indices.jl:297
[6] to_index at indices.jl:274
[7] to_indices at indices.jl:325
[8] to_indices at indices.jl:322
[9] getindex at abstractarray.jl:1060
[10] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
[11] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported call through a literal pointer (call to jl_array_del_end)
[1] _deleteend! at array.jl:901
[2] resize! at array.jl:1090
[3] print_to_string at strings/io.jl:137
[4] string at strings/io.jl:174
[5] to_index at indices.jl:297
[6] to_index at indices.jl:274
[7] to_indices at indices.jl:325
[8] to_indices at indices.jl:322
[9] getindex at abstractarray.jl:1060
[10] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
[11] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported call through a literal pointer (call to jl_array_to_string)
[1] String at strings/string.jl:39
[2] print_to_string at strings/io.jl:137
[3] string at strings/io.jl:174
[4] to_index at indices.jl:297
[5] to_index at indices.jl:274
[6] to_indices at indices.jl:325
[7] to_indices at indices.jl:322
[8] getindex at abstractarray.jl:1060
[9] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
[10] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported call through a literal pointer (call to jl_alloc_string)
[1] _string_n at strings/string.jl:60
[2] StringVector at iobuffer.jl:31
[3] #IOBuffer#331 at iobuffer.jl:114
[4] print_to_string at strings/io.jl:133
[5] string at strings/io.jl:174
[6] to_index at indices.jl:297
[7] to_index at indices.jl:274
[8] to_indices at indices.jl:325
[9] to_indices at indices.jl:322
[10] getindex at abstractarray.jl:1060
[11] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
[12] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported call through a literal pointer (call to jl_string_to_array)
[1] unsafe_wrap at strings/string.jl:71
[2] StringVector at iobuffer.jl:31
[3] #IOBuffer#331 at iobuffer.jl:114
[4] print_to_string at strings/io.jl:133
[5] string at strings/io.jl:174
[6] to_index at indices.jl:297
[7] to_index at indices.jl:274
[8] to_indices at indices.jl:325
[9] to_indices at indices.jl:322
[10] getindex at abstractarray.jl:1060
[11] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
[12] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported call through a literal pointer (call to __memset_avx2_unaligned_erms)
[1] fill! at array.jl:428
[2] #IOBuffer#331 at iobuffer.jl:121
[3] print_to_string at strings/io.jl:133
[4] string at strings/io.jl:174
[5] to_index at indices.jl:297
[6] to_index at indices.jl:274
[7] to_indices at indices.jl:325
[8] to_indices at indices.jl:322
[9] getindex at abstractarray.jl:1060
[10] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
[11] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
[1] check_ir(::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget,CUDA.CUDACompilerParams}, ::LLVM.Module) at /opt/.julia/packages/GPUCompiler/uTpNx/src/validation.jl:123
[2] macro expansion at /opt/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:239 [inlined]
[3] macro expansion at /opt/.julia/packages/TimerOutputs/ZmKD7/src/TimerOutput.jl:206 [inlined]
[4] codegen(::Symbol, ::GPUCompiler.CompilerJob; libraries::Bool, deferred_codegen::Bool, optimize::Bool, strip::Bool, validate::Bool, only_entry::Bool) at /opt/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:237
[5] compile(::Symbol, ::GPUCompiler.CompilerJob; libraries::Bool, deferred_codegen::Bool, optimize::Bool, strip::Bool, validate::Bool, only_entry::Bool) at /opt/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:39
[6] compile at /opt/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:35 [inlined]
[7] cufunction_compile(::GPUCompiler.FunctionSpec; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /opt/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:310
[8] cufunction_compile(::GPUCompiler.FunctionSpec) at /opt/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:305
[9] check_cache(::Dict{UInt64,Any}, ::Any, ::Any, ::GPUCompiler.FunctionSpec{typeof(GPUArrays.getindex_kernel),Tuple{CUDA.CuKernelContext,CuDeviceArray{Complex{Float32},4,1},CuDeviceArray{Complex{Float32},4,1},Tuple{Int64},CuDeviceArray{Float32,4,1}}}, ::UInt64; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /opt/.julia/packages/GPUCompiler/uTpNx/src/cache.jl:40
[10] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139 [inlined]
[11] cached_compilation at /opt/.julia/packages/GPUCompiler/uTpNx/src/cache.jl:65 [inlined]
[12] cufunction(::typeof(GPUArrays.getindex_kernel), ::Type{Tuple{CUDA.CuKernelContext,CuDeviceArray{Complex{Float32},4,1},CuDeviceArray{Complex{Float32},4,1},Tuple{Int64},CuDeviceArray{Float32,4,1}}}; name::Nothing, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /opt/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:297
[13] cufunction at /opt/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:294 [inlined]
[14] #launch_heuristic#853 at /opt/.julia/packages/CUDA/YeS8q/src/gpuarrays.jl:19 [inlined]
[15] launch_heuristic at /opt/.julia/packages/CUDA/YeS8q/src/gpuarrays.jl:17 [inlined]
[16] gpu_call(::typeof(GPUArrays.getindex_kernel), ::CuArray{Complex{Float32},4}, ::CuArray{Complex{Float32},4}, ::Tuple{Int64}, ::CuArray{Float32,4}; target::CuArray{Complex{Float32},4}, total_threads::Nothing, threads::Nothing, blocks::Nothing, name::Nothing) at /opt/.julia/packages/GPUArrays/jhRU7/src/device/execution.jl:61
[17] gpu_call(::typeof(GPUArrays.getindex_kernel), ::CuArray{Complex{Float32},4}, ::CuArray{Complex{Float32},4}, ::Tuple{Int64}, ::CuArray{Float32,4}) at /opt/.julia/packages/GPUArrays/jhRU7/src/device/execution.jl:46
[18] _getindex(::CuArray{Complex{Float32},4}, ::CuArray{Float32,4}) at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:135
[19] getindex(::CuArray{Complex{Float32},4}, ::CuArray{Float32,4}) at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:125
[20] update!(::ADAM, ::Params, ::CuArray{Complex{Float32},4}) at /opt/.julia/packages/Flux/q3zeA/src/optimise/train.jl:28
[21] top-level scope at REPL[42]:1
So what is happening ?
is just for convenience and ends up calling loss, back = Zygote.pullback(...)
and then basically just returns back(one(loss))
. The reason your example uses pullback
is that this gives you the actual value of the loss function in the same pass as well, which is often useful for logging.
Thank you for the explanation.