I am using v0.6
, this is an MWE of a question I want to solve.
I would like to make broadcast
work for a custom type
struct CongruentVector{Td,Tv}
descriptor::Td
vector::Tv
end
the following way: given a method ≅
with fallback
≅(a, b) = a == b
I would test that all CongruentVector
arguments to broadcast
have descriptors that are ≅
, then broadcast on their vector
s and nothing else, ie every other argument would be treated as a “scalar”.
The idea behind this interface is that it would enforce matching descriptor
s. In particular, I would use it for posterior analysis on MCMC, and want to ensure that vectors (which are posterior draws) really do come from the same chains(s). Compatible dimensions are necessary but not sufficient for this.
Here is some mock code (that is incomplete as is):
function _common_descriptor(y, xs...)
@assert all(y ≅ x for x in xs) "Arguments are not ≅."
y
end
@generated function broadcast(f, xs::FIXMEunsurehowtotype...)
ex_descriptors = tuple((:(xs[$i]).descriptor for (i,x) in enumerate(xs)
if x <: CongruentVector)...)
ex_arguments = tuple((x <: CongruentVector ? :(xs[$i].vector) : :(Ref(xs[$i]))
for (i,x) in enumerate(xs))...)
quote
descriptor = common_descriptor($ex_descriptors...)
result = broadcast(f, $ex_arguments...)
CongruentVector(descriptor, result)
end
end
From discussions I learned that I should probably be using generated functions. But it is unclear how to write the signature of a method that is only called when at least some arguments are CongruentVector
s, but otherwise would not interfere with broadcast
. Any pointers would be appreciated, including suggestions that I should use some other approach because what I am attempting is not currently feasible.