Zygote gradient with custom type results in error

Hi! I’m once more struggling with Zygote. I’m trying to maximize the kernel-correlation between two point sets in 2D. Imagine two scattered sets of points, calculate the distances between all points in one set to all points in the other set, weigh the distances with some kernel function, sum it up - that’s the kernel correlation. Now I want to transform one of the point sets with a rigid transform (shift, rotation) in a way that maximizes the correlation, yielding the optimal shift and rotation. Let’s dive into some code:

using Zygote
using LinearAlgebra

struct PointSet{T}
    points::T
end

Base.iterate(ps::PointSet) = iterate(ps.points)
Base.iterate(ps::PointSet, state) = iterate(ps.points, state)
Base.length(ps::PointSet) = length(ps.points)
Base.size(ps::PointSet) = size(ps.points)

function pointset_correlation(reference_set, transformable_set, K) # K is the kernel function
    corr = 0.
    for rp in reference_set
        for tp in transformable_set
            dist = norm(rp - tp)
            corr += K(dist)
        end
    end
    return corr
end

function transform_point(point, shift, rotmat::AbstractMatrix, rotation_center)
    return (rotmat * (point .- rotation_center)) .+ rotation_center .+ shift
end

function rigid_transform(ps, shift, ϕ, rotation_center)
    rotmat = [cos(ϕ) -sin(ϕ)
              sin(ϕ) cos(ϕ)]
    return PointSet([transform_point(p, shift, rotmat, rotation_center) for p in ps])
end

ps = PointSet([rand(2) for _ in 1:10])

Zygote.gradient([0,0], 0) do shift, angle
    transformed_points = rigid_transform(ps, shift, angle, [0,0])
    pointset_correlation(ps, transformed_points, z->exp(-z^2))
end

ERROR: MethodError: no method matching +(::NamedTuple{(:points,), Tuple{Vector{Vector{Float64}}}}, ::Vector{Vector{Float64}})

Any idea what I’m doing wrong? Ideally, I want to use static arrays for the points and maybe also forwarddiff to make it fast. When I dump the PointSet wrapper and use vectors of vectors it seems to work, but in my understanding it should work like this, too and having the wrapper is practical for dispatch.

You need to define adjoints (specifically, read the paragraph immediately before that header) for your custom methods. I am not certain these are the correct definitions, but adding these does produce an answer:

using Zygote: @adjoint

@adjoint Base.iterate(ps::PointSet) = iterate(ps.points), ((pt, st),)->(PointSet([pt]), st)
@adjoint Base.iterate(ps::PointSet, state) = iterate(ps.points, state), ((pt, st),)->(PointSet([pt]), st)
Base.:(+)(x::PointSet{T}, y::PointSet{T}) where T = PointSet([a .+ b for (a, b) in zip(x.points, y.points)])
# The types on this one should probably be locked down a little better.
Base.:(+)(x::PointSet, y) = PointSet([a .+ b for (a, b) in zip(x.points, y)])
@adjoint PointSet(ps) = PointSet(ps), x->(x.points,)

Thanks for your reply! I ended up using Enzyme in combination with StaticArrays, which speeded up things by a factor of 40. For the sake of completeness, here’s the code:

using Enzyme
using LinearAlgebra
using StaticArrays

struct PointSet{T}
    points::T
end

Base.iterate(ps::PointSet) = iterate(ps.points)
Base.iterate(ps::PointSet, state) = iterate(ps.points, state)
Base.length(ps::PointSet) = length(ps.points)
Base.size(ps::PointSet) = size(ps.points)

K(x) = exp(-x^2)

function pointset_correlation(reference_set::PointSet, transformable_set::PointSet, K)
    corr = 0.
    for rp in reference_set
        for tp in transformable_set
            dist = norm(rp-tp)
            corr += K(dist)
        end
    end
    return corr
end

function rigid_transform(ps::PointSet, shift, ϕ, rotation_center)
    rotmat = @SMatrix[cos(ϕ) -sin(ϕ)
                      sin(ϕ) cos(ϕ)]
    PointSet([(rotmat * (p .- rotation_center)) .+ rotation_center .+ shift for p in ps])
end

function transformed_correlation(reference_set, transformable_set, shift, ϕ, rc)
    transformed_set = rigid_transform(transformable_set, shift, ϕ, rc)
    return pointset_correlation(reference_set, transformed_set, K)
end

ps = [rand(2) for _ in 1:10]
ps2 = [rand(2) for _ in 1:10]
pss  = PointSet([MVector{2,Float64}(p...) for p in ps])
pss2 = PointSet([MVector{2,Float64}(p...) for p in ps2])

function gradtest(ps, ps2)
    scratch  = deepcopy(ps)
    scratch2 = deepcopy(ps2)
    s  = @MVector[0.,0.]
    ds = @MVector[0.,0.]
    r = 0.
    rc = @SVector[0.,0.]
    grad = autodiff(transformed_correlation, Const(scratch), Const(scratch2), Duplicated(s, ds), Active(r), Const(rc))
    return ds, grad[1]
end

gradtest(pss, pss2)

In theory one could speed up things even more by using out of place computations, but I found it doesn’t help too much.

2 Likes