Hi, I utilize Lux.jl implemented the input convex neural network, but it doesn’t show positive eigenvalues at origin. Below is the working example. Any comment is appreciated.
using Lux
using Random
using LinearAlgebra
using ForwardDiff
function huber_nonneg(x)
return @. ifelse(abs(x) < 1.0, 0.5 * x^2, abs(x) - 0.5)
end
"""
ICNNLayer
Input Convex Neural Network layer with two weight matrices:
- W_z: non-negative weights from previous layer (enforces convexity)
- W_x: unconstrained weights for input passthrough
"""
struct ICNNLayer{T, F} <: Lux.AbstractLuxLayer
in_dims::Int
hidden_dims::Int
out_dims::Int
activation::F
use_bias::Bool
end
function ICNNLayer(in_dims::Int, hidden_dims::Int, out_dims::Int,
activation=softplus; use_bias=true, T::Type=Float64)
return ICNNLayer{T, typeof(activation)}(in_dims, hidden_dims, out_dims, activation, use_bias)
end
function Lux.initialparameters(rng::AbstractRNG, layer::ICNNLayer{T,F}) where {T,F}
W_z = rand(rng, T, layer.out_dims, layer.hidden_dims)
W_x = randn(rng, T, layer.out_dims, layer.in_dims)
if layer.use_bias
bias = zeros(T, layer.out_dims)
return (W_z=W_z, W_x=W_x, bias=bias)
else
return (W_z=W_z, W_x=W_x)
end
end
Lux.initialstates(rng::AbstractRNG, ::ICNNLayer) = NamedTuple()
function Lux.parameterlength(layer::ICNNLayer)
n = layer.out_dims * layer.hidden_dims + layer.out_dims * layer.in_dims
return layer.use_bias ? n + layer.out_dims : n
end
Lux.statelength(::ICNNLayer) = 0
function (layer::ICNNLayer{T,F})((x, z)::Tuple, ps, st) where {T,F}
y = ps.W_z * z + ps.W_x * x
if layer.use_bias
y = y .+ ps.bias
end
return layer.activation.(y), st
end
"""
ICNNFirstLayer
First layer of ICNN (no previous hidden layer)
"""
struct ICNNFirstLayer{T,F} <: Lux.AbstractLuxLayer
in_dims::Int
out_dims::Int
activation::F
use_bias::Bool
end
function ICNNFirstLayer(in_dims::Int, out_dims::Int,
activation=softplus; use_bias=true, T::Type=Float64)
return ICNNFirstLayer{T, typeof(activation)}(in_dims, out_dims, activation, use_bias)
end
function Lux.initialparameters(rng::AbstractRNG, layer::ICNNFirstLayer{T,F}) where {T,F}
weight = randn(rng, T, layer.out_dims, layer.in_dims)
if layer.use_bias
bias = zeros(T, layer.out_dims)
return (weight=weight, bias=bias)
else
return (weight=weight,)
end
end
Lux.initialstates(rng::AbstractRNG, ::ICNNFirstLayer) = NamedTuple()
function Lux.parameterlength(layer::ICNNFirstLayer)
n = layer.out_dims * layer.in_dims
return layer.use_bias ? n + layer.out_dims : n
end
Lux.statelength(::ICNNFirstLayer) = 0
function (layer::ICNNFirstLayer{T,F})(x, ps, st) where {T,F}
y = ps.weight * x
if layer.use_bias
y = y .+ ps.bias
end
return layer.activation.(y), st
end
"""
FinalICNNLayer
Final output layer of ICNN with optional quadratic terms
"""
struct FinalICNNLayer{T, F} <: Lux.AbstractLuxLayer
in_dims::Int
hidden_dims::Int
out_dims::Int
activation::F
use_bias::Bool
use_quadratic::Bool
end
function FinalICNNLayer(in_dims::Int, hidden_dims::Int, out_dims::Int=1,
activation=identity; use_bias=false, use_quadratic=false, T::Type=Float64)
return FinalICNNLayer{T, typeof(activation)}(in_dims, hidden_dims, out_dims, activation, use_bias, use_quadratic)
end
function Lux.initialparameters(rng::AbstractRNG, layer::FinalICNNLayer{T,F}) where {T,F}
W_z = rand(rng, T, layer.out_dims, layer.hidden_dims)
if layer.use_quadratic
W_x = randn(rng, T, layer.out_dims, layer.in_dims)
d = rand(rng, T, layer.in_dims)
if layer.use_bias
bias = zeros(T, layer.out_dims)
return (W_z=W_z, W_x=W_x, d=d, bias=bias)
else
return (W_z=W_z, W_x=W_x, d=d)
end
else
if layer.use_bias
bias = zeros(T, layer.out_dims)
return (W_z=W_z, bias=bias)
else
return (W_z=W_z,)
end
end
end
Lux.initialstates(rng::AbstractRNG, ::FinalICNNLayer) = NamedTuple()
function Lux.parameterlength(layer::FinalICNNLayer)
n = layer.out_dims * layer.hidden_dims
if layer.use_quadratic
n += layer.out_dims * layer.in_dims + layer.in_dims
end
return layer.use_bias ? n + layer.out_dims : n
end
Lux.statelength(::FinalICNNLayer) = 0
function (layer::FinalICNNLayer{T,F})((x, z)::Tuple, ps, st) where {T,F}
y = ps.W_z * z
if layer.use_quadratic
quad_term = sum(x .^ 2 .* huber_nonneg(ps.d), dims=1)
linear_term = ps.W_x * x
y = y .+ quad_term .+ linear_term
end
if layer.use_bias
y = y .+ ps.bias
end
return layer.activation.(y), st
end
"""
ICNNChain
Chain of ICNN layers with proper input/hidden state passing
"""
struct ICNNChain{L<:Tuple} <: Lux.AbstractLuxLayer
layers::L
end
ICNNChain(layers...) = ICNNChain(layers)
function Base.show(io::IO, chain::ICNNChain)
print(io, "ICNNChain(")
for (i, layer) in enumerate(chain.layers)
i > 1 && print(io, ", ")
print(io, layer)
end
print(io, ")")
end
function Lux.initialparameters(rng::AbstractRNG, chain::ICNNChain)
params = map(chain.layers) do layer
Lux.initialparameters(rng, layer)
end
return (layers=params,)
end
function Lux.initialstates(rng::AbstractRNG, chain::ICNNChain)
states = map(chain.layers) do layer
Lux.initialstates(rng, layer)
end
return (layers=states,)
end
function Lux.parameterlength(chain::ICNNChain)
return sum(Lux.parameterlength, chain.layers)
end
function Lux.statelength(chain::ICNNChain)
return sum(Lux.statelength, chain.layers)
end
function (chain::ICNNChain)(x, ps, st)
# First layer - only takes x
z, st_first = chain.layers[1](x, ps.layers[1], st.layers[1])
states = [st_first]
# Subsequent layers - take both x and previous hidden state
for i in 2:length(chain.layers)
z, st_new = chain.layers[i]((x, z), ps.layers[i], st.layers[i])
push!(states, st_new)
end
return z, (layers=tuple(states...),)
end
"""
create_icnn(n_vars, hidden_dims=[32, 32]; T=Float64, rng=Random.default_rng())
Create an Input Convex Neural Network for entropy function
"""
function create_icnn(n_vars::Int, hidden_dims::Vector{Int}=[32, 32];
T::Type=Float64, rng::AbstractRNG=Random.default_rng())
@assert !isempty(hidden_dims) "hidden_dims cannot be empty"
layers = []
# First layer
push!(layers, ICNNFirstLayer(n_vars, hidden_dims[1], softplus; T=T))
# Hidden layers
for i in 1:(length(hidden_dims)-1)
push!(layers, ICNNLayer(n_vars, hidden_dims[i],
hidden_dims[i+1], softplus; T=T))
end
# Output layer (no activation, with quadratic terms)
push!(layers, FinalICNNLayer(n_vars, hidden_dims[end], 1, identity;
use_bias=false, use_quadratic=false, T=T))
model = ICNNChain(layers...)
ps, st = Lux.setup(rng, model)
return model, ps, st
end
rng = Random.default_rng()
model, ps, st = create_icnn(2, [16, 16]; T=Float64, rng=rng)
u = Float64[0.0, 0.0]
# Forward pass
η, _ = model(u,ps, st)
println(η[1]) # 38.30048
# Hessian
H = ForwardDiff.hessian(u_vec -> model(u_vec, ps, st)[1][1], u)
println(isapprox(H, H', rtol=1e-15))# false
println(all(eigvals(H) .> -1e-15)) # false