Passing struct with Union type to other function

I’m pretty new to Julia so I might have missed something pretty elementary.

I have defined the following struct which contains some parameters with default values.

@with_kw struct Par1
    D::Union{Float64, VecOrMat{Float64}} = 5.0

    G::Float64 = 1.0
    dims::Int64 = 1
    Nx::Int64 = 256
    Lx::Float64 = 1.0
    x0::Float64 = 0.0
    x = range(x0, x0+Lx, length=Nx+1)[1:end-1]
    dx::Float64 = x[2] - x[1]
    order::Int64 = 2
end

This structure is then passed to the following function, which will be evaluated many times for solving PDE:

function pressure(h,D,G)
    return G/(D-h)^3
end


function fun!(dh, h, par, t)
    G = par.G
    D = par.D
    order = par.order
    dx = par.dx

    temp1 = pressure.(h, D, G)

    diffp!(dh, h, 2, order, dx)     # diffp! is a self-defined function for doing finite difference
    
    @. dh = - temp1 - dh            
    diffp!(temp1, dh, 1, order, dx)

    @. temp1 = h^3 * temp1
    diffp!(dh, temp1, 1, order, dx)
end
using Parameters, Setfield

Nx = 256
G = 0.05

par1 = ParTest1(Nx=Nx, G=G)
x = par1.x
D = 5.0*(1 .+ 0.1*cos.(2*pi*x))
par1 = @set par1.D = D

h0 = ones(Nx)
h1 = similar(h0)
fun!(h1, h0, par1, 0.0)

The main problem is that since par1.D can be either a Float64, a 1D vector or a matrix, when I run @code_warntype to check fun!(h1, h0, par1, 0.0) it gives (I only showed those highlighted in red)

Locals
  temp1::VecOrMat{Float64}
  D::Union{Float64, VecOrMat{Float64}}

%5  = Base.broadcasted(Main.fun2, h, D, G)::Union{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(fun2), Tuple{Vector{Float64}, Float64, Float64}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(fun2), Tuple{Vector{Float64}, Vector{Float64}, Float64}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(fun2), Tuple{Vector{Float64}, Matrix{Float64}, Float64}}}

%8  = Base.broadcasted(Main.:-, temp1)::Union{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Vector{Float64}}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(-), Tuple{Matrix{Float64}}}}
│   %9  = Base.broadcasted(Main.:-, %8, dh)::Union{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(-), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(-), Tuple{Matrix{Float64}}}, Vector{Float64}}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Vector{Float64}}}, Vector{Float64}}}}

  %12 = temp1::VecOrMat{Float64}

%16 = Base.broadcasted(Main.:*, %15, temp1)::Union{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(Base.literal_pow), Tuple{Base.RefValue{typeof(^)}, Vector{Float64}, Base.RefValue{Val{3}}}}, Vector{Float64}}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(Base.literal_pow), Tuple{Base.RefValue{typeof(^)}, Vector{Float64}, Base.RefValue{Val{3}}}}, Matrix{Float64}}}}

I can of course be more specific with the type of par1.D or even just pass a named tuple to fun which doesn’t give any red highlighted parts when I run `@code_warntype. I get a bit better performance using a named tuple:

par0 = (D=par1.D, order=par1.order, dx=par1.dx, G=par1.G)

@btime fun!(h1, h0, par0, 0.0)   #  932.143 ns (1 allocation: 2.12 KiB)
@btime fun!(h1, h0, par1, 0.0)   # 1.310 μs (7 allocations: 2.27 KiB)

But I’d like to know if there are ways to better define the structure so that it retains the flexibility of letting D be either a scalar/vector/matrix, (and I have many other similar fields for the parameters in my actual code) while still letting Julia figure out the correct type inside the function calls.

How about this?

@with_kw struct Par1{T<:Union{Float64, VecOrMat{Float64}}}
    D::T = 5.0
    G::Float64 = 1.0
    dims::Int64 = 1
    Nx::Int64 = 256
    Lx::Float64 = 1.0
    x0::Float64 = 0.0
    x = range(x0, x0+Lx, length=Nx+1)[1:end-1]
    dx::Float64 = x[2] - x[1]
    order::Int64 = 2
end

Note that in your current struct, x also seems to have an undetermined type (Any if I’m correct)

Thanks. It seems to work as I intended. Suppose I need to include extra fields in Par1 called D2 and D3, which are also be scalar, 1D vector or matrix (D, D2 and De need not be of the same type), is the following the correct way to do it?

@with_kw struct Par1{T1, T2, T3<:Union{Float64, VecOrMat{Float64}}}
    D::T1 = 5.0
    D2::T2 = 1.0
    D3::T3 = 2.0
    G::Float64 = 1.0
    dims::Int64 = 1
    Nx::Int64 = 256
    Lx::Float64 = 1.0
    x0::Float64 = 0.0
    x = range(x0, x0+Lx, length=Nx+1)[1:end-1]
    dx::Float64 = x[2] - x[1]
    order::Int64 = 2
end

And yes there are some fields like x of type Any in the struct, but if they are not being accessed and used directly (more like intermediate helper variables for my own convenience) by fun!, it shouldn’t matter, right?

You actually need to specify the union for each Ti (between the commas). To make the code more readable, you could do something like

const U = Union{Float64, VecOrMat{Float64}}
@with_kw struct Par1{T1<:U, T2<:U, T3<:U}
    D::T1 = 5.0
    D2::T2 = 1.0
    D3::T3 = 2.0
    G::Float64 = 1.0
    dims::Int64 = 1
    Nx::Int64 = 256
    Lx::Float64 = 1.0
    x0::Float64 = 0.0
    x = range(x0, x0+Lx, length=Nx+1)[1:end-1]
    dx::Float64 = x[2] - x[1]
    order::Int64 = 2
end

Declaring U as a constant is important for this to work properly

Are there any potential issues if I declare const U2 = Union{<:Real, VecOrMat{<:Real}} to make my structure accept more general type for D? A few quick tests with @code_warntype and @benchmark seem to suggest no performance / type issues, though I’m not sure if there are important stuff I’ve overlooked.

Could be a problem if you use a real Real, because that would be abstract. But that seems to be more of caller’s problem than the callee’s then.

I don’t think there is a problem with that, but I think most Julia packages would rather write it like this:

@with_kw struct Par1{R<:Real, U<:Union{R,VecOrMat{R}}, T1<:U, T2<:U, T3<:U}
    D::T1 = 5.0
    D2::T2 = 1.0
    D3::T3 = 2.0
    G::Float64 = 1.0
    dims::Int64 = 1
    Nx::Int64 = 256
    Lx::Float64 = 1.0
    x0::Float64 = 0.0
    x = range(x0, x0+Lx, length=Nx+1)[1:end-1]
    dx::Float64 = x[2] - x[1]
    order::Int64 = 2
end