Hello everyone,
I want to be able to use Enzyme to differentiate a scalar-valued function with respect to only part of the fields of an arbitrary struct
.
As an example, in my MWE below, I would like to differentiate the calculate_area
function with respect to the z-positions of all the points
in my surface s
. For starters, I try to differentiate w.r.t. s
completely. However, I get the following error:
ERROR: Enzyme Mutability Error: Cannot add one in place to immutable value 1.0 tup[i]=0.0 i=1 w=1 tup=(0.0, -0.0)
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] runtime_generic_rev(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, tape::Enzyme.Compiler.Tape{…}, f::typeof(-), df::Nothing, primal_1::Float64, shadow_1_1::Float64, primal_2::Float64, shadow_2_1::Float64)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/Vjlrr/src/rules/jitrules.jl:584
[3] calculate_area
@ ~/Work/Projects/geom-sensitivity/enzyme_mwe.jl:249 [inlined]
[4] calculate_area
@ ~/Work/Projects/geom-sensitivity/enzyme_mwe.jl:0 [inlined]
[5] diffejulia_calculate_area_2754_inner_1wrap
@ ~/Work/Projects/geom-sensitivity/enzyme_mwe.jl:0
[6] macro expansion
@ ~/.julia/packages/Enzyme/Vjlrr/src/compiler.jl:8839 [inlined]
[7] enzyme_call
@ ~/.julia/packages/Enzyme/Vjlrr/src/compiler.jl:8405 [inlined]
[8] CombinedAdjointThunk
@ ~/.julia/packages/Enzyme/Vjlrr/src/compiler.jl:8178 [inlined]
[9] autodiff
@ ~/.julia/packages/Enzyme/Vjlrr/src/Enzyme.jl:491 [inlined]
[10] autodiff(mode::ReverseMode{…}, f::typeof(calculate_area), ::Type{…}, args::Duplicated{…})
@ Enzyme ~/.julia/packages/Enzyme/Vjlrr/src/Enzyme.jl:512
[11] top-level scope
@ ~/Work/Projects/geom-sensitivity/enzyme_mwe.jl:263
Some type information was truncated. Use `show(err)` to see complete types.
Following is the MWE:
using Enzyme
mutable struct Point
x::Real
y::Real
z::Real
end
function Point(x, y, z_func::Function)
z = z_func(x, y)
Point(x, y, z)
end
struct Edge
from::Integer
to::Integer
end
struct Triangle
p1::Integer
p2::Integer
p3::Integer
end
struct Surface
points::Vector{Point}
edges::Vector{Edge}
triangles::Vector{Triangle}
end
# X, Y grid positions for points
positions = Dict{Integer, Vector{Real}}(
1=>[-0.5, -0.5],
2=>[0.5, -0.5],
3=>[0.5, 0.5],
4=>[-0.5, 0.5],
5=>[-0.3, -0.5],
6=>[-0.1, -0.5],
7=>[0.1, -0.5],
8=>[0.3, -0.5],
9=>[0.5, -0.3],
10=>[0.5, -0.1],
11=>[0.5, 0.1],
12=>[0.5, 0.3],
13=>[0.3, 0.5],
14=>[0.1, 0.5],
15=>[-0.1, 0.5],
16=>[-0.3, 0.5],
17=>[-0.5, 0.3],
18=>[-0.5, 0.1],
19=>[-0.5, -0.1],
20=>[-0.5, -0.3],
21=>[-0.3, -0.3],
22=>[-0.1, -0.3],
23=>[0.1, -0.3],
24=>[0.3, -0.3],
25=>[0.3, -0.1],
26=>[0.3, 0.1],
27=>[0.3, 0.3],
28=>[0.1, 0.3],
29=>[-0.1, 0.3],
30=>[-0.3, 0.3],
31=>[-0.3, 0.1],
32=>[-0.3, -0.1],
33=>[-0.1, -0.1],
34=>[0.1, -0.1],
35=>[0.1, 0.1],
36=>[-0.1, 0.1],
)
# Height Function
height(x, y) = 500 - (x^2 + y^2) * 10
# Creating points on the surface
points = Point[]
for i in eachindex(1:length(positions))
push!(points, Point(positions[i][1], positions[i][2], height))
end
# Setup Edge connectivity
edges = Edge[
Edge(1, 5), #1
Edge(5, 6), #2
Edge(6, 7), #3
Edge(7, 8), #4
Edge(8, 2), #5
Edge(2, 9), #6
Edge(9, 10), #7
Edge(10, 11), #8
Edge(11, 12), #9
Edge(12, 3), #10
Edge(3, 13), #11
Edge(13, 14), #12
Edge(14, 15), #13
Edge(15, 16), #14
Edge(16, 4), #15
Edge(4, 17), #16
Edge(17, 18), #17
Edge(18, 19), #18
Edge(19, 20), #19
Edge(20, 1), #20
Edge(20, 21), #21
Edge(21, 22), #22
Edge(22, 23), #23
Edge(23, 24), #24
Edge(24, 9), #25
Edge(19, 32), #26
Edge(32, 33), #27
Edge(33, 34), #28
Edge(34, 25), #29
Edge(25, 10), #30
Edge(18, 31), #31
Edge(31, 36), #32
Edge(36, 35), #33
Edge(35, 26), #34
Edge(26, 11), #35
Edge(17, 30), #36
Edge(30, 29), #37
Edge(29, 28), #38
Edge(28, 27), #39
Edge(27, 12), #40
Edge(5, 21), #41
Edge(21, 32), #42
Edge(32, 31), #43
Edge(31, 30), #44
Edge(30, 16), #45
Edge(6, 22), #46
Edge(22, 33), #47
Edge(33, 36), #48
Edge(36, 29), #49
Edge(29, 15), #50
Edge(7, 23), #51
Edge(23, 34), #52
Edge(34, 35), #53
Edge(35, 28), #54
Edge(28, 14), #55
Edge(8, 24), #56
Edge(24, 25), #57
Edge(25, 26), #58
Edge(26, 27), #59
Edge(27, 13), #60
Edge(1, 21), #61
Edge(5, 22), #62
Edge(6, 23), #63
Edge(7, 24), #64
Edge(8, 9), #65
Edge(20, 32), #66
Edge(21, 33), #67
Edge(22, 34), #68
Edge(23, 25), #69
Edge(24, 10), #70
Edge(19, 31), #71
Edge(32, 36), #72
Edge(33, 35), #73
Edge(34, 26), #74
Edge(25, 11), #75
Edge(18, 30), #76
Edge(31, 29), #77
Edge(36, 28), #78
Edge(35, 27), #79
Edge(26, 12), #80
Edge(17, 16), #81
Edge(30, 15), #82
Edge(29, 14), #83
Edge(28, 13), #84
Edge(27, 3), #85
]
# Setup Face connectivity
triangles = Triangle[
Triangle(1,5,21),
Triangle(5,6,22),
Triangle(6,7,23),
Triangle(7,8,24),
Triangle(8,2,9),
Triangle(1,21,20),
Triangle(5,22,21),
Triangle(6,23,22),
Triangle(7,24,23),
Triangle(8,9,24),
Triangle(20,21,32),
Triangle(21,22,33),
Triangle(22,23,34),
Triangle(23,24,25),
Triangle(24,9,10),
Triangle(20,32,19),
Triangle(21,33,32),
Triangle(22,34,33),
Triangle(23,25,34),
Triangle(24,10,9),
Triangle(19,32,31),
Triangle(32,33,36),
Triangle(33,34,35),
Triangle(35,25,26),
Triangle(25,10,11),
Triangle(19,31,18),
Triangle(32,36,31),
Triangle(33,35,36),
Triangle(34,26,35),
Triangle(25,11,26),
Triangle(18,31,30),
Triangle(31,36,29),
Triangle(36,35,28),
Triangle(35,26,27),
Triangle(26,11,12),
Triangle(18,30,17),
Triangle(31,29,30),
Triangle(36,28,29),
Triangle(35,27,28),
Triangle(26,12,27),
Triangle(17,30,16),
Triangle(30,29,15),
Triangle(29,28,14),
Triangle(28,27,13),
Triangle(27,12,3),
Triangle(17,16,4),
Triangle(30,15,16),
Triangle(29,14,15),
Triangle(28,13,14),
Triangle(27,3,13),
]
# Create surface
s = Surface(points, edges, triangles)
function calculate_area(s::Surface)
area = 0.0
for triangle in s.triangles
x1 = s.points[triangle.p1].x
y1 = s.points[triangle.p1].y
z1 = s.points[triangle.p1].z
x2 = s.points[triangle.p2].x
y2 = s.points[triangle.p2].y
z2 = s.points[triangle.p2].z
x3 = s.points[triangle.p3].x
y3 = s.points[triangle.p3].y
z3 = s.points[triangle.p3].z
u1 = x2 - x1
u2 = y2 - y1
u3 = z2 - z1
v1 = x3 - x1
v2 = y3 - y1
v3 = z3 - z1
area += 0.5 * sqrt((u2*v3 - u3*v2)^2 + (u3*v1-u1*v3)^2 + (u1*v2-u2*v1)^2)
end
return area
end
calculate_area(s)
ds = deepcopy(s)
for i in eachindex(ds.points)
ds.points[i].x = 0.0
ds.points[i].y = 0.0
ds.points[i].z = 1.0
end
Enzyme.autodiff(Reverse, calculate_area, Const, Duplicated(s, ds))
On the other hand, using Zygote
I am able to get the gradient without any issue, and it also ignores the fields having structs involving only integers.
All I need to do is:
using Zygote
ds = gradient(calculate_area, s)
# ((points = @NamedTuple{x::Float64, y::Float64, z::Float64}[(x = 1.6816612310819545, y = -1.7344882854613735, z = -0.6339246525544604), (x = 0.008804509063256116, y = -0.017609018126512232, z = -0.42261643503629953), (x = 1.725683776398136, y = -1.6728567220187194, z = -0.6339246525544604), (x = -0.008804509063256116, y = 0.017609018126512232, z = -0.42261643503629953), (x = 0.7357816691599668, y = -2.004896671653612, z = -1.066666666666638), (x = -2.147125321541682, y = -1.1881166765601836, z = -1.0263834223756458), (x = -4.571552002947907, y = -3.313695835806661, z = -0.8310166963474414), (x = -5.86692575011869, y = -7.062454279801859, z = -0.6113082175181475), (x = 0.019327603161698126, y = -3.004637588201552, z = -3.0113082175181356), (x = 1.8646020919901334, y = -1.3942506430834438, z = -0.4930500890423126) … (x = 9.447877599958657, y = -0.007566278284033556, z = 2.6310441645923865), (x = 1.9903061459690947, y = 6.274385700856671, z = 3.1663573686795106), (x = -8.238190034026063, y = 3.80098806245576, z = 2.72186511510064), (x = -14.000475935202989, y = -4.92425414159782, z = 1.7704399505975197), (x = -10.090331899443346, y = -8.052581033706907, z = 1.5311316630442386), (x = -3.812455430534781, y = -11.615446302781455, z = 0.40218636850508027), (x = -3.0384345019264316, y = -6.842144905546649, z = 2.6639788117457117), (x = 2.5273878910781, y = -6.684578626092981, z = 1.544401872924625), (x = 6.27753593731042, y = 3.2121382526298734, z = 5.917081683388485), (x = -5.788945407332804, y = 1.0208251410105178, z = 3.999487497523498)], edges = nothing, triangles = nothing),)
However, it is not possible in Zygote
to make only the Z-values of each point active, while keeping the X and Y values as constant. So I was trying out Enzyme
I tried reading the documentation, and going through other issues here on discord, but was not able to fix the issue.
Could anyone explain what I’m doing wrong?
@wsmoses Any help is appreciated
Thanks
Alan