Hi! I’m once more struggling with Zygote. I’m trying to maximize the kernel-correlation between two point sets in 2D. Imagine two scattered sets of points, calculate the distances between all points in one set to all points in the other set, weigh the distances with some kernel function, sum it up - that’s the kernel correlation. Now I want to transform one of the point sets with a rigid transform (shift, rotation) in a way that maximizes the correlation, yielding the optimal shift and rotation. Let’s dive into some code:
using Zygote
using LinearAlgebra
struct PointSet{T}
points::T
end
Base.iterate(ps::PointSet) = iterate(ps.points)
Base.iterate(ps::PointSet, state) = iterate(ps.points, state)
Base.length(ps::PointSet) = length(ps.points)
Base.size(ps::PointSet) = size(ps.points)
function pointset_correlation(reference_set, transformable_set, K) # K is the kernel function
corr = 0.
for rp in reference_set
for tp in transformable_set
dist = norm(rp - tp)
corr += K(dist)
end
end
return corr
end
function transform_point(point, shift, rotmat::AbstractMatrix, rotation_center)
return (rotmat * (point .- rotation_center)) .+ rotation_center .+ shift
end
function rigid_transform(ps, shift, ϕ, rotation_center)
rotmat = [cos(ϕ) -sin(ϕ)
sin(ϕ) cos(ϕ)]
return PointSet([transform_point(p, shift, rotmat, rotation_center) for p in ps])
end
ps = PointSet([rand(2) for _ in 1:10])
Zygote.gradient([0,0], 0) do shift, angle
transformed_points = rigid_transform(ps, shift, angle, [0,0])
pointset_correlation(ps, transformed_points, z->exp(-z^2))
end
ERROR: MethodError: no method matching +(::NamedTuple{(:points,), Tuple{Vector{Vector{Float64}}}}, ::Vector{Vector{Float64}})
Any idea what I’m doing wrong? Ideally, I want to use static arrays for the points and maybe also forwarddiff to make it fast. When I dump the PointSet wrapper and use vectors of vectors it seems to work, but in my understanding it should work like this, too and having the wrapper is practical for dispatch.
You need to define adjoints (specifically, read the paragraph immediately before that header) for your custom methods. I am not certain these are the correct definitions, but adding these does produce an answer:
using Zygote: @adjoint
@adjoint Base.iterate(ps::PointSet) = iterate(ps.points), ((pt, st),)->(PointSet([pt]), st)
@adjoint Base.iterate(ps::PointSet, state) = iterate(ps.points, state), ((pt, st),)->(PointSet([pt]), st)
Base.:(+)(x::PointSet{T}, y::PointSet{T}) where T = PointSet([a .+ b for (a, b) in zip(x.points, y.points)])
# The types on this one should probably be locked down a little better.
Base.:(+)(x::PointSet, y) = PointSet([a .+ b for (a, b) in zip(x.points, y)])
@adjoint PointSet(ps) = PointSet(ps), x->(x.points,)
Thanks for your reply! I ended up using Enzyme in combination with StaticArrays, which speeded up things by a factor of 40. For the sake of completeness, here’s the code:
using Enzyme
using LinearAlgebra
using StaticArrays
struct PointSet{T}
points::T
end
Base.iterate(ps::PointSet) = iterate(ps.points)
Base.iterate(ps::PointSet, state) = iterate(ps.points, state)
Base.length(ps::PointSet) = length(ps.points)
Base.size(ps::PointSet) = size(ps.points)
K(x) = exp(-x^2)
function pointset_correlation(reference_set::PointSet, transformable_set::PointSet, K)
corr = 0.
for rp in reference_set
for tp in transformable_set
dist = norm(rp-tp)
corr += K(dist)
end
end
return corr
end
function rigid_transform(ps::PointSet, shift, ϕ, rotation_center)
rotmat = @SMatrix[cos(ϕ) -sin(ϕ)
sin(ϕ) cos(ϕ)]
PointSet([(rotmat * (p .- rotation_center)) .+ rotation_center .+ shift for p in ps])
end
function transformed_correlation(reference_set, transformable_set, shift, ϕ, rc)
transformed_set = rigid_transform(transformable_set, shift, ϕ, rc)
return pointset_correlation(reference_set, transformed_set, K)
end
ps = [rand(2) for _ in 1:10]
ps2 = [rand(2) for _ in 1:10]
pss = PointSet([MVector{2,Float64}(p...) for p in ps])
pss2 = PointSet([MVector{2,Float64}(p...) for p in ps2])
function gradtest(ps, ps2)
scratch = deepcopy(ps)
scratch2 = deepcopy(ps2)
s = @MVector[0.,0.]
ds = @MVector[0.,0.]
r = 0.
rc = @SVector[0.,0.]
grad = autodiff(transformed_correlation, Const(scratch), Const(scratch2), Duplicated(s, ds), Active(r), Const(rc))
return ds, grad[1]
end
gradtest(pss, pss2)
In theory one could speed up things even more by using out of place computations, but I found it doesn’t help too much.