Hi Julia community!
I’m trying to fit a custom, one-layer NN to an SDE by computing an MSE loss. Here is my code defining the model and computing the loss (noise and sample_features are cuda matrices defined globally)
d = 4
l_list = ones((d,1))
samples_rff = rand(TDist(2*nu), (d, n_samples));
A = randn((d, n_samples))/sqrt(n_samples)
# Send the parameters to the GPU
A = cu(A)
samples_rff = cu(samples_rff)
l_list = cu(l_list)
weights = hcat((A, l_list)...)
function rff_model(X, A, sample_features, l)
"""
X matrix of size (d, N)
A is a matrix of size (d, n_samples)
sample_features matrix of size (n_samples, d)
l vector of size (d,1)
"""
tau = l.^(-1)
W = tau.*sample_features
M = W'*X
M = cos.(M) + sin.(M)
return A*M
end
function drift(dstate, state, p, t)
A = p[:, 1:end-1]
l = p[:, end:end]
dstate .= rff_model(state, A, samples_rff, l)
end
function diffussion_rff(dstate,state,p,t)
A = p[:, 1:end-1]
l = p[:, end:end]
dstate .= noise
end
function loss(weights, training_trajectories, initial_conditions, training_time, reg = 0.1)
"""
parameters is of size (d, n_features +1)
A is of size (d, n_features)
l if of size (d, 1)
initial_conditions is of size (d, m)
"""
temp_prob = prob = SDEProblem( drift_rff,diffussion_rff, initial_conditions,(t_in, t_fin), weights)
tmp_sol = solve(prob,SOSRI(), saveat=training_time);
arrsol = CuArray(tmp_sol)
#l = weights[:, end]
return mean((arrsol - training_trajectories).^2) #+ reg*sum(l.^(-1)) # mean(arrsol)
end
objective = weights -> loss(weights, training_trajectories, initial_conditions, training_times, 0.1)
val, grads = Zygote.withgradient(objective, weights)
Computing the error works without a hitch. However, computing gradients with Zygote or ReverseDiff returns the following warning and subsequent error:
┌ Warning: Potential performance improvement omitted. ZygoteVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN = true. To turn off this printing, add verbose = false
to the solve
call. └ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:99
┌ Warning: Potential performance improvement omitted. ReverseDiffVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN = true. To turn off this printing, add verbose = false
to the solve
call. └ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:116
┌ Warning: Potential performance improvement omitted. TrackerVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN = true. To turn off this printing, add verbose = false
to the solve
call. └ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:134
┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs └ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:144
ERROR: LoadError: MethodError: no method matching SciMLBase.SensitivityInterpolation(::Vector{Float64}, ::Vector{CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}})
Computing gradients with ForwardDiff also returns an error:
ERROR: LoadError: MethodError: randn!(::CUDA.RNG, ::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{var"#47#48", Float32}, Float64, 12}, 2, CUDA.Mem.DeviceBuffer}) is ambiguous.
Candidates: randn!(rng::CUDA.RNG, A::AbstractArray{T}) where T @ CUDA ~/.julia/packages/CUDA/rXson/src/random.jl:255 randn!(rng::Random.AbstractRNG, A::GPUArraysCore.AnyGPUArray) @ GPUArrays ~/.julia/packages/GPUArrays/dAUOE/src/host/random.jl:116
To resolve the ambiguity, try making one of the methods more specific, or adding a new method more specific than any of the existing applicable methods.
Does anyone know what is going on? Any workaround?
Thank you for your help!