Let’s say I have a matrix A, and each of the entries of this matrix A is the function of x. And then we have matrix B, and the entries in matrix B are the derivatives of the entries in matrix A with respect to x. When I call Zygote, I don’t need Zygote to calculate B. I want to call Zygote: @adjoint to customize B directly, but there are some problems in the implementation process. The detailed code is as follows:
using Zygote, Test
using Zygote: bufferfrom
using Zygote: @adjoint
function Get_Mat(x,t)
d=2
A= bufferfrom(zeros(typeof(x),d,d))
B= bufferfrom(zeros(typeof(x),d,d))
for i in 1:d
for j in 1:d
A[i,j] = sin(i*t)*exp(j*x)
B[i,j] = sin(j*x)*exp(i*t)
end
end
return copy(A),copy(B)
end
function Grad_of_Mat(x,t)
x0 = 1e-14
A,B=Get_Mat(x+x0*im,t)
g_A = imag.(A) ./ x0
g_B = imag.(B) ./ x0
return g_A,g_B
end
function myfunction(x,t)
A,B=Get_Mat(x,t)
return sum(A*B)
end
function test1()
t0 = 0.3
x0 = 0.7
f(x) = myfunction(x,t0)
@time res = f(x0)
@time grad = gradient(f,t0)[1]
@show res , grad
end
@time test1()
@adjoint Get_Mat(x,t) = Get_Mat(x,t) , c̄ -> @. c̄ * (Grad_of_Mat(x,t),(zeros(2,2),zeros(2,2)))
function test2()
t0 = 0.3
x0 = 0.7
f(x) = myfunction(x,t0)
@time res = f(x0)
@time grad = gradient(f,t0)[1]
@show res , grad
end
@time test2()
The error is:
0.000011 seconds (3 allocations: 288 bytes)
1.361310 seconds (1.39 M allocations: 74.524 MiB, 1.66% gc time, 99.90% compilation time)
(res, grad) = (14.168254994540455, 17.787911265679224)
1.450008 seconds (1.48 M allocations: 79.829 MiB, 1.55% gc time, 99.78% compilation time)
0.000005 seconds (3 allocations: 288 bytes)
ERROR: LoadError: MethodError: no method matching *(::Matrix{Float64}, ::Tuple{Matrix{Float64}, Matrix{Float64}})
Closest candidates are:
*(::Any, ::Any, ::Any, ::Any...) at E:\AppData\Julia-1.7.3\share\julia\base\operators.jl:655
*(::StridedMatrix{T}, ::StridedVector{S}) where {T<:Union{Float32, Float64, ComplexF32, ComplexF64}, S<:Real} at E:\AppData\Julia-1.7.3\share\julia\stdlib\v1.7\LinearAlgebra\src\matmul.jl:44
*(::StridedMatrix{var"#s861"} where var"#s861"<:Union{Float32, Float64}, ::StridedMatrix{var"#s860"} where var"#s860"<:Union{Float32, Float64, ComplexF32, ComplexF64}) at E:\AppData\Julia-1.7.3\share\julia\stdlib\v1.7\LinearAlgebra\src\matmul.jl:158
...
Stacktrace:
[1] _broadcast_getindex_evalf
@ .\broadcast.jl:670 [inlined]
[2] _broadcast_getindex
@ .\broadcast.jl:643 [inlined]
[3] (::Base.Broadcast.var"#29#30"{Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(*), Tuple{Tuple{Matrix{Float64}, Matrix{Float64}}, Tuple{Tuple{Matrix{Float64}, Matrix{Float64}}, Tuple{Matrix{Float64}, Matrix{Float64}}}}}})(k::Int64)
@ Base.Broadcast .\broadcast.jl:1075
[4] ntuple
@ .\ntuple.jl:49 [inlined]
[5] copy
@ .\broadcast.jl:1075 [inlined]
[6] materialize
@ .\broadcast.jl:860 [inlined]
.......
How do I call Zygote: @adJoint to achieve the result I want?