Hi @ToucheSir, thanks for the pointers. I tried to follow the solution you proposed and developed a resample_model
(Method 2 below) function for the same.
I tried to benchmark both to see which one is better. I was under the assumption that in-place update of parameters should do better. Below are the necessary snippets.
# Function to build model
function LeNet5(; initializer=Flux.glorot_uniform, act_fn=relu, im_size=(28, 28, 1), n_classes=10)
out_conv_size = (im_size[1] ÷ 4 - 3, im_size[2] ÷ 4 -3, 16)
return Chain(
Conv((5, 5), im_size[end] => 6, act_fn, init=initializer),
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, act_fn, init=initializer),
MaxPool((2, 2)),
flatten,
Dense(prod(out_conv_size), 120, act_fn, init=initializer),
Dense(120, 84, act_fn, init=initializer),
Dense(84, n_classes, init=initializer)
)
end
Method 1: Call build function everytime
function test_reinit_1()
model = LeNet5()
for i in 1:N
model = LeNet5()
end
end
Method 2: Inplace parameter update
function resample_model!(model)
for layer ∈ model
if hasproperty(layer, :weight)
layer.weight .= randn(size(layer.weight))
end
if hasproperty(layer, :bias)
layer.bias .= randn(size(layer.bias))
end
end
end
function test_reinit_2()
model = LeNet5()
for i in 1:N
resample_model!(model)
end
end
Results
Below is the benchmark results for 'N = 100`
julia> @benchmark test_reinit_1()
BenchmarkTools.Trial:
memory estimate: 34.66 MiB
allocs estimate: 12322
--------------
minimum time: 6.986 ms (0.00% GC)
median time: 9.145 ms (14.77% GC)
mean time: 9.143 ms (12.06% GC)
maximum time: 12.402 ms (12.52% GC)
--------------
samples: 546
evals/sample: 1
julia> @benchmark test_reinit_2()
BenchmarkTools.Trial:
memory estimate: 34.86 MiB
allocs estimate: 9522
--------------
minimum time: 16.408 ms (0.00% GC)
median time: 23.207 ms (6.69% GC)
mean time: 22.937 ms (7.06% GC)
maximum time: 29.082 ms (9.05% GC)
--------------
samples: 218
evals/sample: 1
It seems like method 1 (calling model building function over and over again) is better than my current implementation of method 2. This looks odd as I was expecting opposite results.
Any suggestions how I can speed up method2 here?