but I am a bit lost. My major challenge here is that for each layer of the network I need the trainable parameter ‘a’ of PReLU to be shared across the activations in that layer. So if the network has say 10 layers, then only 10 scalar trainable parameters should be added as a result (one for each layer).
I’m not super familiar with Flux but this seems to work, I just modified a bit Flux’s Dense layer:
using Flux.Tracker, NNlib, Flux
struct DensePRELU{S,T,K}
W::S
b::T
a::K
end
prelu(x,a) = x > 0 ? x : a*x
function DensePRELU(in::Integer, out::Integer;
initW = Flux.glorot_uniform, initb = zeros)
return DensePRELU(param(initW(out, in)), param(initb(out)), param(0.0))
end
Flux.treelike(DensePRELU)
function (a::DensePRELU)(x)
W, b, a = a.W, a.b, a.a
NNlib.@fix prelu.(W*x .+ b, a)
end
m = Chain(
DensePRELU(10, 2),
)
M = rand(2,10)
fake_data() = begin x=rand(10); y = M*x; (x,y) end
train = [fake_data() for i=1:100]
loss(x, y) = sum(abs2.(m(x) .- y))
opt = ADAM(params(m))
evalcb = () -> println( mean( loss(d...) for d in train) )
for i=1:10 Flux.train!(loss, train, opt, cb=evalcb) end
Maybe there’s a way to directly do it with the default Dense layer.
@jonathanBieler, Did you try running this? This produces the following error in the current stable version of Flux (v0.5.1):
> opt = ADAM(params(m))
MethodError: Cannot `convert` an object of type Flux.Tracker.TrackedReal{Float64} to an object of type Flux.Optimise.Param
This may have arisen from a call to the constructor Flux.Optimise.Param(...),
since type constructors fall back to convert methods.
in ADAM at Flux/src/optimise/interface.jl:56
in optimiser at Flux/src/optimise/interface.jl:6
in collect at base/array.jl:476
in collect_to! at base/array.jl:518
in collect_to! at base/array.jl:508
Thanks @jonathanBieler!
I’ve modified your solution below to abstract it away from a particular layer, so that now it can be called with BatchNorm for example.
using Flux
using Flux: Tracker, treelike
using NNlib
struct PReLU{T}
a::T
end
PReLU(init::Real) = PReLU(param([init/1]))
PReLU() = PReLU(0.0)
treelike(PReLU)
prelu(x, a) = x > 0 ? x : a*x
function (f::PReLU)(x)
NNlib.@fix prelu.(x, f.a)
end
m = Chain( Dense(10, 2), PReLU())
M = rand(2,10)
fake_data() = begin x=rand(10); y = M*x; (x,y) end
train = [fake_data() for i=1:100]
loss(x, y) = sum(abs2.(m(x) .- y))
opt = ADAM(params(m))
evalcb = () -> println( mean( loss(d...) for d in train) )
@time for i=1:10 Flux.train!(loss, train, opt, cb=evalcb) end