Matrix multiplication with custom types

Hey there.

So I’m very new to coding with Julia and am trying to get a handle on the “generalizability” of the language, as I’ll call it.

Below I created a Point strcut/class/? and defined the + and * operations of two points. I want to figure out how to take this and make the last line work somewhat like regular matrix multiplication works. Just instead of scalars the entries of the resulting matrix would be Point’s. Could someone point me in the direction to look to do this?

import Base.+
import Base.*

struct Point
    x::Number
    y::Number
end

+(p1::Point, p2::Point) = Point(p1.x + p2.x, p1.y + p2.y)
*(p1::Point, p2::Point) = Point(p1.x * p2.x, p1.y * p2.y)

a = Point(1,1)
b = Point(2,2)
c = a + b

point_matrix = [Point(i, j) for i in 1:3, j in 4:6]
point_matrix

point_vector = [a; b; c]
point_vector

point_matrix * point_vector

Below is the error that is generated

MethodError: no method matching zero(::Point)
Closest candidates are:
zero(!Matched::Type{LibGit2.GitHash}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LibGit2/src/oid.jl:220
zero(!Matched::Type{Pkg.Resolve.VersionWeights.VersionWeight}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/Pkg/src/resolve/VersionWeights.jl:19
zero(!Matched::Type{Pkg.Resolve.MaxSum.FieldValues.FieldValue}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/Pkg/src/resolve/FieldValues.jl:44

Stacktrace:
[1] generic_matvecmul!(::Array{Point,1}, ::Char, ::Array{Point,2}, ::Array{Point,1}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/matmul.jl:542
[2] mul! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/matmul.jl:76 [inlined]
[3] *(::Array{Point,2}, ::Array{Point,1}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/matmul.jl:50
[4] top-level scope at In[12]:1

Also, if anyone could give me formatting tips or a link to a guide on how to make a better post that would be great!

Welcome to Julia!

So the error you’re running into is that it wants a zero function defined, which it uses at some point during the matrix multiplication algorithm. So it wants you to define a function zero(::Point) = ... that returns a zero object for the Point type, i.e. something so that that x + zero(Point) == x whenever x is a Point (in other words, zero(T) should return the additive identity for things of type T, but it doesn’t know what this is for Points, so it wants you to define the method).

Formatting tips: PSA: how to quote code with backticks

General posting tips: Please read: make it easier to help you

1 Like

Thanks for that! Still getting used to understanding the errors in Julia. I added the lines

import Base.zero
zero(x::Point) = Point(0,0)

to the code and it worked.

Is there a place that indicates the types of functions I have to import from Base to get them to work like this? I’ll look through the docs a bit.

The answer is “each of them that you use” (any function from Base that you want to override requires that you do one of two things)

import Base: +, *, zero

zero(x::MyType) = ...

or (does not require the import statement)

Base.zero(x::MyType) = ...

also, your type will work more quickly if you define it this way

struct Point{T<:Real}
    x::T
    y::T
end

see the docs Constructors · The Julia Language

Just to add, with the parametric definition (struct Point{T <: Real}), you would want to define zero as

import Base.zero
zero(::Point{T}) where {T} = Point(zero(T), zero(T))

so that way you get Point’s with the right numeric type:

julia> zero(Point(1.0, 2.0))
Point{Float64}(0.0, 0.0)

julia> zero(Point(1, 2))
Point{Int64}(0, 0)

This actually illustrates one of the advantages of the parametric definition; without it, you don’t know ahead of time what the numeric types of p.x and p.y are for a Point p, but in the parametric case you do since you have access to the parameter T to use in zero.

(You could define zero(p::Point) = Point(zero(p.x), zero(p.y)) in the non-parametric case, but this involves a runtime lookup to figure out the type of p.x in order to choose the right zero method to use when you call zero(p.x), whereas in the parametric case it’s all done at compile time.)

1 Like

Thanks for the extra tidbit Jeffrey!

Thanks for the detailed explanation Eric! I’ll be sure to look this up.

2 Likes