Since the Flux has moved to Zygote, I tried to modify some previous codes.
The motivation is to calculate second-order derivatives of a matrix, but I met the following errors.
With the setup
x0 = 0.0; x1 = 2.0
y0 = 0.0; y1 = 1.0
nx = 23; ny = 11
xVec = Vector(range(x0, stop=x1, length=nx))
xMesh = repeat(reshape(xVec, 1, :), ny, 1)
yVec = Vector(range(y0, stop=y1, length=ny))
yMesh = repeat(yVec, 1, nx)
xMesh1D = reshape(xMesh, (1, :))
yMesh1D = reshape(yMesh, (1, :))
mesh = cat(xMesh1D, yMesh1D; dims=1)
X = deepcopy(mesh)
Y = zeros(1, nx*ny);
the previous codes are
using Flux
import Tracker
m = Chain(Dense(2, 20, tanh), Dense(20, 20, tanh), Dense(20, 1))
mat_x(x) = [1.0 0.0] * x
mat_y(x) = [0.0 1.0] * x
u(x) = sin.(π .* mat_x(x) ./ 2.) .* mat_y(x) .+
mat_x(x) .* (2. .- mat_x(x)) .* mat_y(x) .* (1. .- mat_y(x)) .* m(x)
ux(x) = mat_x(Tracker.forward(u, x)[2](1)[1])
uy(x) = mat_y(Tracker.forward(u, x)[2](1)[1])
uxx(x) = mat_x(Tracker.forward(ux, x)[2](1)[1])
uyy(x) = mat_y(Tracker.forward(uy, x)[2](1)[1])
uxx(X)
and it gives
Tracked 1×253 Array{Float64,2}:
0.0 -0.266183 -0.451705 -0.562292 … 0.284977 0.160405 -3.02169e-16
However, when I turned to Zygote with,
vx(x) = mat_x(Flux.Zygote.pullback(u, x)[2](1)[1])
vy(x) = mat_y(Flux.Zygote.pullback(u, x)[2](1)[1])
vxx(x) = mat_x(Flux.Zygote.pullback(vx, x)[2](1)[1])
vyy(x) = mat_y(Flux.Zygote.pullback(vy, x)[2](1)[1])
vxx(X)
it gives
MethodError: no method matching reshape(::Int64, ::Tuple{Int64,Int64})
Also I tried another function definition,
vx(x) = mat_x(Flux.Zygote.pullback(u, x)[2](ones(size(x)))[1])
vy(x) = mat_y(Flux.Zygote.pullback(u, x)[2](ones(size(x)))[1])
vxx(x) = mat_x(Flux.Zygote.pullback(vx, x)[2](ones(size(x)))[1])
vyy(x) = mat_y(Flux.Zygote.pullback(vy, x)[2](ones(size(x)))[1])
vxx(X)
but now it gives
DimensionMismatch("A has dimensions (2,1) but B has dimensions (2,253)")
Stacktrace:
[1] gemm_wrapper!(::Array{Float64,2}, ::Char, ::Char, ::Array{Float64,2}, ::Array{Float64,2}, ::LinearAlgebra.MulAddMul{true,true,Bool,Bool}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.3/LinearAlgebra/src/matmul.jl:545
[2] mul! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.3/LinearAlgebra/src/matmul.jl:305 [inlined]
[3] mul! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.3/LinearAlgebra/src/matmul.jl:372 [inlined]
[4] mul! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.3/LinearAlgebra/src/matmul.jl:203 [inlined]
[5] * at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.3/LinearAlgebra/src/matmul.jl:153 [inlined]
[6] #1204 at /home/tianbai/.julia/packages/Zygote/oMScO/src/lib/array.jl:246 [inlined]
[7] #3072#back at /home/tianbai/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
[8] mat_x at ./In[3]:6 [inlined]
[9] vx at ./In[6]:1 [inlined]
[10] (::typeof(∂(vx)))(::Array{Float64,2}) at /home/tianbai/.julia/packages/Zygote/oMScO/src/compiler/interface2.jl:0
[11] (::Zygote.var"#28#29"{typeof(∂(vx))})(::Array{Float64,2}) at /home/tianbai/.julia/packages/Zygote/oMScO/src/compiler/interface.jl:38
[12] vxx(::Array{Float64,2}) at ./In[6]:3
[13] top-level scope at In[9]:1
Any ideas on how to fix the bug? Thanks