Is there a way to impose a constraint on the output of a Flux neural network model?

Suppose that I have the observed data y_i\sim N(0,1/w(x_i)), where x_i=i for i=1,...,100. I define the true function w(x_i)=(x_i-50)^2 and hope to recover this function with a Flux neural network model. Let’s denote the estimated neural network \tilde{w}(x_i).

Below is my current code that I’m using. However, I noticed that there is a domain error, which I believe occurs because the neural network allows \tilde{w}(x_i) to be negative, which leads to an improper probability distribution (i.e., the variance must be positive). Thus, I’m wondering if you know how to impose a constraint on the Flux model such that \tilde{w}(x_i)>0.

My code:

using Flux
using Flux: train!
using LinearAlgebra, Random, Statistics, Distributions, SparseArrays, StatsBase # stats
using Plots # for chains

## data simulation 
T = 100
w = zeros(T)
times = hcat(1:T...)
[w[t] = (times[t]-50)^2 for t in 1:T]
Sig = Diagonal(vec(1/w))
mu = zeros(T)
y = rand(MvNormal(mu, Sig), 1)
plot(y)

## specify NN
model = Dense(1 => 1)

## specify loss function as negative log-likelihood (using univariate specification to make things easier)
loss(model, x, y) = -sum([logpdf(Normal(0, 1 / model(x)[i]), y[i]) for i in 1:T])

## initial loss 
loss(model, times, y)

## optimization function
opt = Descent()

## structure data 
data = [(times, y)]

## train it,  baby
for epoch in 1:200
    train!(loss, model, data, opt)
end

## RESULTING ERROR:
ERROR: DomainError with -0.8109541:
Normal: the condition σ >= zero(σ) is not satisfied.

I think that the most usual approach to impose such constraints are to use appropriate nonlinearity. In case of strict positive output, i would use softplus log(1+exp(x)).

5 Likes

Thank you for your comment, Tomas. Per your suggestion, I now have the following code:

###
### Playing w/ Flux in JL 
###

using Flux
using Flux: train!
using LinearAlgebra, Random, Statistics, Distributions, SparseArrays, StatsBase # stats
using Plots # for chains
using ProgressBars
using NNlib # for activation functions

## data simulation 
T = 100
w = zeros(T)
times = hcat(1:T...)
[w[t] = (((times[t]-50)+1)/10)^2 for t in 1:T]
plot(w)
Sig = Diagonal(vec(w)) #Diagonal(vec(1/w))
mu = zeros(T)
y = rand(MvNormal(mu, Sig), 1)
plot(y)

## specify NN
model = Chain(Dense(1, 12), 
    Dense(12, 1, softplus))

## specify loss function as negative log-likelihood
function loss(model, x, y) # we minimize loss, so...
    newSig = Diagonal(vec(model(x)))
    lpdf = logpdf(MvNormal(mu, Sig), y) # will be negative
    -lpdf
end

## initial loss 
loss(model, times, y)

The initial loss is NaN. As I’m playing around with this package, I’m finding myself getting many Inf and NaN values. I assume this is due to the initial weights/biases, but I’m not sure how to modify these (or if I even should). Do you have any recommendations here?

there’s probably a bug in your loss. You want to use newSig instead of Sig in the line that calls logpdf
(Note that the reason for the NaN is that Sig contains one entry which is zero, which makes the normal pdf diverge).

That resolves the NaN issue, but now I’m seeing the following error when attempting to train my model:

┌ Warning: setup found no trainable parameters in this model
└ @ Optimisers ~/.julia/packages/Optimisers/1x8gl/src/interface.jl:28
ERROR: MethodError: no method matching (::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(softplus), Matrix{Float32}, Vector{Float32}}}})(::typeof(loss), ::Matrix{Int64}, ::Matrix{Float64})

Updated code:

###
### Playing w/ Flux in JL 
###

using Flux
using Flux: train!
using LinearAlgebra, Random, Statistics, Distributions, SparseArrays, StatsBase # stats
using Plots # for chains
using ProgressBars
using NNlib # for activation functions

## data simulation 
T = 100
w = zeros(T)
times = hcat(1:T...)
[w[t] = (((times[t]-50)+1)/10)^2 for t in 1:T]
plot(w)
Sig = Diagonal(vec(w)) #Diagonal(vec(1/w))
mu = zeros(T)
y = rand(MvNormal(mu, Sig), 1)
plot(y)

## specify NN
model = Chain(Dense(1, 12), 
    Dense(12, 1, softplus))

## specify loss function as negative log-likelihood
function loss(model, x, y) # we minimize loss, so...
    newSig = Diagonal(vec(model(x)))
    lpdf = logpdf(MvNormal(mu, newSig), y) # will be negative
    -lpdf
end

## initial loss 
loss(model, times, y)

## structure data 
data = [(times, y)]

## train model
train!(model, loss, data, Descent())

I’m following the data structure of the Flux tutorial for fitting a line, but for some reason, there are no trainable parameters in my model and there is not an existing method that is applicable. I’m not even sure which function is providing the “no method matching” error.