Flux: 1D convolutions (on genomic data)

Some small changes in the one-hot encoding of the DNA sequences and reshaping of the encoded array. Also some changes in the network structure.

using Flux
using Random
using Downloads
using Flux.Losses: logitcrossentropy
using Flux: onehotbatch
using Flux: onecold
using Flux.Data: DataLoader

#%%

bases_dna = ['A', 'C', 'G', 'T']

function ohe_row(sequence)
    return collect(sequence) .== permutedims(bases_dna)
end

function onehot(sequences)
    L = 50
    N_bases = 4
    out = BitArray(undef, L, N_bases, size(sequences, 1))
    for (i, sequence) in enumerate(sequences)
        out[:, :, i] = reshape(ohe_row(sequence), (L, N_bases, 1))
    end
    return out
end

#%%

function get_data()

    url_seqs = "https://raw.githubusercontent.com/abidlabs/deep-learning-genomics-primer/master/sequences.txt"
    sequences_raw = split(String(take!(Downloads.download(url_seqs, IOBuffer()))));

    url_labels = "https://raw.githubusercontent.com/abidlabs/deep-learning-genomics-primer/master/labels.txt"
    labels_raw = parse.(Int64, split(String(take!(Downloads.download(url_labels, IOBuffer())))));

    # One-hot-encode the sequences
    hotseq = Float32.(onehot(sequences_raw))
    sequences = reshape(hotseq,(50,4,1,:));

    # One-hot-encode the labels
    labels = onehotbatch(labels_raw, 0:1)

    N_sequences = last(size(sequences))
    idxs = shuffle(1:N_sequences);
    idx_split = Int(N_sequences*0.75)
    train = idxs[1:idx_split];
    test = idxs[idx_split+1:end];

    # Create DataLoaders (mini-batch iterators)
    train_loader = DataLoader((sequences[:, :,:, train], labels[:, train]), batchsize=32, shuffle=true);
    test_loader = DataLoader((sequences[:, :, :,test], labels[:, test]), batchsize=32);

    return train_loader, test_loader

end

#%%

function build_model()
    init = Flux.glorot_uniform()
    front = Chain(
        Conv((8,1), 1 => 16,
            pad = 1,
            bias = true,
            init = init
        ),
        MaxPool((4,1)),
        Conv((8, 1), 16=>32, relu),
        MaxPool((4,1)),
        Flux.flatten,
    #)
    #d = Flux.outputsize(front,(50,4,1,32)) # 192
    #Chain(
    #    front,
        Dense(
            192,
            16,
            relu),
        Dense(
            16,
            2),
    )
end



function loss_and_accuracy(data_loader, model)
    acc = 0
    ls = 0.0f0
    num = 0
    for (x, y) in data_loader
        ŷ = model(x)
        ls += logitcrossentropy(ŷ, y, agg=sum)
        acc += sum(onecold(ŷ) .== onecold(y))
        num +=  size(x)[end]
    end
    return ls / num, acc / num
end

loss(ŷ, y) = logitcrossentropy(ŷ, y)

function train()

    train_loader, test_loader = get_data()

    # Construct model
    model = build_model()
    ps = Flux.params(model) # model's trainable parameters;

    ## Optimizer
    opt = ADAM(0.001, (0.9, 0.999))

    train_losses = Float64[]
    train_accuracies = Float64[]
    test_losses = Float64[]
    test_accuracies = Float64[]

    ## Training
    epochs = 50
    for epoch in 1:epochs

        for (x, y) in train_loader

            gs = Flux.gradient(ps) do
                ŷ = model(x)
                loss(ŷ, y)
            end

            # gs = gradient(() -> logitcrossentropy(model(x), y), ps) # compute gradient
            Flux.Optimise.update!(opt, ps, gs) # update parameters
        end

        # Report on train and test
        train_loss, train_acc = loss_and_accuracy(train_loader, model)
        test_loss, test_acc = loss_and_accuracy(test_loader, model)
        println("Epoch=$epoch")
        println("  train_loss = $train_loss, train_accuracy = $train_acc")
        println("  test_loss = $test_loss, test_accuracy = $test_acc")

        push!(train_losses, train_loss)
        push!(train_accuracies, train_acc)

        push!(test_losses, test_loss)
        push!(test_accuracies, test_acc)

    end

    return train_losses, train_accuracies, test_losses, test_accuracies

end

train_losses, train_accuracies, test_losses, test_accuracies = train();

using Plots
plot(train_accuracies)
plot!(test_accuracies)

1 Like