I’ve encountered so many bugs using ForwardDiff and Zygote to do automatic differentiation for my objective function that takes complex vector as input and outputs a real number. I don’t think I have the expertise to debug myself so I’m posting my problem here. ForwardDiff is just bad and a quick search suggested that it doesn’t support complex stuff as well as Zygote. For Zygote, the most common error I got is about how plan_fft doesn’t have a method for the ForwardDiff AD datatype:
MethodError: no method matching plan_fft(::Array{Complex{ForwardDiff.Dual{Nothing,Float64,12}},3}, ::UnitRange{Int64})
I then tried to compute the plan ahead of time
pfft = plan_fft(e_in)
and pifft = plan_ifft(e_in)
and feed them into the functions using *
, but I got errors
type ScaledPlan has no field region
. Google search resulted in zero answers.
Moreover, my objective function is not supposed to return complex values, but I got error saying output is complex, so I had to add real()
to my objective function.
At this point I’m not sure if I’m just using it incorrectly or is this type of objective function involving complex inputs and fast fourier transform doomed for AD packages like ForwardDiff and Zygote. I’ve attached an abbreviated version of my code below with dummy values. Thank you so much for your help.
using LinearAlgebra,Polynomials,Printf,Random,ForwardDiff,PyPlot,PyCall,OffsetArrays,FFTW,Zygote
np=pyimport("numpy")
function make_prop_kernel( sz, z; dl=2,lmbd=0.5)
nx = sz[1]
ny = sz[2]
k=2*pi/lmbd # wavenumber
dkx=2*pi/((nx-1)*dl)
dky=2*pi/((ny-1)*dl)
kx=(LinRange(0,nx-1,nx).-nx/2)*dkx
ky=(LinRange(0,ny-1,ny).-ny/2)*dky
inflate(f, kx, ky) = [f(x,y) for x in kx, y in ky]
f(kx,ky)=exp(1im*sqrt(k^2-kx^2-ky^2)*z)
prop_kernel=inflate(f,kx,ky)
prop_kernel = ifftshift(prop_kernel)
return prop_kernel
end
function light_prop(e_in, prop_kernel)
if ndims(e_in) == 3
prop_kernel = reshape(prop_kernel,(1,size(prop_kernel)...))
end
ek_in = fft(ifftshift(e_in))
ek_out = ek_in.*prop_kernel
e_out = fftshift(ifft(ek_out))
return e_out
end
function phase_mod(e_in, theta; samp_ratio=1)
#=
e_in is the input field
theta is the phase mask
samp_ratio is the pixel size ratio between the phase mask and the e field
=#
if ndims(e_in) == 2
if samp_ratio == 1
e_out = e_in.*exp.(1im*theta)
else
e_out = e_in.*kron(exp.(1im*theta),ones((samp_ratio,samp_ratio)))
end
elseif ndims(e_in) == 3
if samp_ratio == 1
M = exp.(1im*theta)
e_out = e_in.*reshape(M,(1,size(M)...))
else
M = kron(exp.(1im*theta),ones((samp_ratio,samp_ratio)))
e_out = e_in.*reshape(M,(1,size(M)...))
end
end
return e_out
end
function propagatedOutput(e_in,airKernels,theta)
e_out = e_in
for jj = 1:size(theta,1) # loop over plate
e_out = airPropagate(e_out, jj)
e_out = phase_mod(e_out, theta[jj,:,:], samp_ratio=1)
end
return e_out
end
function symmetricADHessian(f,x)
Hx = Zygote.hessian(f,x)
return LinearAlgebra.tril(Hx,-1)+LinearAlgebra.tril(Hx)'
end
nx = 16
ny = 16
npl = 5
nmod = 10
mat(x) = reshape(x,(npl,nx,ny))
e_in = randn(nmod,nx,ny)
e_target = randn(nmod,nx,ny)
d = [2e4,2.5e4,2.5e4,2.5e4,2.5e4,2e4]
airKernels = zeros(ComplexF64,(npl+1,nx,ny))
for jj = 1:npl+1
airKernels[jj,:,:] = make_prop_kernel( (nx,ny), d[jj])
end
airPropagate(e_in, jj) = light_prop(e_in, airKernels[jj,:,:])
#objective function
f(x) = real(norm(propagatedOutput(e_in,airKernels,mat(x)).*conj(e_target))^2/nmod -1);
x0 = zeros(npl,nx,ny)
gx=Zygote.gradient(f,x0)
Hx=symmetricADHessian(f,x0)