Simple implementation of reverse mode automatic differentiation

I was looking for a simple implementation of reverse mode automatic differentiation. There are a lot of tutorials which show implementation of forward mode using dual numbers. But I could not find anything for reverse mode.

I have implemented a simple version below, but I am not sure if this is correct way to do it.

import Base:+,* 

mutable struct Variable
	value::Float64                     # Stores the value of the variable
	derivative::Float64                # Stores the value of derivative
	parents::Vector{Variable}          # Stores the input variables
	local_derivatives::Vector{Float64} # Local derivatives of outputs with respect to input variables

	function Variable(value)
		x = new()					    
		x.value = value
		x.derivative = 0.0
		x.parents = []
		x.local_derivatives = []
		return x
	end
end

function +(a::Variable, b::Variable)
	value = a.value + b.value
	C = Variable(value) 
	C.parents = [a, b]
	C.local_derivatives = [1.0, 1.0]
	return C
end

function *(a::Variable, b::Variable)
	value = a.value*b.value
	C = Variable(value)
	C.parents = [a, b]
	C.local_derivatives = [b.value, a.value]
	return C
end

function calc_derivative(C::Variable)
	C.derivative = 1.0
	# Set all the gradients to zero initially
	function set_derivatives_to_zero(C::Variable)
		for i = 1:length(C.parents)
			C.parents[i].derivative = 0.0
			set_derivatives_to_zero(C.parents[i])
		end
		return nothing
	end

	# Backpropogation of derivatives
	function recursive_derivative(C::Variable)
		for i = 1:length(C.parents)
			C.parents[i].derivative += C.derivative * C.local_derivatives[i] 
			recursive_derivative(C.parents[i])
		end
		return nothing
	end
	set_derivatives_to_zero(C)
	recursive_derivative(C)
end

x = Variable(3.0)
y = Variable(4.0)
z = (x*x*x + x*x)*(y*y)

calc_derivative(z)
println("Derivative of z w.r.t x:", x.derivative)
println("Derivative of z w.r.t x:", y.derivative)

Also, the function calc_derivative is type unstable. But I am unable to make it type stable. I need help in making it type stable.

1 Like

Unnesting set_derivatives_to_zero and recursive_derivative seems to do the trick

1 Like

Yes. That does make it type stable. Thanks a lot. So as a rule one should avoid nesting of functions?

I’m not sure if this is another instance of performance of captured variables in closures · Issue #15276 · JuliaLang/julia · GitHub?

That seems like a cool little example!

I actually put together something very similar in an interactive Pluto notebook for a class recently in GitHub - simeonschaub/ReverseModePluto. I made a few different tradeoffs though by just sacrificing type stability alltogether in favor of simplicity and generality. With that approach, I even managed to put together a perfectly usable neural net.

I think with your approach, it might be a little difficult to extend it to the non-scalar case, which is what you typically want to use reverse AD for. If you think about matrix multiplication for example, you’ll need a way to represent multiplication from the left as well as from the right.

3 Likes

Exactly what I thought (never seen it presented like this before).

Edit: I should mention the presented example runs allocation free on my machine…

1 Like

That is an amazing notebook. Thanks for the reference. I think you should make it into a tutorial. There are some Python tutorials but I could not find a Julia tutorial on reverse mode. It would be a good reference.

I think it would be possible to generalize to vector case by defining

mutable struct Variable
	value::Matrix{Float64}                     # Stores the value of the variable
	derivative::Matrix{Float64}                # Stores the value of derivative
	parents::Vector{Variable}          # Stores the input variables
	local_derivatives::Vector{Matrix{Float64}} # Local derivatives of outputs with respect to input variables
end

But I have not given it much thought. I wanted to show a simple implementation of reverse mode autodiff in class and then I will just show an example in Flux.

Maybe not exactly what your are looking for, but I like these notebooks.

1 Like