DiffEqFlux gradient error: breaks with increasing size of problem, size should not affect autodiff

I am optimizing a simple ODE problem with DiffEqFlux to get started with Julia. This code worked on TensorFlow.

In the following code, running main(S=5, B=256) works, but main(S=7, B=256) fails. I previously used EnsembleProblem with same code and got same results. The error begins:

ERROR: Need an adjoint for constructor StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}}. Gradient is of type Vector{Nothing}

but, as noted, the code runs for S=5. Also (separately) it seems a bit slow, so any hints about how to code for better performance welcome.

Thanks.

using Random: seed!
using RandomNumbers.Xorshifts
using DifferentialEquations, DiffEqFlux
using DiffEqSensitivity, OrdinaryDiffEq, ForwardDiff
using Statistics

r = Xoroshiro128Plus(0x1234567890abcdef)

L1 					= 0.02 #/(S*S) 	# L1 reg, normalized by size of weight matrix, W
L2 					= 0.05 #/(S*S)	# L2 regularization parameter
useL1				= true			# True -> L1, False -> L2

const final_t		        = 3.0			# final time
const tspan			= (0.0,final_t)	# range of times
atol				        = 1e-6			# abs tol for diffeq
rtol				        = 1e-4			# rel tol for diffeq
####################################################################

phi(x) = 1.0 ./ (exp.(-x) .+ 1.0)

function ode!(ds, s, w, t, input)
	ds .= phi(w*s) .- s .+ input
end

function newbatch(B, S; stochastic=false)
	initial = 0.1 .* ones(B,S)
	if stochastic
		inputvec = 2.0 .* rand(r,B)
	else
		inputvec = [range(0,2,length=B);]	# need semicolon to expand range
	end
	# make inputs [B,S] with first column as input to y0 at t=0, other cols as zeros
	input = hcat(inputvec, zeros(B,S-1))
	output = zeros(B,S)
	output = map(x -> convert(Float64,x >= 0.5 && x <= 1.5), inputvec)
	return input, initial, output
end

## cost function calculations
function reg(w, regularize)
	if !regularize
		return 0.0
	elseif useL1
		return L1 * norm(w,1)
	else
		return L2 * norm(w)
	end
end

function loss(p, B, S, regularize, stoch_batch)
	input, initial, output = newbatch(B, S, stochastic=stoch_batch)
	result = []
	for i in 1:B
		prob = ODEProblem((ds, s, p, t) -> ode!(ds, s, w(p,S), t, input[i,:]),
						initial[i,:], tspan, p)
		soln = solve(prob, Tsit5(), saveon = false, reltol=rtol,abstol=atol)
		(i == 1) ? result = soln[end]' : result = vcat(result, soln[end]')
	end
	loss_val = mean(abs2.(a(p) .* result[:,end] .- output)) .+ reg(w(p,S), regularize)
	return loss_val
end

function callback(p, l)
	#display(l)
	return false
end

w(p, S) = reshape(p[1:end-1],S,S)
a(p) = p[end]
	
# parameters p:
# a = p[end] = 1.0
# w = reshape(p[1:end-1],S,S)
function main(;B::Int=4, S::Int=3, regularize::Bool=false, stoch_batch::Bool=false,
			rnd_entropy::Bool=false)
	rnd_entropy ? seed!(r) : seed!(r, 0x1234567890abcdef)
	p = ones(S*S+1)
	p[1:end-1] = 0.1 .* randn(r, S*S)
	opt_result = DiffEqFlux.sciml_train(p -> loss(p, B, S, regularize, stoch_batch), 
			p, ADAM(0.1); cb = callback, maxiters = 10)
	finalp = opt_result.u
	
	println()
	print(opt_result)
	println()
	print("Final param = ", typeof(finalp))
	println()
	println("Loss calculated from final param = ", 
				loss(finalp, B, S, regularize, stoch_batch)[1])
	return w(opt_result.u,S)
end

main(S=5, B=256)
main(S=7, B=256)

First of all: excellent MWE!

How long do I have to wait? OK, just finished. One more test…

Confirmed. Fails for me, too, on Julia 1.6.3 with

ERROR: LoadError: Need an adjoint for constructor StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}}. Gradient is of type Vector{Nothing}
Stacktrace:
  [1] error(s::String)  
    @ Base .\error.jl:33
  [2] (::Zygote.Jnew{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}}, Nothing, false})(Δ::Vector{Nothing})

Time to file an issue against Zygote?

P.S.: only random observation: Tsit5()
P.P.S.: the only thing I’d check before is if the gradient is valid.

Thanks very much for the quick reply. It fails similarly for a different solver and a different optimizer.

Random idea, try Array(soln)[:,end]' and see if the error goes away. Seems like an odd omission of an adjoint. @dhairyagandhi96

Thanks for considering this. Made this change:

(i == 1) ? result = Array(soln)[:,end]' : result = vcat(result, Array(soln)[:,end]')

and got the same error result.

Simplify and delete some things until it goes away. That’ll isolate what it is.

Commenting this line:

#output = map(x -> convert(Float64,x >= 0.5 && x <= 1.5), inputvec)

allowed code to run for larger S values. I can experiment with alternative ways to achieve the same output matrix. Is using map in that way expected to cause failure of AD?

Note: The line above should be output = ones(B) instead of output = zeros(B,S) if commenting as here.

Looks like a bug in the map adjoint.