I wonder what is the best practice for callables that are broadcastable for input and its parameters. Here is an example:
struct Scale{T}
a::T
end
struct Shift{T}
a::T
end
(f::Scale)(x) = f.a * x
(f::Shift)(x) = f.a + x
struct VCompose{F, G}
f::F
g::G
end
(f::VCompose)(x) = f.f.(f.g.(x))
Scale
and Shift
can be broadcasted as usual when their parameters are scalars:
julia> f = VCompose(Scale(10), Shift(0.1))
f([1, 2, 3])
3-element Array{Float64,1}:
11.0
21.0
31.0
However, this does not work when the parameters are vectors
julia> f = VCompose(Scale([10, 100, 1000]), Shift([10, 20, 30]))
f([1, 2, 3])
ERROR: MethodError: no method matching +(::Array{Int64,1}, ::Int64)
Note that using dot expressions inside the call definition like below is not the right solution.
(f::Scale)(x) = f.a .* x
(f::Shift)(x) = f.a .+ x
This is because it would “doubly” broadcast the input.
julia> f = VCompose(Scale([10, 100, 1000]), Shift([10, 20, 30]))
f([1, 2, 3])
3-element Array{Array{Int64,1},1}:
[110, 2100, 31000]
[120, 2200, 32000]
[130, 2300, 33000]
What is the best way to do it?
One thing I can do is to overload broadcasted
. That is to say, I can define
Broadcast.broadcasted(f::Scale, x) = Broadcast.broadcasted(*, f.a, x)
Broadcast.broadcasted(f::Shift, x) = Broadcast.broadcasted(+, f.a, x)
so that
julia> f = VCompose(Scale([10, 100, 1000]), Shift([10, 20, 30]))
f([1, 2, 3])
3-element Array{Int64,1}:
110
2200
33000
I’m using Broadcast.broadcasted
directly here and it is hard to read. But I’m not worried about this aspect because Julia may implement “lazy” broadcasting feature and there is a macro that does this. Also, LazyArrays.jl may implement it soon.
On the other hand, it’s a bit frustrating that I have to implement the callable twice. It’s a bad idea to keep them consistent by just being careful. So, it probably is better to define the normal callable in terms of broadcasted
specialization:
(f::Scale)(x) = f.(x)
(f::Shift)(x) = f.(x)
It is a bit strange definition because you’d think it would have a stack overflow. Or, if you know Julia well, maybe you’d assume there are scalar specializations for those callables. It could be very annoying to read such code but I guess I can solve it by just leaving comments.
I think this way of writing broadcastable callables is somewhat OK. (It actually is fun.) But, I’d like to know if there are better ways to do this.