Using LsqFit on complex data

I’m trying to use LsqFit.jl to fit data that is complex (in the form x + im*y). For purposes of showing the problem, I’m using the example provided on GitHub for LsqFit but I’m changing the model such that it has an imaginary term in the exponential, rather than a real one.

using LsqFit
@. model(x, p) = p[1] * exp(im*x * p[2])

xdata = range(0, stop=10, length=20)
ydata = model(xdata, [1.0 2.0]) + 0.01 * randn(length(xdata))

p0 = [0.5, 0.5]
fit = curve_fit(model, xdata, ydata, p0)

I get the error message:

LoadError: InexactError: Float64(-0.9292044089784778 + 3.141189448336519im)
Stacktrace:
  [1] Real
    @ ./complex.jl:44 [inlined]
  [2] convert
    @ ./number.jl:7 [inlined]

I’ve tried the same thing in Matlab (:frowning:) using lsqcurvefit() with complex data and it works fine and extracts the correct parameters. I’m aware that a fitting procedure that involves complex data is going to be inherently different to simply fitting real data, but I cannot figure this out.

1 Like

Even if I force all arguments to be of type (or eltype) ComplexF64 like

using LsqFit
@. model(x, p) = p[1] * exp(im*x * p[2])

xdata = range(0+0im, stop=10, length=20)
ydata = model(xdata, ComplexF64[1.0 2.0]) + 0.01 * randn(length(xdata))

p0 = [0.5+0im, 0.5+0im]
fit = curve_fit(model, xdata, ydata, p0)

I get

ERROR: MethodError: no method matching isless(::Float64, ::ComplexF64)
Closest candidates are:
  isless(::T, ::T) where T<:Union{Float16, Float32, Float64} at float.jl:424
  isless(::AbstractFloat, ::AbstractFloat) at operators.jl:177
  isless(::Real, ::AbstractFloat) at operators.jl:178
  ...
Stacktrace:
  [1] max(x::ComplexF64, y::Float64)
    @ Base .\operators.jl:467
  [2] levenberg_marquardt(df::NLSolversBase.OnceDifferentiable{Vector{ComplexF64}, Matrix{ComplexF64}, Vector{ComplexF64}}, initial_x::Vector{ComplexF64}; x_tol::Float64, g_tol::Float64, maxIter::Int64, lambda::ComplexF64, tau::ComplexF64, lambda_increase::Float64, lambda_decrease::Float64, min_step_quality::Float64, good_step_quality::Float64, show_trace::Bool, lower::Vector{ComplexF64}, upper::Vector{ComplexF64}, avv!::Nothing)

This looks like a problem in LsqFit to me.

1 Like

Yes, it looks like its levenberg_marquardt function assumes that the model results have the same type as the parameters (see the array allocations here), which is wrong in this case. They need to be a bit more careful in propagating the data types.

I would file an issue (and/or work on a PR).

2 Likes

@stevengj: I tried to fix the problem. Extended the original tests in curve_fit.jl and curve_fit_inplace.jl to check with Float32, Float64, ComplexF32, ComplexF64. Only open problem for now is I’m running into problems with complex AD in ForwardDiff here

        for ad in (T<:Complex ? (:finite,) : (:finite, :forward, :forwarddiff))
            fit = curve_fit(model, xdata, ydata, p0; autodiff = ad)
            @show fit.param
            @assert norm(fit.param - [1.0, 2.0]) < 0.05
            @test fit.converged

            # can also get error estimates on the fit parameters
            errors = margin_error(fit, 0.1)
            @assert norm(errors - [0.017, 0.075]) < 0.01
        end

What else to check?

Thanks very much for your help, it’s super great to know that you guys are willing to help fix the problem! This has been my first question I’ve posted during my PhD after switching from Matlab to Julia, and it’s great to know there is help out there. I look forward to seeing the progress on the resolution of this issue!

Ahh okay- I look forward to seeing the progress of the resolution of this issue! Many thanks for your help on this!

I don’t think ForwardDiff supports this (Support for real-valued function with complex arguments · Issue #498 · JuliaDiff/ForwardDiff.jl · GitHub). I think Zygote does?

(Or you can use finite differences here; that’s what Matlab’s lsqcurvefit does IIRC.)

Hi @PazzyBoardman449, you said you got it working in matlab? Could you explain how? I’m trying to do something similar but the output I get is always exactly the same as my starting point. Couldn’t find an answer on any matlab forum, so that’s why I asked it here…