I’m following along with this tutorial by Chris Rackauckas. This is the code provided for a simple neural network:
using Flux
NNODE = Chain(x -> [x], # Take in a scalar and transform it into an array
Dense(1,32,tanh),
Dense(32,1),
first) # Take first value, i.e. return a scalar
NNODE(1.0)
g(t) = t*NNODE(t) + 1f0
using Statistics
ϵ = sqrt(eps(Float32))
loss() = mean(abs2(((g(t+ϵ)-g(t))/ϵ) - cos(2π*t)) for t in 0:1f-2:1f0)
opt = Flux.Descent(0.01)
data = Iterators.repeated((), 1000)
iter = 0
cb = function () #callback function to observe training
global iter += 1
if iter % 500 == 0
display(loss())
end
end
display(loss())
Flux.train!(loss, Flux.params(NNODE), data, opt; cb=cb)
This code runs just fine. I noticed however that the numbers are defined as Float32 rather than Float64, expressions such as 1f0 or eps(Float32). If I change all of these to Float64, by replacing all of my 1f0 with 1e0, I get this code, which no longer runs:
using Flux
NNODE = Chain(x -> [x], # Take in a scalar and transform it into an array
Dense(1,32,tanh),
Dense(32,1),
first) # Take first value, i.e. return a scalar
NNODE(1.0)
g(t) = t*NNODE(t) + 1e0
using Statistics
ϵ = sqrt(eps(Float64))
loss() = mean(abs2(((g(t+ϵ)-g(t))/ϵ) - cos(2π*t)) for t in 0:1e-2:1e0)
opt = Flux.Descent(0.01)
data = Iterators.repeated((), 1000)
iter = 0
cb = function () #callback function to observe training
global iter += 1
if iter % 500 == 0
display(loss())
end
end
display(loss())
Flux.train!(loss, Flux.params(NNODE), data, opt; cb=cb)
0.6015101698151941
MethodError: no method matching zero(::Tuple{Float64, Float64})
Closest candidates are:
zero(::Union{Type{P}, P}) where P<:Dates.Period at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\Dates\src\periods.jl:53
zero(::StatsBase.Histogram{T, N, E}) where {T, N, E} at C:\Users\ruvil.julia\packages\StatsBase\Q76Ni\src\hist.jl:538
zero(::SparseArrays.AbstractSparseArray) at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\SparseArrays\src\SparseArrays.jl:55
…
I have read that with Flux it is recommend to use Float32, because the extra precision isn’t needed and you halve the memory usage. However I still would like to understand what the source of this error is.