Hi All,
I’m trying to port this example of a recurrent neural network in PyTorch to Flux to help me learn the API. I know that I’m not putting the data together with the loss function in the right way (I’m using the char-rnn model from the model zoo as a guide), but I was wondering whether anyone would chip in to see where I’m going wrong. Apart from only going through a single train of a minibatch, I’m trying to stay as faithful to the original implementation as possible. The code below gives
MethodError: no method matching isless(::TrackedArray{…,Array{Float64,2}}, ::Array{Float64,2}) Closest candidates are: isless(!Matched::Missing, ::Any) at missing.jl:66 isless(::Any, !Matched::Missing) at missing.jl:67
using Flux
using Flux: chunk, batchseq, onehot, onehotbatch, mse
using StatsBase: sample, wsample
# using CuArrays
# Make simulated sequence
bases = ['A','C','G','T']
alphabet = [bases;'_']
seq_len = 220
seq = [sample(bases) for i in 1:seq_len]
seq = join(seq)
function sim_error(seq,pins=0.05,pdel=0.05,psub=0.01)
out_seq = []
for c in seq
while true
r=rand()
if r < pins
push!(out_seq,sample(bases))
else
break
end
end
r -= pins
if r < pdel
continue
end
r -= pdel
if r < psub
push!(out_seq,sample(bases))
continue
end
push!(out_seq,c)
end
return join(out_seq)
end
num_sim = 20
seqs = [sim_error(seq) for i in 1:num_sim]
max_len = maximum([length(s) for s in seqs])
# Generate one-hot
input_t = [onehotbatch(input[1:(end-1)],bases) for input in seqs]
output_t = [onehotbatch(input[2:end],bases) for input in seqs]
# Define model
hidden_dim = 32
layer1_dim = 12
layer2_dim = 12
num_bases = 4
m = Chain(
LSTM(num_bases, hidden_dim),
Dense(hidden_dim,layer1_dim),
relu,
Dense(layer1_dim,layer2_dim),
relu,
Dense(layer2_dim,num_bases)
)
# Define MSE loss
function loss(xs, ys)
l = sum(Flux.mse.(m.(xs)), ys)
Flux.truncate!(m)
return l
end
# Set optimiser
lr = 0.1
opt = SGD(params(m), lr)
# Train one minibatch
mini_batch_size = 5
idx = [sample(1:num_sim) for x in 1:mini_batch_size]
train = [(input_t[i], output_t[i]) for i in idx]
Flux.train!(loss, train, opt)