Zygote pullback vs gradient

Hello,

I have a sequence to sequence Flux model which has a custom loss function and custom train function.

I can do forward propagation. And also take the gradients of the model by using gradient method of Zygote. However, using the custom train function that is taken from the Flux’s official site shows warnings about scalar operations on GPU which I understand them because I have some scalar operations on GPU. But I do not understand why the gradient operation takes too long. It seems like it is taking too much time. On the other hand, the gradient function takes the gradients immediately; (compared to custom_train! function).

function my_custom_train!(loss, model, data, opt)
    ps = params(model)
    local epoch = 1;
    for (x, y) in data
      # back is a method that computes the product of the gradient so far with its argument.
      x = x |> gpu; y = y |> gpu;
      println(epoch)
      train_loss, back = Zygote.pullback(() -> wsdrLoss(x, model(x), y), ps)
      # Insert whatever code you want here that needs training_loss, e.g. logging.
      l# ogging_callback(training_loss)
      # Apply back() to the correct type of 1.0 to get the gradient of loss.
      gs = back(one(train_loss))
      # Insert what ever code you want here that needs gradient.
      # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge.
      update!(opt, ps, gs)
      epoch += 1;
      println(epoch)
      # Here you might like to check validation set accuracy, and break out to do early stopping.
    end
end

Could someone explain me what is the difference between Zygote.pullback and Zygote.gradient() ? And why do we use pullback instead of gradient inside the train! function ?

B.R.

Did you move your model to the GPU as well? Otherwise, it’s difficult to say without seeing a full example.

gradient is just for convenience and ends up calling loss, back = Zygote.pullback(...) and then basically just returns back(one(loss)). The reason your example uses pullback is that this gives you the actual value of the loss function in the same pass as well, which is often useful for logging.

2 Likes

Did you move your model to the GPU as well? Otherwise, it’s difficult to say without seeing a full example.

Yes. I think so.

using Flux
using Zygote
using WAV
using DataLoaders
import LearnBase: nobs, getobs

include("./src/DCUNet.jl")
using .DCUNet


global SAMPLE_RATE = 48000
global N_FFT = div((SAMPLE_RATE * 64), 1000) + 4
global HOP_LENGTH = div((SAMPLE_RATE  * 16), 1000) + 4
global MAX_LENGTH = 165000


struct Data
    files::Vector{String}
end

prepare_sample(waveform) = (MAX_LENGTH - length(waveform) > 0) ?  vcat(zeros(Float32, MAX_LENGTH - length(waveform)), waveform) : waveform[1:MAX_LENGTH]

nobs(ds::Data) = (length(ds.files))
getobs(ds::Data, idx::Int) = stft(prepare_sample(vec(Float32.(wavread(ds.files[idx])[1]))))




path = "./data/exp_raw/"

noisy_train_path = path * "noisy_trainset_28spk_wav"
clean_train_path = path * "clean_trainset_28spk_wav"

noisy_test_path = path * "noisy_testset_wav.nosync"
clean_test_path = path * "clean_testset_wav.nosync"

noisyTrainData = Data(readdir(noisy_train_path, join=true));
cleanTrainData = Data(readdir(clean_train_path, join=true));

# noisyTestData = DataSet(readdir(noisy_test_path, join=true));
# cleanTestData = DataSet(readdir(clean_test_path, join=true));
 

@assert nobs(noisyTrainData) == nobs(cleanTrainData)
# @assert nobs(testData)[1] == nobs(testData)[2]

trn = DataLoader((noisyTrainData, cleanTrainData),  8, partial=true);
# tst = DataLoader((testData.noisyFiles , testData.cleanFiles) , 2, partial=true);

model = UNet() |> gpu;
loss(x, y) = wsdrLoss(x, model(x), y)

opt = ADAM(1e-4);

function my_custom_train!(loss, model, data, opt)
    ps = params(model)
    epoch = 1;
    for (x, y) in data
      # back is a method that computes the product of the gradient so far with its argument.
      x = x |> gpu; y = y |> gpu;
      println(epoch)
      train_loss, back = Zygote.pullback(() -> loss(x, y), ps)
      # Insert whatever code you want here that needs training_loss, e.g. logging.
      #logging_callback(training_loss)
      # Apply back() to the correct type of 1.0 to get the gradient of loss.
      gs = back(one(train_loss))
      # Insert what ever code you want here that needs gradient.
      # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge.
      update!(opt, ps, gs)
      epoch += 1;
      println(epoch)
      # Here you might like to check validation set accuracy, and break out to do early stopping.
    end
end

And to train the model :

my_custom_train!(loss, model, trn, opt)

I see

┌ Warning: Slow fallback implementation invoked for conv!  You probably don't want this; check your datatypes.
│   yT = ForwardDiff.Dual{Nothing,Float32,2}
│   T1 = ForwardDiff.Dual{Nothing,Float32,2}
│   T2 = Float32
└ @ NNlib /opt/.julia/packages/NNlib/2Wxlq/src/conv.jl:206
┌ Warning: Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`
└ @ GPUArrays /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:43

and nothing else is coming afterwards. By the way, why do I get ForwardDiff.Dual{Nothing, Float32,2} warning ?

Is it possible to see loss for each batch ?

Also , I tried gradient function dmodel = gradient(wsdrLoss, model(x), x, y)[1]. It returns the gradients but when I try to use the update function flux.update!(opt, params(model), dmodel) I get :

ERROR: InvalidIRErrorp.s. I might mix a : compiling kernel getindex_kernel(CUDA.CuKernelContext, CuDeviceArray{Complex{Float32},4,1}, CuDeviceArray{Complex{Float32},4,1}, Tuple{Int64}, CuDeviceArray{Float32,4,1}) resulted in invalid LLVM IR
Reason: unsupported dynamic function invocation (call to #sprint#355(context, sizehint::Integer, ::typeof(sprint), f::Function, args...) in Base at strings/io.jl:100)
Stacktrace:
 [1] #repr#356 at strings/io.jl:227
 [2] limitrepr at strings/io.jl:229
 [3] to_index at indices.jl:297
 [4] to_index at indices.jl:274
 [5] to_indices at indices.jl:325
 [6] to_indices at indices.jl:322
 [7] getindex at abstractarray.jl:1060
 [8] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
 [9] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported dynamic function invocation (call to print)
Stacktrace:
 [1] print_to_string at strings/io.jl:135
 [2] string at strings/io.jl:174
 [3] to_index at indices.jl:297
 [4] to_index at indices.jp.s. I might mix a l:274
 [5] to_indices at indices.jl:325
 [6] to_indices at indices.jl:322
 [7] getindex at abstractarray.jl:1060
 [8] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
 [9] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported call through a literal pointer (call to jl_array_grow_end)
Stacktrace:
 [1] _growend! at array.jl:892
 [2] resize! at array.jl:1085
 [3] print_to_string at strings/io.jl:137
 [4] string at strings/io.jl:174
 [5] to_index at indices.jl:297
 [6] to_index at indices.jl:274
 [7] to_indices at indices.jl:325
 [8] to_indices at indices.jl:322
 [9] getindex at abstractarray.jl:1060
 [10] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
 [11] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported call through a literal pointer (call to jl_array_del_end)
Stacktrace:
 [1] _deleteend! at array.jl:901
 [2] resize! at array.jl:1090
 [3] print_to_string at strings/io.jl:137
 [4] string at strings/io.jl:174
 [5] to_index at indices.jl:297
 [6] to_index at indices.jl:274
 [7] to_indices at indices.jl:325
 [8] to_indices at indices.jl:322
 [9] getindex at abstractarray.jl:1060
 [10] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
 [11] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported call through a literal pointer (call to jl_array_to_string)
Stacktrace:
 [1] String at strings/string.jl:39
 [2] print_to_string at strings/io.jl:137
 [3] string at strings/io.jl:174
 [4] to_index at indices.jl:297
 [5] to_index at indices.jl:274
 [6] to_indices at indices.jl:325
 [7] to_indices at indices.jl:322
 [8] getindex at abstractarray.jl:1060
 [9] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
 [10] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported call through a literal pointer (call to jl_alloc_string)
Stacktrace:
 [1] _string_n at strings/string.jl:60
 [2] StringVector at iobuffer.jl:31
 [3] #IOBuffer#331 at iobuffer.jl:114
 [4] print_to_string at strings/io.jl:133
 [5] string at strings/io.jl:174
 [6] to_index at indices.jl:297
 [7] to_index at indices.jl:274
 [8] to_indices at indices.jl:325
 [9] to_indices at indices.jl:322
 [10] getindex at abstractarray.jl:1060
 [11] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
 [12] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported call through a literal pointer (call to jl_string_to_array)
Stacktrace:
 [1] unsafe_wrap at strings/string.jl:71
 [2] StringVector at iobuffer.jl:31
 [3] #IOBuffer#331 at iobuffer.jl:114
 [4] print_to_string at strings/io.jl:133
 [5] string at strings/io.jl:174
 [6] to_index at indices.jl:297
 [7] to_index at indices.jl:274
 [8] to_indices at indices.jl:325
 [9] to_indices at indices.jl:322
 [10] getindex at abstractarray.jl:1060
 [11] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
 [12] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported call through a literal pointer (call to __memset_avx2_unaligned_erms)
Stacktrace:
 [1] fill! at array.jl:428
 [2] #IOBuffer#331 at iobuffer.jl:121
 [3] print_to_string at strings/io.jl:133
 [4] string at strings/io.jl:174
 [5] to_index at indices.jl:297
 [6] to_index at indices.jl:274
 [7] to_indices at indices.jl:325
 [8] to_indices at indices.jl:322
 [9] getindex at abstractarray.jl:1060
 [10] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
 [11] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Stacktrace:
 [1] check_ir(::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget,CUDA.CUDACompilerParams}, ::LLVM.Module) at /opt/.julia/packages/GPUCompiler/uTpNx/src/validation.jl:123
 [2] macro expansion at /opt/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:239 [inlined]
 [3] macro expansion at /opt/.julia/packages/TimerOutputs/ZmKD7/src/TimerOutput.jl:206 [inlined]
 [4] codegen(::Symbol, ::GPUCompiler.CompilerJob; libraries::Bool, deferred_codegen::Bool, optimize::Bool, strip::Bool, validate::Bool, only_entry::Bool) at /opt/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:237
 [5] compile(::Symbol, ::GPUCompiler.CompilerJob; libraries::Bool, deferred_codegen::Bool, optimize::Bool, strip::Bool, validate::Bool, only_entry::Bool) at /opt/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:39
 [6] compile at /opt/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:35 [inlined]
 [7] cufunction_compile(::GPUCompiler.FunctionSpec; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /opt/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:310
 [8] cufunction_compile(::GPUCompiler.FunctionSpec) at /opt/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:305
 [9] check_cache(::Dict{UInt64,Any}, ::Any, ::Any, ::GPUCompiler.FunctionSpec{typeof(GPUArrays.getindex_kernel),Tuple{CUDA.CuKernelContext,CuDeviceArray{Complex{Float32},4,1},CuDeviceArray{Complex{Float32},4,1},Tuple{Int64},CuDeviceArray{Float32,4,1}}}, ::UInt64; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /opt/.julia/packages/GPUCompiler/uTpNx/src/cache.jl:40
 [10] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139 [inlined]
 [11] cached_compilation at /opt/.julia/packages/GPUCompiler/uTpNx/src/cache.jl:65 [inlined]
 [12] cufunction(::typeof(GPUArrays.getindex_kernel), ::Type{Tuple{CUDA.CuKernelContext,CuDeviceArray{Complex{Float32},4,1},CuDeviceArray{Complex{Float32},4,1},Tuple{Int64},CuDeviceArray{Float32,4,1}}}; name::Nothing, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /opt/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:297
 [13] cufunction at /opt/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:294 [inlined]
 [14] #launch_heuristic#853 at /opt/.julia/packages/CUDA/YeS8q/src/gpuarrays.jl:19 [inlined]
 [15] launch_heuristic at /opt/.julia/packages/CUDA/YeS8q/src/gpuarrays.jl:17 [inlined]
 [16] gpu_call(::typeof(GPUArrays.getindex_kernel), ::CuArray{Complex{Float32},4}, ::CuArray{Complex{Float32},4}, ::Tuple{Int64}, ::CuArray{Float32,4}; target::CuArray{Complex{Float32},4}, total_threads::Nothing, threads::Nothing, blocks::Nothing, name::Nothing) at /opt/.julia/packages/GPUArrays/jhRU7/src/device/execution.jl:61
 [17] gpu_call(::typeof(GPUArrays.getindex_kernel), ::CuArray{Complex{Float32},4}, ::CuArray{Complex{Float32},4}, ::Tuple{Int64}, ::CuArray{Float32,4}) at /opt/.julia/packages/GPUArrays/jhRU7/src/device/execution.jl:46
 [18] _getindex(::CuArray{Complex{Float32},4}, ::CuArray{Float32,4}) at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:135
 [19] getindex(::CuArray{Complex{Float32},4}, ::CuArray{Float32,4}) at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:125
 [20] update!(::ADAM, ::Params, ::CuArray{Complex{Float32},4}) at /opt/.julia/packages/Flux/q3zeA/src/optimise/train.jl:28
 [21] top-level scope at REPL[42]:1

So what is happening ?

gradient is just for convenience and ends up calling loss, back = Zygote.pullback(...) and then basically just returns back(one(loss)) . The reason your example uses pullback is that this gives you the actual value of the loss function in the same pass as well, which is often useful for logging.

Thank you for the explanation.