Broadcasting to all fields of a struct

I have a data structure, nesting AbstractArrays and Structs, where all the elements (leaves of my datatree) are, for example Float64, and I wish to transform them all into Float32, by rebuilding the data structure. I write the following code, which works:

using StaticArrays
struct S{R}
    a::SArray{Tuple{2,2},R,2,4}
    b::R
end
s = S(SMatrix{2,2,Float64}(randn(2,2)),randn())

cast(s::S            ,cv) = S(cast(s.a,cv),cast(s.b,cv))
#cast(s::S            ,cv) = cast.(s,cv)
cast(a::AbstractArray,cv) = cast.(a,cv)
cast(a::Real         ,cv) = cv(a)

cv(x) = Float32(x)
@code_warntype cast(s,cv) 
@show          cast(s,cv) 

However the ambition (as suggested by the line of code commented out) is that I would like to write a castmethod to recursively apply cast to all fields of any composite type.

If the answer is metaprogramming, allright, I’ll do it (@generated and fieldnames…), but I’d like to know if there is a lazier way.

You mean, you could, but probably will fail for any struct that is not like yours, this is, for which all the leaves are of the same type that can be converted to whatever you want, right?

It really seems like a generated function is the right choice here, but I myself have never used them, so I cannot offer any practical advice.

Meanwhile, why did you chose to create a cast function? convert did not seem adequate?

I think you shouldn’t call it cast but instead extend convert. Perhaps something like this:

function Base.convert(::Type{S{R}}, s::S) where {R}
    vals = getfield.((s,), fieldnames(S))
    return S{R}(vals...)
end
jl> s = S(SMatrix{2,2,Float64}(randn(2,2)),randn())
S{Float64}([-0.33679898580000156 -1.6556986308629353; -1.0290895856851072 0.4700396382426737], 0.041295723749969614)

jl> convert(S{ComplexF32}, s)
S{ComplexF32}(ComplexF32[-0.336799f0 + 0.0f0im -1.6556987f0 + 0.0f0im; -1.0290896f0 + 0.0f0im 0.47003964f0 + 0.0f0im], 0.041295722f0 + 0.0f0im)

Note that the use of getfield and fieldnames isn’t apparently type stable, so there may be a better way to get vals, perhaps with a generated function.

If you directly type in s.a, s.b into the constructor, it will of course be fast.

Hi Henrique,

Yes to your first question.

I prefer not to call it convert because there is the convention that convert is not lossy. For reasons that will bore everyone, my cv function may be seriously lossy: I don’t want to create a convertmethod which might be called implicitly in some assignement.

Yes, this kind of thing seems to often call for a generated function to guarantee type-stability.

Hi DNF,

Great snippet

which I will use. I will end up metaprogramming and this will appear in the code.

The solution you propose assumes I have easy access to the exact type of the struct I want to create.
However I will program
unparametredtype(s::S{R}) where{R} = S
which output I can use to call the constructor.

Almost there!!!

I must have missed something. I thought you knew the types. Can you give a more concrete example showing what your wanted function does have access to?

1 Like

Convert can be lossy, converting from Float64 to Float32, for example. At least for the use case in your example, convert seems like the right name, but perhaps you have a different example in mind?

Well, I want a method cast(s,cv) that accepts any struct (and I don’t know how to dispatch for that, by the way…), and recursively applies cv to all fields of s.

The parameters of cast are not a MWE - I want to write exactly that function…

Yes I know the type of s (typeof) of course. My challenge was (is now solved) to get a name for the constructor I need to call. I do not easily know the exact type of the output. The good news is, we can call the constructor for S (not S{hard to get}), and that I know how to get:
unparametredtype(s::S{R}) where{R} = S

I am working on a parallel solutioj in which I use a named tuple instead of a struct - that’s a lot easier.

Indeed, if you want this to work on arbitrary types, you should not overload convert (that would be very bad). But then you should expect cast to fail pretty frequently, right, since most fields are not convertible to Float32, and the constructor syntax is also generally unknown.

Maybe convertfields is an appropriate name?

OK, all cards on the table.

My MWE Float64 might be in the real world an automatic differentiation object full of partial derivatives data.

So I might need to extract the Float64from it (the current challenge), but I never want this to happen implicitly: I would silently get the wrong derivative - I get numb with fear just at the thought! :grinning:

I see. Beware of the XY problem by the way.

1 Like

convertfieldsit is! Good name!

Sorry! And thank you for helping!