# Zygote gradient with custom type results in error

Hi! I’m once more struggling with Zygote. I’m trying to maximize the kernel-correlation between two point sets in 2D. Imagine two scattered sets of points, calculate the distances between all points in one set to all points in the other set, weigh the distances with some kernel function, sum it up - that’s the kernel correlation. Now I want to transform one of the point sets with a rigid transform (shift, rotation) in a way that maximizes the correlation, yielding the optimal shift and rotation. Let’s dive into some code:

``````using Zygote
using LinearAlgebra

struct PointSet{T}
points::T
end

Base.iterate(ps::PointSet) = iterate(ps.points)
Base.iterate(ps::PointSet, state) = iterate(ps.points, state)
Base.length(ps::PointSet) = length(ps.points)
Base.size(ps::PointSet) = size(ps.points)

function pointset_correlation(reference_set, transformable_set, K) # K is the kernel function
corr = 0.
for rp in reference_set
for tp in transformable_set
dist = norm(rp - tp)
corr += K(dist)
end
end
return corr
end

function transform_point(point, shift, rotmat::AbstractMatrix, rotation_center)
return (rotmat * (point .- rotation_center)) .+ rotation_center .+ shift
end

function rigid_transform(ps, shift, ϕ, rotation_center)
rotmat = [cos(ϕ) -sin(ϕ)
sin(ϕ) cos(ϕ)]
return PointSet([transform_point(p, shift, rotmat, rotation_center) for p in ps])
end

ps = PointSet([rand(2) for _ in 1:10])

transformed_points = rigid_transform(ps, shift, angle, [0,0])
pointset_correlation(ps, transformed_points, z->exp(-z^2))
end

ERROR: MethodError: no method matching +(::NamedTuple{(:points,), Tuple{Vector{Vector{Float64}}}}, ::Vector{Vector{Float64}})
``````

Any idea what I’m doing wrong? Ideally, I want to use static arrays for the points and maybe also forwarddiff to make it fast. When I dump the `PointSet` wrapper and use vectors of vectors it seems to work, but in my understanding it should work like this, too and having the wrapper is practical for dispatch.

``````using Zygote: @adjoint

@adjoint Base.iterate(ps::PointSet) = iterate(ps.points), ((pt, st),)->(PointSet([pt]), st)
@adjoint Base.iterate(ps::PointSet, state) = iterate(ps.points, state), ((pt, st),)->(PointSet([pt]), st)
Base.:(+)(x::PointSet{T}, y::PointSet{T}) where T = PointSet([a .+ b for (a, b) in zip(x.points, y.points)])
# The types on this one should probably be locked down a little better.
Base.:(+)(x::PointSet, y) = PointSet([a .+ b for (a, b) in zip(x.points, y)])
``````

Thanks for your reply! I ended up using Enzyme in combination with StaticArrays, which speeded up things by a factor of 40. For the sake of completeness, here’s the code:

``````using Enzyme
using LinearAlgebra
using StaticArrays

struct PointSet{T}
points::T
end

Base.iterate(ps::PointSet) = iterate(ps.points)
Base.iterate(ps::PointSet, state) = iterate(ps.points, state)
Base.length(ps::PointSet) = length(ps.points)
Base.size(ps::PointSet) = size(ps.points)

K(x) = exp(-x^2)

function pointset_correlation(reference_set::PointSet, transformable_set::PointSet, K)
corr = 0.
for rp in reference_set
for tp in transformable_set
dist = norm(rp-tp)
corr += K(dist)
end
end
return corr
end

function rigid_transform(ps::PointSet, shift, ϕ, rotation_center)
rotmat = @SMatrix[cos(ϕ) -sin(ϕ)
sin(ϕ) cos(ϕ)]
PointSet([(rotmat * (p .- rotation_center)) .+ rotation_center .+ shift for p in ps])
end

function transformed_correlation(reference_set, transformable_set, shift, ϕ, rc)
transformed_set = rigid_transform(transformable_set, shift, ϕ, rc)
return pointset_correlation(reference_set, transformed_set, K)
end

ps = [rand(2) for _ in 1:10]
ps2 = [rand(2) for _ in 1:10]
pss  = PointSet([MVector{2,Float64}(p...) for p in ps])
pss2 = PointSet([MVector{2,Float64}(p...) for p in ps2])

scratch  = deepcopy(ps)
scratch2 = deepcopy(ps2)
s  = @MVector[0.,0.]
ds = @MVector[0.,0.]
r = 0.
rc = @SVector[0.,0.]
grad = autodiff(transformed_correlation, Const(scratch), Const(scratch2), Duplicated(s, ds), Active(r), Const(rc))