Here is a minimal working example:
function ResNetBlock(n::Int)
return Chain(
Conv((3, 3), n => n, relu; pad=1, stride=1),
BatchNorm(n, relu),
Conv((3, 3), n => n; pad=1, stride=1),
BatchNorm(n),
)
end
function simplenet(n_filter)
return Chain(
Conv((3, 3), 4 => n_filter, relu; stride=1, pad=1),
Conv((3, 3), n_filter => n_filter, relu; stride=1, pad=1),
Conv((3, 3), n_filter => n_filter, relu; stride=1, pad=1),
Conv((3, 3), n_filter => n_filter, relu; stride=1, pad=1),
Conv((3, 3), n_filter => n_filter, relu; stride=1, pad=1),
Conv((3, 3), n_filter => n_filter, relu; stride=1, pad=1),
Conv((1, 1), n_filter => 32, relu; stride=1, pad=0),
Flux.flatten,
Dense(32 * 49, 128, relu),
Dense(128, 1, tanh),
)
end
function resnet(n_filter)
return Chain(
Conv((3, 3), 4 => n_filter, relu; stride=1, pad=1),
ResNetBlock(n_filter),
ResNetBlock(n_filter),
ResNetBlock(n_filter),
ResNetBlock(n_filter),
Conv((1, 1), n_filter => 32, relu; stride=1, pad=0),
Flux.flatten,
Dense(32 * 49, 128, relu),
Dense(128, 1, tanh),
)
end
function loss(model, x, y)
return Flux.mse(model(x), y)
end
data=[(rand(Float32,7,7,4,1),0.7f0) for k in 1:10]
model1=simplenet(4)
opt1=Flux.setup(Adam(),model1)
model2= resnet(4)
opt2=flux.setup(Adam(),model2)
for k in 1:10;Flux.train!(loss,model1,data,opt1);
for k in 1:10;Flux.train!(loss,model2,data,opt2);
First I know this is not trully a resnet, came to this investigating my problem. If you try model1(data[4][1]) will be close to 0.7 as expected, but same with model 2 is always 0.999 and get to 1 with more epochs. Same on gpu.
I use latest Flux.