Inconsistent matrix multiply output from Flux.Dense depending on shape of input

Anyone know a workaround for this inconsistency in matrix multiply?

Here’s a workaround for now, if this affects others’ sanity as much as mine! :laughing:

# Correction below thanks to "skleinbo"
function (d::Dense)(x::AbstractArray)
    Flux._size_check(d, x, 1 => size(d.weight, 2))
    d.σ.(batched_mul(d.weight, x) .+ d.bias) 
end

It seems the fact that you include a and b on the third dimension changes the order of operations somehow. Maybe because the blocks of memory are queried following a different heuristic?

I can’t reproduce on Flux@0.14.6 and Julia 1.9.3.

P.S.: I suppose you meant to write

d.σ.(batched_mul(d.weight, x) .+ d.bias)
1 Like

Thanks for that fix!

So, I made sure my Julia and Flux versions matched yours… here’s exactly what I get:

Here’s a link to the code. Should have provided that before. Sorry!

Copy-pasted your code verbatim. No difference for me.

Can you post your

versioninfo()
Pkg.status(mode=Pkg.PKGMODE_MANIFEST) and
LinearAlgebra.BLAS.get_config() ?

julia> versioninfo()
Julia Version 1.9.3
Commit bed2cd540a1 (2023-08-24 14:43 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (x86_64-apple-darwin22.4.0)
  CPU: 4 × Intel(R) Core(TM) i5-7500 CPU @ 3.40GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, skylake)
  Threads: 8 on 4 virtual cores
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 8
julia> Pkg.status(mode=Pkg.PKGMODE_MANIFEST)
Status `~/.julia/dev/DSTModels/util/weirdissue/Manifest.toml`
  [621f4979] AbstractFFTs v1.5.0
  [79e6a3ab] Adapt v3.6.2
  [dce04be8] ArgCheck v2.3.0
  [a9b6321e] Atomix v0.1.0
  [198e06fe] BangBang v0.3.39
  [9718e550] Baselet v0.1.1
  [fa961155] CEnum v0.4.2
  [082447d4] ChainRules v1.54.0
  [d360d2e6] ChainRulesCore v1.16.0
  [bbf7d656] CommonSubexpressions v0.3.0
  [34da2185] Compat v4.9.0
  [a33af91c] CompositionsBase v0.1.2
  [187b0558] ConstructionBase v1.5.4
  [6add18c4] ContextVariablesX v0.1.3
  [9a962f9c] DataAPI v1.15.0
  [864edb3b] DataStructures v0.18.15
  [e2d170a0] DataValueInterfaces v1.0.0
  [244e2a9f] DefineSingletons v0.1.2
  [8bb1440f] DelimitedFiles v1.9.1
  [163ba53b] DiffResults v1.1.0
  [b552c78f] DiffRules v1.15.1
  [ffbed154] DocStringExtensions v0.9.3
  [cc61a311] FLoops v0.2.1
  [b9860ae5] FLoopsBase v0.1.1
  [1a297f60] FillArrays v1.6.1
  [587475ba] Flux v0.14.6
  [f6369f11] ForwardDiff v0.10.36
  [d9f16b24] Functors v0.4.5
  [0c68f7d7] GPUArrays v9.0.0
  [46192b85] GPUArraysCore v0.1.5
  [7869d1d1] IRTools v0.4.10
  [22cec73e] InitialValues v0.3.1
  [92d709cd] IrrationalConstants v0.2.2
  [82899510] IteratorInterfaceExtensions v1.0.0
  [692b3bcd] JLLWrappers v1.5.0
  [b14d175d] JuliaVariables v0.2.4
  [63c18a36] KernelAbstractions v0.9.8
  [929cbde3] LLVM v6.2.1
  [2ab3a3ac] LogExpFunctions v0.3.26
  [d8e11817] MLStyle v0.4.17
  [f1d291b0] MLUtils v0.4.3
  [1914dd2f] MacroTools v0.5.11
  [128add7d] MicroCollections v0.1.4
  [e1d29d7a] Missings v1.1.0
  [872c559c] NNlib v0.9.6
  [77ba4419] NaNMath v1.0.2
  [71a1bf82] NameResolution v0.1.5
  [0b1bfda6] OneHotArrays v0.2.4
  [3bd65402] Optimisers v0.3.1
  [bac558e1] OrderedCollections v1.6.2
  [aea7be01] PrecompileTools v1.2.0
  [21216c6a] Preferences v1.4.1
  [8162dcfd] PrettyPrint v0.2.0
  [33c8b6b6] ProgressLogging v0.1.4
  [c1ae055f] RealDot v0.1.0
  [189a3867] Reexport v1.2.2
  [ae029012] Requires v1.3.0
  [efcf1570] Setfield v1.1.1
  [605ecd9f] ShowCases v0.1.0
  [699a6c99] SimpleTraits v0.9.4
  [a2af1166] SortingAlgorithms v1.1.1
  [dc90abb0] SparseInverseSubset v0.1.1
  [276daf66] SpecialFunctions v2.3.1
  [171d559e] SplittablesBase v0.1.15
  [90137ffa] StaticArrays v1.6.3
  [1e83bf80] StaticArraysCore v1.4.2
  [82ae8749] StatsAPI v1.7.0
  [2913bbd2] StatsBase v0.34.0
  [09ab397b] StructArrays v0.6.16
  [3783bdb8] TableTraits v1.0.1
  [bd369af6] Tables v1.11.0
  [28d57a85] Transducers v0.4.78
  [013be700] UnsafeAtomics v0.2.1
  [d80eeb9a] UnsafeAtomicsLLVM v0.1.3
  [e88e6eb3] Zygote v0.6.64
  [700de1a5] ZygoteRules v0.2.3
  [dad2f222] LLVMExtra_jll v0.0.25+0
  [efe28fd5] OpenSpecFun_jll v0.5.5+0
  [0dad84c5] ArgTools v1.1.1
  [56f22d72] Artifacts
  [2a0f44e3] Base64
  [ade2ca70] Dates
  [8ba89e20] Distributed
  [f43a241f] Downloads v1.6.0
  [7b1f6079] FileWatching
  [9fa8497b] Future
  [b77e0a4c] InteractiveUtils
  [4af54fe1] LazyArtifacts
  [b27032c2] LibCURL v0.6.3
  [76f85450] LibGit2
  [8f399da3] Libdl
  [37e2e46d] LinearAlgebra
  [56ddb016] Logging
  [d6f4376e] Markdown
  [a63ad114] Mmap
  [ca575930] NetworkOptions v1.2.0
  [44cfe95a] Pkg v1.9.2
  [de0858da] Printf
  [3fa0cd96] REPL
  [9a3f8284] Random
  [ea8e919c] SHA v0.7.0
  [9e88b42a] Serialization
  [6462fe0b] Sockets
  [2f01184e] SparseArrays
  [10745b16] Statistics v1.9.0
  [4607b0f0] SuiteSparse
  [fa267f1f] TOML v1.0.3
  [a4e569a6] Tar v1.10.0
  [8dfed614] Test
  [cf7118a7] UUIDs
  [4ec0a83e] Unicode
  [e66e0078] CompilerSupportLibraries_jll v1.0.5+0
  [deac9b47] LibCURL_jll v7.84.0+0
  [29816b5a] LibSSH2_jll v1.10.2+0
  [c8ffd9c3] MbedTLS_jll v2.28.2+0
  [14a3606d] MozillaCACerts_jll v2022.10.11
  [4536629a] OpenBLAS_jll v0.3.21+4
  [05823500] OpenLibm_jll v0.8.1+0
  [bea87d4a] SuiteSparse_jll v5.10.1+6
  [83775a58] Zlib_jll v1.2.13+0
  [8e850b90] libblastrampoline_jll v5.8.0+0
  [8e850ede] nghttp2_jll v1.48.0+0
  [3f19e933] p7zip_jll v17.4.0+0
julia> LinearAlgebra.BLAS.get_config()
ERROR: UndefVarError: `LinearAlgebra` not defined
Stacktrace:
 [1] top-level scope
   @ REPL[4]:1

You have to import LinearAlgebra first.

julia> import LinearAlgebra

julia> LinearAlgebra.BLAS.get_config()
LinearAlgebra.BLAS.LBTConfig
Libraries: 
└ [ILP64] libopenblas64_.0.3.21.dylib

Ha! I just tested it on my machine at home which has an Intel i5-6500 CPU, i.e. a Skylake like yourself, and now I can reproduce your observation. Previously I was on an Apple M1.

I think the matrix multiplications hit different code paths in OpenBLAS (on that CPU)?

You can strip away Flux and directly call BLAS.gemm to the same effect. Probably some reordering of operations is happening, loops unroll differently, I don’t know.

However, it’s ultimately inconsequential, because the difference is within machine precision. As a rule of thumb, one should compare floats with isapprox () anyway.

Edit: MKL behaves nicely: GitHub - JuliaLinearAlgebra/MKL.jl: Intel MKL linear algebra backend for Julia

3 Likes

The problem is that the numerical precision error compounds in a deep neural net… that’s just one Dense layer… there could be hundreds of layers! I don’t have a problem with machine precision error; it’s more with the expectation that changing the shape of the function input should change the shape of the function output… and nothing else. Otherwise, it’s not proper function composition. Just like changing the shape of an array… the coordinates should change, but not the bytes.

Thanks for your help!

1 Like