How do I customize the derivative of a matrix using Zygote: @adjoint

Let’s say I have a matrix A, and each of the entries of this matrix A is the function of x. And then we have matrix B, and the entries in matrix B are the derivatives of the entries in matrix A with respect to x. When I call Zygote, I don’t need Zygote to calculate B. I want to call Zygote: @adjoint to customize B directly, but there are some problems in the implementation process. The detailed code is as follows:

using Zygote, Test
using Zygote: bufferfrom
using Zygote: @adjoint
function Get_Mat(x,t)
    d=2
    A= bufferfrom(zeros(typeof(x),d,d))
    B= bufferfrom(zeros(typeof(x),d,d))
    for i in 1:d
        for j in 1:d
            A[i,j] = sin(i*t)*exp(j*x)
            B[i,j] = sin(j*x)*exp(i*t)
        end
    end
    return copy(A),copy(B)
end
function Grad_of_Mat(x,t)
    x0 = 1e-14
    A,B=Get_Mat(x+x0*im,t)
    g_A = imag.(A) ./ x0
    g_B = imag.(B) ./ x0
    return g_A,g_B
end

function myfunction(x,t)
    A,B=Get_Mat(x,t)
    return sum(A*B)
end

function test1()
    t0 = 0.3
    x0 = 0.7
    f(x) = myfunction(x,t0)

    @time res = f(x0)
    @time grad = gradient(f,t0)[1]
    @show res , grad

end
@time test1()

@adjoint Get_Mat(x,t) = Get_Mat(x,t) ,  c̄ -> @. c̄ * (Grad_of_Mat(x,t),(zeros(2,2),zeros(2,2)))
function test2()
    t0 = 0.3
    x0 = 0.7
    f(x) = myfunction(x,t0)

    @time res = f(x0)
    @time grad = gradient(f,t0)[1]
    @show res , grad
end
@time test2()

The error is:

0.000011 seconds (3 allocations: 288 bytes)
  1.361310 seconds (1.39 M allocations: 74.524 MiB, 1.66% gc time, 99.90% compilation time)
(res, grad) = (14.168254994540455, 17.787911265679224)
  1.450008 seconds (1.48 M allocations: 79.829 MiB, 1.55% gc time, 99.78% compilation time)
  0.000005 seconds (3 allocations: 288 bytes)
ERROR: LoadError: MethodError: no method matching *(::Matrix{Float64}, ::Tuple{Matrix{Float64}, Matrix{Float64}})
Closest candidates are:
  *(::Any, ::Any, ::Any, ::Any...) at E:\AppData\Julia-1.7.3\share\julia\base\operators.jl:655
  *(::StridedMatrix{T}, ::StridedVector{S}) where {T<:Union{Float32, Float64, ComplexF32, ComplexF64}, S<:Real} at E:\AppData\Julia-1.7.3\share\julia\stdlib\v1.7\LinearAlgebra\src\matmul.jl:44
  *(::StridedMatrix{var"#s861"} where var"#s861"<:Union{Float32, Float64}, ::StridedMatrix{var"#s860"} where var"#s860"<:Union{Float32, Float64, ComplexF32, ComplexF64}) at E:\AppData\Julia-1.7.3\share\julia\stdlib\v1.7\LinearAlgebra\src\matmul.jl:158
  ...
Stacktrace:
  [1] _broadcast_getindex_evalf
    @ .\broadcast.jl:670 [inlined]
  [2] _broadcast_getindex
    @ .\broadcast.jl:643 [inlined]
  [3] (::Base.Broadcast.var"#29#30"{Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(*), Tuple{Tuple{Matrix{Float64}, Matrix{Float64}}, Tuple{Tuple{Matrix{Float64}, Matrix{Float64}}, Tuple{Matrix{Float64}, Matrix{Float64}}}}}})(k::Int64)        
    @ Base.Broadcast .\broadcast.jl:1075
  [4] ntuple
    @ .\ntuple.jl:49 [inlined]
  [5] copy
    @ .\broadcast.jl:1075 [inlined]
  [6] materialize
    @ .\broadcast.jl:860 [inlined]
 
.......

How do I call Zygote: @adJoint to achieve the result I want?

I don’t know if my problem is clearly described.

It’s deceptive, but this is not actually doing the same thing as:

c̄ .* (Grad_of_Mat(x,t),(zeros(2,2),zeros(2,2)))

At the same time, I would recommend you use ChainRulesCore.rrule instead of Zygote.@adjoint for this. Converting ZygoteRules.@adjoint to rrules · ChainRules has a migration guide for converting between the two. The reason is that we plan on deprecating @adjoint in the near future and asking users to move to ChainRules. Bonus, your rule will work with other ChainRules-compatible ADs!