I am trying to build a Temporal Convolution Network. First we define the Residual Block as main element of the network. Here is it:
struct ResidualBlock{T}
dilation::Int
num_channels::Int
kernel_size::Int
dilated_conv::Conv{T}
residual_conv::Conv{T}
skip_conv::Conv{T}
end
function ResidualBlock(dilation, num_channels, kernel_size)
dilated_conv = Conv((1,), num_channels => num_channels, dilation=dilation, pad=(kernel_size - 1) * dilation)
residual_conv = Conv((1,), num_channels => num_channels)
skip_conv = Conv((1,), num_channels => num_channels)
return ResidualBlock(dilation, num_channels, kernel_size, dilated_conv, residual_conv, skip_conv)
end
function (rb::ResidualBlock)(x)
dilated_output = rb.dilated_conv(x)
dilated_output = tanh.(dilated_output) .* relu.(dilated_output)
skip_output = rb.skip_conv(dilated_output)
residual_output = rb.residual_conv(dilated_output)
output = x + residual_output
return output, skip_output
end
Then, we define the NN using the struct above:
struct WaveNet{T}
num_blocks::Int
num_layers::Int
num_channels::Int
num_classes::Int
entry_conv::Conv{T}
residual_blocks::Vector{ResidualBlock}
relu::typeof(relu)
tanh::typeof(tanh_fast)
skip_conv::Conv{T}
output_conv1::Conv{T}
output_conv2::Conv{T}
end
function WaveNet(num_blocks, num_layers, num_channels, num_classes, kernel_size=2)
entry_conv = Conv((kernel_size,), num_classes => num_channels )
residual_blocks = ResidualBlock[]
for _ in 1:num_blocks
for layer in 1:num_layers
dilation = 2^(layer-1)
push!(residual_blocks, ResidualBlock(dilation, num_channels, kernel_size))
end
end
relu = Flux.relu
tanh = Flux.tanh_fast
skip_conv = Conv((1,), num_channels => num_channels)
output_conv1 = Conv((1,), num_channels => num_channels)
output_conv2 = Conv((1,), num_channels => num_classes)
return WaveNet(num_blocks, num_layers, num_channels, num_classes, entry_conv, residual_blocks, relu, tanh, skip_conv, output_conv1, output_conv2)
end
function (wn::WaveNet)(x)
#println("INPUT ", size(x))
x = wn.entry_conv(x)
#println("FIRST CONV ", size(x))
skip_sum = zeros(Float32, size(x)[1], wn.num_channels, 1)
for (i, block) in enumerate(wn.residual_blocks)
x, skip = block(x)
skip_sum += skip
end
x = wn.relu(skip_sum)
x = wn.skip_conv(x)
x = wn.relu(x)
x = wn.output_conv1(x)
x = wn.relu(x)
x = wn.output_conv2(x)
return x
end
Now here is the training function of this network:
Flux.@functor ResidualBlock
Flux.@functor WaveNet
function train(model::WaveNet, in_file, out_file, batch_size, lr, epochs)
opt = Flux.setup(Adam(lr), model)
data = JLD2.load("./data/data.jld2")
x_train, y_train, x_valid, y_valid, x_test, y_test =
data["x_train"], data["y_train"], data["x_valid"], data["y_valid"], data["x_test"], data["y_test"]
train_num_batches = size(x_train)[1]/batch_size
train_data = Flux.Data.DataLoader((x_train, y_train), batchsize=batch_size)
valid_data = Flux.Data.DataLoader((x_valid, y_valid), batchsize=batch_size)
loss(ŷ, y) = Flux.Losses.mse(ŷ, y)
best_loss = Inf
for epoch in 1:epochs
for (i,data) in enumerate(train_data)
if i > train_num_batches
break
end
# Unpack this element (for supervised training):
x, y= data
x = x[:,:,1:1]
y = y[:,:,1:1]
Flux.trainmode!(model,true)
# Calculate the gradient of the objective
# with respect to the parameters within the model:
∇ = Flux.gradient(model) do m
ŷ = m(x)
loss(ŷ, y)
end
# Update the parameters so as to reduce the objective,
# according the chosen optimisation rule:
Flux.update!(opt, Flux.params(model), ∇[1])
end
train_loss = mean(loss(model(x_train[:,:,1:1]), y_train))
valid_loss = mean(loss(model(x_valid[:,:,1:1]), y_valid))
if valid_loss < best_loss
best_loss = valid_loss
best_model = Flux.state(model)
end
println("Epoch $epoch | Train Loss: $train_loss, Valid Loss: $valid_loss")
end
jldsave("model.jld2"; model_state)
return best_model
end
I am using the following parameters in the training function
batch_size = 512
lr = 1e-1
epochs = 40
num_blocks = 16
num_layers = 8
num_channels = 1
num_classes = 1
kernel_size = 1
The problem here is that the training and validation loss are not being updated during the training loop. It is my first time using this framework and I don’t know where else look for errors.
Here the first 3 epochs output:
┌ Warning: Layer with Float32 parameters got Float64 input.
│ The input will be converted, but any earlier layers may be very slow.
│ layer = Conv((1,), 1 => 1) # 2 parameters
│ summary(x) = "512×1×1 Array{Float64, 3}"
└ @ Flux ~/.julia/packages/Flux/EHgZm/src/layers/stateless.jl:60
Epoch 1 | Train Loss: 0.06058520324398762, Valid Loss: 0.044346793228467046
Epoch 2 | Train Loss: 0.06058520324398762, Valid Loss: 0.044346793228467046
Epoch 3 | Train Loss: 0.06058520324398762, Valid Loss: 0.044346793228467046
Would be grateful for any advice about the problem or the code!