Overwhelmed by broadcast interface. Help me 🥺

I am writing code that would benefit from generic generator arithmetic over a vector type. In general, Vector{PrimeGenerator} suffices, but it introduces duplication of irrelevant data. Thus, I am attempting to subtype AbstractVector. The MWE of that is as follows:

# Vector elements are generators for modular arithmetic

struct PrimeGenerator
    val::BigInt
    mod::BigInt
end

value(x::PrimeGenerator) = x.val
modulus(x::PrimeGenerator) = x.mod

import Base.*
function *(x::PrimeGenerator, y::PrimeGenerator) 
    @assert x.mod == y.mod "Groups must be equal"
    mod = modulus(x)
    val = mod(value(x) * value(y), mod)
    PrimeGenerator(val, mod)
end

import Base.^
^(g::PrimeGenerator, n::Integer) = PrimeGenerator(powermod(value(g), n, modulus(g)), modulus(g))

# Definition of the vector

struct GVector{G} <: AbstractVector{G}
    x::Vector{T} where T
    g::G
end


Base.IndexStyle(::Type{<:GVector}) = IndexLinear()
Base.setindex!(𝐠::GVector, val::PrimeGenerator, i::Int) = 𝐠.x[i] = value(val)
Base.getindex(𝐠::GVector, i::Int) = PrimeGenerator(𝐠.x[i], modulus(𝐠.g))
Base.length(g::GVector) = length(g.x)
Base.size(g::GVector) = size(g.x)

With this implementation, when I run broadcasting like:

a = PrimeGenerator(3, 23)
gv = GVector([a, a^2, a^3], a)

s = gv .^ 2

The resulting type of s is Vector{PrimeGenerator} where instead I wish it to be a GVector. Vaguely I want broadcasting to work like:

mybroadcast(f::Function, 𝐠::GVector, x::Integer) = GVector([value(f(i, x)) for i in 𝐠], 𝐠.g)
# mybroadcast(^, gv, 2)

function mybroadcast(::typeof(*), 𝐠::GVector{T}, 𝐡::GVector{T}) where T
    @assert 𝐠.g == 𝐡.g "The groups must be equal"
    p = modulus(𝐠.g)
    v = GVector([mod(i*j, p) for (i, j) in zip(𝐠.x, 𝐡.x)], 𝐠.g)
    return v
end

# mybroadcast(*, gv, gv)

To get the desired behaviour, I attempted to modify the ArrayAndChar example from documentation and came up with:

Base.BroadcastStyle(::Type{<:GVector}) = Broadcast.ArrayStyle{GVector}()

function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{GVector}}, ::Type{ElType}) where ElType
    # Scan the inputs for the ArrayAndChar:
    A = find_aac(bc)
    # Use the char field of A to create the output
    GVector(similar(Array{ElType}, axes(bc)), A.g)
end

"`A = find_aac(As)` returns the first ArrayAndChar among the arguments."
find_aac(bc::Base.Broadcast.Broadcasted) = find_aac(bc.args)
find_aac(args::Tuple) = find_aac(find_aac(args[1]), Base.tail(args))
find_aac(x) = x
find_aac(::Tuple{}) = nothing
find_aac(a::GVector, rest) = a
find_aac(::Any, rest) = find_aac(rest)

However it fails attempting to initialize GVector(::Vector{PrimeGenerator}, ::PrimeGenerator) which is not desired. Surelly I could strip up irrelevant data by defining

GVector(x::Vector{PrimeGenerator}, g::PrimeGenerator) = GVector(value.(x), g)

But that is instead an ill fix, and I would be better off using Vector{PrimeGenerator} type in calculations instead.

Adding

GVector(gs::Vector{G}, g::G) where G = GVector(value.(gs), g)

seems to work for me. I see no need for specialized broadcasting, here is what I checked:

# Vector elements are generators for modular arithmetic

struct PrimeGenerator
    val::BigInt
    mod::BigInt
end

value(x::PrimeGenerator) = x.val
modulus(x::PrimeGenerator) = x.mod

import Base.*
function *(x::PrimeGenerator, y::PrimeGenerator) 
    @assert x.mod == y.mod "Groups must be equal"
    mod = modulus(x)
    val = mod(value(x) * value(y), mod)
    PrimeGenerator(val, mod)
end

import Base.^
^(g::PrimeGenerator, n::Integer) = PrimeGenerator(powermod(value(g), n, modulus(g)), modulus(g))

# Definition of the vector

struct GVector{G} <: AbstractVector{G}
    x::Vector{T} where T 
    g::G
end

GVector(gs::Vector{G}, g::G) where G = GVector(value.(gs), g)

Base.IndexStyle(::Type{<:GVector}) = IndexLinear()
Base.setindex!(𝐠::GVector, val::PrimeGenerator, i::Int) = 𝐠.x[i] = value(val)
Base.getindex(𝐠::GVector, i::Int) = PrimeGenerator(𝐠.x[i], modulus(𝐠.g))
Base.length(g::GVector) = length(g.x)
Base.size(g::GVector) = size(g.x)

a = PrimeGenerator(3, 23)
gv = @show GVector([a, a^2, a^3], a)
s = gv .^ 2
2 Likes

That looks like a good start :slightly_smiling_face:. However, one drawback is that it makes unnecessary allocation of Vector{PrimeGenerator}. It does @assert x.mod == y.mod for every multiplication while it ideally would need to be run only once for the whole vector. I wonder if those things could be avoided so that writing code with GVector would be as performant as writing down mod(x, p) for integer arithmetic by hand :thinking:

You mean in the construction process? You certainly could add other constructors which accept BigInt[] directly? For small sizes we could use Tuple or SVector from StaticArrays instead?

Perhaps I don’t understand how broadcasting works in the current implementation. I suppose that when gv .^ 2 takes place the intermediary result is stored into Vector{PrimeGenerator}, which at the last step is passed as an argument to GVector constructor. Please clarify if I got that wrong.

Ah, I see, you want something like s .= gv .^ 2? This indeed will need custom broadcasting, I believe.

Edit: double checked again: even this seems to work without custom broadcasting!

1 Like

Yeah, that is about what I mean. I like how the problem is nicely solved in Mods package, where mod is stored into a type parameter so no data is duplicated and can use a simple Vector type. But it does introduce a compilation step for each new prime modulus and is not applicable when the prime modulus is BigInt or have a, for example, elliptic group specification. Thus I am left to implement similar behaviour in runtime.

When I run your code I see that execution does not enter Base.similar method and typeof(s) == Vector{PrimeGenerator} instead of GVector{PrimeGenerator}.

1 Like

You are right! The following looks slightly better to me, but certainly can be improved…

# Vector elements are generators for modular arithmetic

struct PrimeGenerator
    val::BigInt
    mod::BigInt
end

value(x::PrimeGenerator) = x.val
modulus(x::PrimeGenerator) = x.mod

import Base.*
function *(x::PrimeGenerator, y::PrimeGenerator) 
    @assert x.mod == y.mod "Groups must be equal"
    mod = modulus(x)
    val = mod(value(x) * value(y), mod)
    PrimeGenerator(val, mod)
end

import Base.^
^(g::PrimeGenerator, n::Integer) = PrimeGenerator(powermod(value(g), n, modulus(g)), modulus(g))

# Definition of the vector

struct GVector{G, T} <: AbstractVector{G}
    x::Vector{T} 
    g::G
end

GVector(gs::Vector{G}, g::G) where {G} = GVector(value.(gs), g)

Base.IndexStyle(::Type{<:GVector}) = IndexLinear()
Base.setindex!(𝐠::GVector, val::PrimeGenerator, i::Int) = 𝐠.x[i] = value(val)
Base.getindex(𝐠::GVector, i::Int) = PrimeGenerator(𝐠.x[i], modulus(𝐠.g))
Base.length(g::GVector) = length(g.x)
Base.size(g::GVector) = size(g.x)
Base.similar(g::GVector) = GVector(similar(g.x), g.g)

Base.BroadcastStyle(::Type{<:GVector{G, T}}) where {G, T} = Broadcast.ArrayStyle{GVector{G, T}}() 

function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{GVector{G, T}}}, ::Type{G}) where {G, T}
    GVector(similar(Vector{T}, axes(bc)), find(bc).g)
end

find(bc::Base.Broadcast.Broadcasted) = find(bc.args)
find(args::Tuple) = find(find(args[1]), Base.tail(args))
find(x) = x
find(::Tuple{}) = nothing
find(a::GVector, rest) = a
find(::Any, rest) = find(rest)

a = PrimeGenerator(3, 23)
gv = GVector([a, a^2, a^3], a)
s = similar(gv)
s .= gv .^ 2
1 Like

This is quite close to how I would like GVector to work in broadcasting. I thought I could just improve the code to fix some last missing pieces, but after hours of looking at documentation I am still underwater on how to fix those last important issues:

  • The multiplication between the two GVector does not work (gv .* gv throws an error)
  • Multiplication between a PrimeGenerator and GVector fail (gv .* a)
  • For consistency sake, I also wonder if it’s possible to fix value.(gv)
1 Like

Yeah, I believe documentation is quite terse (examples are rare?). You shouldn’t need hours to get this working (famous last words;). So should we investigate this here or in different threads?

BTW: your use case is rather exemplary!

I prefer to continue with the existing thread as the MWE is well defined here, and these issues are still under question on how broadcasting works for a custom vector type.

Perhaps the example could get into docs :face_holding_back_tears:

OK, we’ll start with

  • The multiplication between the two GVector does not work ( gv .* gv throws an error)

Edit: try

import Base.*
function *(x::PrimeGenerator, y::PrimeGenerator) 
    @assert x.mod == y.mod "Groups must be equal"
    _mod = modulus(x)
    _val = mod(value(x) * value(y), _mod)
    PrimeGenerator(_val, _mod)
end

You can’t mix ‘variable’ and ‘function’ names, they live in the same namespace.

1 Like

Hm, should this be really modeled as a broadcast (it isn’t for reals AFAIU)?

In the constructor? I don’t believe. You explicitly ask for it via

gv = GVector([a, a^2, a^3], a)

Ohh, I indeed overlooked the use of mod :blush:

With numbers, it is natural to multiply a vector with a constant in ordinary calculations like 2 .* [1, 2, 3]. Through NIZK proofs I have implemented and are planning to implement such operation have not yet been needed. But such things really can kick out of flow when they do occur; thus, it would be nice to have.

Ok. I guess I have fixed this issue by defining:

Base.broadcasted(::typeof(value), gv::GVector) = gv.x
Base.broadcasted(::typeof(value), gv::Broadcasted{ArrayStyle{GVector{G, T}}}) where {G, T} = Base.materialize(gv).x

Solved the issue by defining:

Base.broadcasted(::Function, x::PrimeGenerator, y::GVector) = (x for i in 1:length(y)) .* y
Base.broadcasted(::Function, y::GVector, x::PrimeGenerator) = y .* (x for i in 1:length(y))
1 Like