Hi all! I had a strange problem using Flux RNN, my training data contains myX:one-hot vector, and myY:a number. The training data shown below worked very well using feedforward network(epoch=20,R2=0.9), but very low using Flux RNN(epoch=200,R2=0.2), what’s more, I am sure it is not the model architecture, because it trained well for other training data(R2=1)refer to as demoX demoY.
I also found that the problem is all about my X, because my RNN network worked also well on [demoX, myY],[demoX,demoY], but not [myX,myY],[myX,demoY].
below is the code.
using Flux
oriX=["ATAGGAGGCGCGTGACAGAGTCCCTGTCCAATTACCTACCCAAA", "ATAGGAGGCGCAAGAGAGAAGCCCAGACCAATAACCTACCCAAA", "ATAGGAGGCTAACGAGAGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGAGGCGCCTGAGAAAAGCCCAGACCAATTACCTACCCAAA", "ATAGGACGCGCATGAGAGATGCCCTGACCAATTACCTACCCAAA", "ATAGGTGGTGCATGAGATAAGCACAGCTCAATACCCTACCCAAA", "ATAGGAGACGCAGGGGCGAAGCCCGGACCATTTACCTACCCAAA", "ATAGGTGGTGCATGAGATAATCCCAGACCAATTACCTACCCAAA", "ATAGGAGGCTCATGAGATAAGGCTTGACCAATTACCTACCCAAA", "ATAGGAGGCTCATGAGAGCAGCCCAGATTAATTACCTACCCAAA", "ATAGGAGGCGCGTGAGAGAGGACCCGACCAATTACTCACCCAAC", "ATAGGCAGCGCATGAGAGAAGCCCAGACCAATTACCTACTCAAC", "ATAGGAGGCTAACGAGAGAAGCCCAGACCACTTACCTACCCAAA", "ATAGGAGGCGCATGAGAAAAGCCCCGCCCAATTACCTACCCAAG", "ATAGGCGGCGCTTGAGAGAAGCCCATACCCATTACCTACCCAAA", "ATAGGCGGCACATGAGACAAGCCGAAGCCAATTACCTACCCAAA", "ATAGGCTGCGCATGAGAGAAGGCGACACAAATTACCTACCCAAA", "ATAGGCGGCACATGAGAGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGTGCAAGAGAGACGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGCGTATGAGAGAAGCCCAGCTCAATTACCTACCCAAA", "ATAGGAGGCGCATGAGATAACCCACCACCAAGTACCTACTCAAA", "ATAGGTGGCGCATGAGAGCACCTCAGACGAAGTACCTACCCAAA", "ATAGGCGGCGCATGAGATAAGCCTAGACCATTTACCTACCCAAA", "ATAGGTGGCGCATGAGATAAGCGCATAACACCAACTTACCCAAC", "ATAGGCGGCGCATGAGACAAATCCAGGCCAATTATCTACCCAAA", "ATAGGCGGCTCATGAGATAAGCCCAGACCAAATACCTACCCAAA", "ATAGGAGGCGCATGAGAGAATCCCAAACCAATTCCCTACAAACC", "ATAGGCGGCGCATGAGACAAGCCCATACCAATTACCTACCCAAA", "ATAGGTGCGACTTGAGAGATGCCCATATCGACTACCTACCCGAA", "ATAGGCGGTGCATGACTGACGCCCAGACCAATTACCTACCCAAA", "ATAGGGGGCTAATGAGAGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGAGGCGCATGAGATAAGCCCAGACCAATTACCTACCCCGA", "ATAGGAGGTGCACGAGAGTTGCCCAGACCAATTAACTTCCCAAA", "ATAGGCGGCGCATGAGAAAAGCCCAGACCAATTACCTACCCAAA", "ATAGGTGGCCCGCGAGTTAGGACGAGACTAATTCCCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCCATTACCTACCCAAA", "ATAGGCGGCGGACGAGAGAAGCCCAGACCAATTACCTACCCATA", "ATAGGTGGCGCATGAAATAAAACCAGTGCAATTACCTACCCATA", "ATAGGCGACGCATGAGAAAAGCCCAGACCCATTACCTACCCAAA", "ATAGGCTGCGCATGAGAGAAGCCCAGACCAATTATCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAATTCCCTACCCAAA", "ATAGGAGTCGCCTGACAGATGACCATACCAATTACCTATCCAAA", "ATAGGCCGCGGATTAGACAACATCTTACCAATTCCCTGCCCAAA", "ATAGGCGGTGCAAGAGCGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGAAGCTAAAGGGAGTAGCTCAGTACAGTTAACTACCCCAA", "ATAGGCCGCGCATGAGAGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAGTTACCTACCCAAA", "ATAGGAAGCGCATGAGAAAAGCCCAGACAAATCACCTACCGAAC", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAGTTACCTACCCAAC", "ATAGGCGGCACATGAGCGCAGCCCAGTCCAATTACCTACCCAAA", "ATAGGCGGCGCATGACACAGGCCCAGACCAATGACCTACCCAAA", "ATAGGCAGCGCATGAGAGAAGCCCAGACCAATTACCTACTCAAA", "ATAGGCGACGAATGAGTGAAGCCCACATTAATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAACTACATACCCAAA", "ATAGGCGGCGCATGAGACAAATCCAGGCCAATTACCTACCCAAA", "ATAGGCCGCCGATGAGAAAAGCCCGACGCACTTAACTACCCGAA", "ATAGGCGGTGCATGAGAGAGACGCAGTGCAAATACCTACCCAAA", "ATAGGCGGCGGATTAGAGAAGTCCAGACTATTTACCTACCCAAA", "ATAGGCGGCGAATGAGAGAAGCCCAGACCAATTACCTACCCAGA", "ATAGGCGGCGCATGAGATAAGCCCAGTCGAATTACCTACCCAAA", "ATAGGCCGCGCATGAGAAAAGCCTAGACCAATTGCCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAATTACCTACCCACA", "ATAGGCGACGCATGAGAGAAGCCCAGACGAATTACCTACCCAAA", "ATAGGTCCAGCATTAAGGCAGGCCAGACCCTTTACCTACCCAAA", "ATAGGAGGGACATGCGATAGGCTCAGACCAATTTCCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGGCCAATTAACTACCCAAA", "ATAGGCGGCGCATGAGAGTAGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGCGCATGAGGGAAGCCCAGACCCATTCCCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAATTAGCTACCCAAA", "ATAGGCGACGTATGAGAGAATCCCTGACCATTTACCTACCCAAA", "ATAGGCGGCGCATGATATAAGCCCAGCCCAATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACATATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCTAAGACCCATTACCTACCCAAA", "ATAGGCCGCGCATGAGAGAAGCTCAGACCCATTACCTACCCAAA", "ATAGGTGGCGCATGAGAGAAGCCCAGACCAATTACCTACACAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAATTACCTGCCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGTCCGATTTACTACCCAAG", "ATAGGCGGAGCATATGAGATGCCCAGACCAAATACCTACCCAAA", "ATAGGCGGCGCATGACAGAAGCCCTGACCGATAACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGAGCAATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCTCAGACCAATTACCTACCCAAA", "ATAGGTGTCGCTTGAAAATAGCCCAGACGAATTACCTACCCAAA", "ATAGGCGGCGCATGAGCGTTGCACAGACCAATTACCTACCCAAA", "ATAGGCGGCGTATGAGAGAAGCGCGGCCCAATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAGGCCCTGACCAAATAACTACCCAAA", "ATAGGCGGCTCATGAGAGAAGCCCAGACCAACTGCCTACCCAAA", "ATAGGCAGCGCATGAGTGAAGCCCAGACCAGTTACCTCCCCAAA", "ATAGGCAGCAGATGACAGTAGCCCCGACCAAATTACTACTCAAA", "ATAGGCGGCGCATGAGAGGAGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGCCCAGGAGAGCATCCAAGACCAATTACCTACCCAAA", "ATAGGCGGCGCATACGAGAAGCGCAGACTAATTACCTACCCAAA", "ATAGGCGGCGCATGACATAAGCCTAGATCAATTACCTACCCAAG", "ATAGGCGGCACATGACACAGGCCCAGACCAATGACCTACCCAAA", "ATAGGCGGCGCAGGAGAGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGCGCGTAAGAGAAGCCTAGACAAATTACCTACCGAAA", "ATAGGCGACGCATCTGCGAATCCCACACCAATTACCTACCCGAA", "ATAGGCGACGCATGAGAGAAGCCCAGACCAATTAACTATCCATC", "ATAGGCGGCGCATGAGAGCAGCCCAAACCAATTGCCTACCCAAA", "ATAGGCGGCGCGTGAGAGTAGCCCTGACCAGTTTCCTGCCCAAA"]
ytrain = Float32[3.5539734 2.7561886 2.8113236 2.7176633 2.7606876 2.6220577 2.3115172 2.4817004 1.9276409 2.5030751 1.8989105 1.6381569 3.112245 1.9992489 1.8364053 2.1545537 1.8151137 1.9761252 2.0710406 1.8238684 1.5769696 2.2978039 2.0652819 1.6795048 1.4621212 1.8550924 1.3247801 2.0052798 1.5950761 2.1166725 1.1718857 1.443101 1.4597932 2.0249891 1.659723 1.7782362 1.3042092 1.3574703 1.7164876 1.4561561 1.6886593 1.5327756 1.3272716 1.2478243 1.6909612 0.9371975 1.3504946 1.7342895 1.0429348 1.6653012 1.6186994 1.6343817 1.1894267 1.6500783 1.1910686 1.5190029 0.93479043 1.5677443 1.2633525 1.4441946 1.8120437 1.6296253 1.3869075 1.7520566 1.247555 1.4638474 1.4413416 1.5457458 1.3801547 1.312296 0.96203357 1.571632 0.2540248 1.0096036 0.8302187 0.73939687 1.4816427 1.1275434 1.1184824 1.3548776 1.3924822 1.2923665 0.9824461 1.2085876 1.3007151 1.4721189 1.3741052 0.7266495 0.5496262 1.3403294 0.931344 0.7101498 1.3628994 1.8999943 1.2633573 1.1379782 0.6508444 0.5403087 1.435614 1.319527]
Xtrain = [map(x -> Flux.onehot.(x, "ACGT"), collect(join(oriX[idx]))) for idx in 1:100]
Xtrain_ffnn = hcat([vcat(x...) for x ∈ Xtrain]...)
# lossFunction and accuracy
function accuracy(m, X, y)
Flux.reset!(m) # Only important for recurrent network
R²(y, m(X))
end
function lossFun(m, X, y)
Flux.reset!(m) # Only important for recurrent network
Flux.mse(m(X),y)
end
# first learn the train data on feedforward
ffnn = Chain(
Dense(176 => 128, relu),
Dense(128 => 128, relu),
Dense(128 => 1)
)
opt_ffnn = ADAM()
θ_ffnn = Flux.params(ffnn) # Keep track of the trainable parameters
epochs = 100 # Train the model for 100 epochs
for epoch ∈ 1:epochs
# Train the model using batches of size 32
for idx ∈ Iterators.partition(shuffle(1:size(Xtrain_ffnn, 2)), 32)
X, y = Xtrain_ffnn[:, idx], ytrain[:, idx]
∇ = gradient(θ_ffnn) do
# Flux.logitcrossentropy(ffnn(X), y)
Flux.mse(ffnn(X),y)
end
Flux.update!(opt_ffnn, θ_ffnn, ∇)
end
X, y = Xtrain_ffnn, ytrain
@show accuracy(ffnn, Xtrain_ffnn, ytrain)
end
# then learn the train data by seq2one(RNN)
struct Seq2One
rnn # Recurrent layers
fc # Fully-connected layers
end
Flux.@functor Seq2One # Make the structure differentiable
# Define behavior of passing data to an instance of this struct
function (m::Seq2One)(X)
# Run recurrent layers on all but final data point
[m.rnn(x) for x ∈ X[1:end-1]]
# Pass last data point through both recurrent and fully-connected layers
m.fc(m.rnn(X[end]))
end
# Create the sequence-to-one network using a similar layer architecture as above
seq2one = Seq2One(
Chain(
RNN(4 => 128, relu),
RNN(128 => 128, relu)
),
Dense(128 => 1)
)
opt_rnn = ADAM()
θ_rnn = Flux.params(seq2one) # Keep track of the trainable parameters
epochs = 200 # Train the model for 10 epochs
for epoch ∈ 1:epochs
# Train the model using batches of size 32
for idx ∈ Iterators.partition(shuffle(1:size(Xtrain, 1)), 32)
Flux.reset!(seq2one) # Reset hidden state
X, y = Xtrain[idx], ytrain[:, idx]
X = [hcat([x[i] for x ∈ X]...) for i ∈ 1:seqlen] # Reshape X for RNN format
∇ = gradient(θ_rnn) do
# Flux.logitcrossentropy(seq2one(X), y)
Flux.mse(seq2one(X),y)
end
Flux.update!(opt_rnn, θ_rnn, ∇)
end
X, y = [hcat([x[i] for x ∈ Xtrain]...) for i ∈ 1:seqlen], ytrain
@show accuracy(seq2one, X, y)
end
Hope anyone could help! Thanks!