Hi,
I have tested gradient of llama2-7b model with respect its input and it gives me wrong results (tested only on GPU). I have created a MWE as follows
using ProfileSummarizer
using Transformers
using Flux
using TextEncodeBase
using NeuralAttentionlib
using HuggingFaceApi
using FiniteDifferences
using Zygote
using CUDA
using StatsBase
CUDA.device!(0)
textenc = HuggingFace.load_tokenizer("meta-llama/Llama-2-7b-chat-hf"; auth_token = access_token);
model = f32(HuggingFace.load_model("llama", "meta-llama/Llama-2-7b-chat-hf", "forCausallm", auth_token = access_token));
model = gpu(model);
embeddings = model.model.embedding.token.embeddings;
decoder = model.model.decoder;
# We compute randomly gradients with respect to few selected elements and compare them
# to true values. This is a sanity check, because otherwise the test would take ages.
let begin
tokens = TextEncodeBase.encode(textenc, "the most important text for this test").token
θ = cpu(embeddings * tokens)
ii = sample(eachindex(θ), 1000, replace = false)
ii = [23, 5641, 18928, 4354, 16801, 294, 18929, 17793, 1277, 17098]
sub_θ = θ[ii]
function sub_f(sub_θ)
θ[ii] = sub_θ
f(θ)
end
function f(θ)
hidden_state = gpu(θ)
sum(decoder((;hidden_state)).hidden_state)
end
fin_gs = grad(central_fdm(5, 1), sub_f, sub_θ)[1]
zyg_gs = Zygote.gradient(f, θ)[1][ii]
hcat(fin_gs, zyg_gs)
sort(abs.(fin_gs .- zyg_gs), rev = true)
end
which returns
10-element Vector{Float32}:
4.4233932
2.523777
1.9660205
1.4594455
0.67348194
0.589447
0.5284171
0.5228882
0.36881256
0.18504143
which seems to me off by a large margin.
I use Julia 1.9.2 and my Pkg.status()
says
[7d9f7c33] Accessors v0.1.32
[6e4b80f9] BenchmarkTools v1.3.2
⌃ [052768ef] CUDA v4.4.0
[d360d2e6] ChainRulesCore v1.16.0
[a93c6f00] DataFrames v1.6.1
[26cc04aa] FiniteDifferences v0.12.30
⌅ [587475ba] Flux v0.13.17
[d9f16b24] Functors v0.4.5
[3cc741c3] HuggingFaceApi v0.1.0
[f1d291b0] MLUtils v0.4.3
[5da4648a] NVTX v0.3.2
⌃ [12afc1b8] NeuralAttentionlib v0.2.11
[0b1bfda6] OneHotArrays v0.2.4
[6099a3de] PythonCall v0.9.14
[2913bbd2] StatsBase v0.34.0
[354b36f9] StringViews v1.3.3
⌃ [899adc3e] TensorBoardLogger v0.1.21
[f92c20c0] TextEncodeBase v0.6.0
[21ca0261] Transformers v0.2.7
⌃ [e88e6eb3] Zygote v0.6.63
I will try to narrow down the problem to test individual layers, but this is going to take bit of time.