Autodiff on image convolution

See https://github.com/FluxML/NNlib.jl/blob/master/src/conv.jl

Nnlib is what’s used by Flux.

3 Likes

I think in this case it is best to use the explicit form of the gradient instead of using AD.

Why won’t you use that?

1 Like

@baggepinnen Thank you, i think this will solve my issues.

@RoyiAvital That is also a good idea to solve the problem, but my model will probably get a lot more complicated, thus i’d rather exploit AD than writing the gradient myself (for the simple problem at hand it is actually very easy to compute it, but if my model becomes more involved, a manual computation of the gradient is going to be a bottleneck in productivity)

1 Like

New issue, i cannot manage to make it train the model…

MWE:

using Flux,LinearAlgebra
function MWE()
    η=0.01
    model=Conv((64, 64), 1=>1, relu,stride=30)
    loss(x,y) = Flux.mse(model(x),y)
    dataset=[(rand(100,100,1,1),rand(2,2,1,1))]
    @show model(dataset[1][1])
    @show loss(dataset[1]...)
    for i in 1:10
        Flux.train!(loss,params(model), dataset, ADAM(η))
        @show i, loss(dataset[1]...)
    end
end
MWE (generic function with 1 method)
MWE()
model((dataset[1])[1]) = [2.74843 1.46427; 0.0 1.95925] (tracked)
loss(dataset[1]...) = 2.7275042760204453 (tracked)
(i, loss(dataset[1]...)) = (1, 0.2400080973923382 (tracked))
(i, loss(dataset[1]...)) = (2, 0.2400080973923382 (tracked))
(i, loss(dataset[1]...)) = (3, 0.2400080973923382 (tracked))
(i, loss(dataset[1]...)) = (4, 0.2400080973923382 (tracked))
(i, loss(dataset[1]...)) = (5, 0.2400080973923382 (tracked))
(i, loss(dataset[1]...)) = (6, 0.2400080973923382 (tracked))
(i, loss(dataset[1]...)) = (7, 0.2400080973923382 (tracked))
(i, loss(dataset[1]...)) = (8, 0.2400080973923382 (tracked))
(i, loss(dataset[1]...)) = (9, 0.2400080973923382 (tracked))
(i, loss(dataset[1]...)) = (10, 0.2400080973923382 (tracked))

You may be able to just encapsulate the particular calculation in a function, calculate the derivative, and integrate it into the AD system you wish to use. I have done this for simple functions in ForwardDiff, for a case which worked OK but was much faster with the adjoint method.

If you want to get help on this, it may help to provide an MWE for the particular AD system you want to use. You did not provide the dimensions, but I imagine that for typical sizes, forward mode would make sense.

I would already be happy to have it work with Flux predefined objects, anyway here is what happened with “imfilter” approach

using Images
K=rand(11,11)
img=randn(100,100)
img2=randn(100,100)
m(x,ker)=imfilter(x,centered(ker))
loss(x,y,ker)=norm(y.-m(x,ker))
loss (generic function with 1 method)
loss(img,img2,K)
759.0782976571238
import AutoGrad
AutoGrad.grad(x->loss(img,img2,x),1)(K)
MethodError: no method matching centered(::Param{Array{Float64,2}})
Closest candidates are:
  centered(!Matched::AxisArray) at /home/aizen/.julia/packages/ImageFiltering/8fmJ4/src/ImageFiltering.jl:95
  centered(!Matched::ImageMeta) at /home/aizen/.julia/packages/ImageFiltering/8fmJ4/src/ImageFiltering.jl:98
  centered(!Matched::AbstractArray) at /home/aizen/.julia/packages/ImageFiltering/8fmJ4/src/utils.jl:15
import ForwardDiff
ForwardDiff.gradient(x->loss(img,img2,x),K)
MethodError: no method matching svd!(::Array{ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##14#15")),Float64},Float64,11},2}; full=false)
Closest candidates are:
  svd!(!Matched::LinearAlgebra.AbstractTriangular) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/triangular.jl:2471 got unsupported keyword argument "full"
  svd!(!Matched::Union{DenseArray{T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64},2}, ReinterpretArray{T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64},2,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray}, ReshapedArray{T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64},2,A,MI} where MI<:Tuple{Vararg{SignedMultiplicativeInverse{Int64},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray} where N where T, SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray}, SubArray{T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64},2,A,I,L} where L where I<:Tuple{Vararg{Union{Int64, AbstractRange{Int64}, AbstractCartesianIndex},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray} where N where T, ReshapedArray{T,N,A,MI} where MI<:Tuple{Vararg{SignedMultiplicativeInverse{Int64},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray} where N where T, SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray} where N where T, DenseArray}}; full) where T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64} at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/svd.jl:59
  svd!(!Matched::Union{DenseArray{T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64},2}, ReinterpretArray{T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64},2,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray}, ReshapedArray{T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64},2,A,MI} where MI<:Tuple{Vararg{SignedMultiplicativeInverse{Int64},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray} where N where T, SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray}, SubArray{T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64},2,A,I,L} where L where I<:Tuple{Vararg{Union{Int64, AbstractRange{Int64}, AbstractCartesianIndex},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray} where N where T, ReshapedArray{T,N,A,MI} where MI<:Tuple{Vararg{SignedMultiplicativeInverse{Int64},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray} where N where T, SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray} where N where T, DenseArray}}, !Matched::Union{DenseArray{T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64},2}, ReinterpretArray{T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64},2,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray}, ReshapedArray{T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64},2,A,MI} where MI<:Tuple{Vararg{SignedMultiplicativeInverse{Int64},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray} where N where T, SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray}, SubArray{T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64},2,A,I,L} where L where I<:Tuple{Vararg{Union{Int64, AbstractRange{Int64}, AbstractCartesianIndex},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray} where N where T, ReshapedArray{T,N,A,MI} where MI<:Tuple{Vararg{SignedMultiplicativeInverse{Int64},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray} where N where T, SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray} where N where T, DenseArray}}) where T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64} at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/svd.jl:274 got unsupported keyword argument "full"
import Zygote
Zygote.gradient(x->loss(img,img2,x),K)
Compiling Tuple{typeof(imfilter!),Array{Float64,2},Array{Float64,2},Tuple{OffsetArrays.OffsetArray{Int64,2,Array{Int64,2}},OffsetArrays.OffsetArray{Float64,2,Array{Float64,2}}},Pad{0},ImageFiltering.Algorithm.FFT}: try/catch is not supported.
import Flux
Flux.Tracker.gradient(x->loss(img,img2,x),K)
MethodError: no method matching Float64(::Tracker.TrackedReal{Float64})
Closest candidates are:
  Float64(::Real, !Matched::RoundingMode) where T<:AbstractFloat at rounding.jl:194
  Float64(::T<:Number) where T<:Number at boot.jl:741
  Float64(!Matched::Int8) at float.jl:60
  ...